Skip to content

Commit

Permalink
add ensemble step
Browse files Browse the repository at this point in the history
This commit add ensemble step to steps.py and ensemble.py to utils,
where the Ensemble model as ArtModule is stored.

Example usage (using our tutorial's code from MNIST example):

```python
import torch.nn as nn
from dataset import MNISTDataModule

from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from art.metrics import build_metric_name
from art.utils.enums import TrainingStage

from art.utils.enums import (
    INPUT,
    TARGET,
)

def get_data_module(n_train=200):
    mnist_data = datasets.load_dataset("mnist")

    mnist_data = mnist_data.rename_columns({"image": INPUT, "label": TARGET})
    mnist_data['train'] = mnist_data['train'].select(range(n_train))

    return MNISTDataModule(mnist_data)

datamodule = get_data_module()
project = ArtProject(name="mnist-ensemble", datamodule=datamodule)
accuracy_metric, ce_loss = Accuracy(
    task="multiclass", num_classes=10), nn.CrossEntropyLoss()
project.register_metrics([accuracy_metric, ce_loss])
checkpoint = ModelCheckpoint(monitor=build_metric_name(
    accuracy_metric, TrainingStage.VALIDATION.value), mode="max")
early_stopping = EarlyStopping(monitor=build_metric_name(
    ce_loss, TrainingStage.VALIDATION.value), mode="min")
project.add_step(Ensemble(MNISTModel, 10, trainer_kwargs={
                 "max_epochs": 6, "callbacks": [checkpoint, early_stopping], "check_val_every_n_epoch": 5}))

project.run_all(force_rerun=True)
```
  • Loading branch information
kordc committed Dec 2, 2023
1 parent edd17dd commit c9cd71d
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
71 changes: 71 additions & 0 deletions art/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from art.utils.enums import TrainingStage
from art.utils.paths import get_checkpoint_logs_folder_path
from art.utils.savers import JSONStepSaver
from art.utils.ensemble import ArtEnsemble


class NoModelUsed:
Expand Down Expand Up @@ -829,3 +830,73 @@ def change_lr(model):
model.lr = self.fine_tune_lr

self.model_modifiers.append(change_lr)


class Ensemble(ModelStep):
"""This step tries to ensemble models"""

name = "Ensemble"
description = "Ensembles models"

def __init__(
self,
model: ArtModule,
num_models: int = 5,
logger: Optional[Logger] = None,
trainer_kwargs: Dict = {},
model_kwargs: Dict = {},
model_modifiers: List[Callable] = [],
datamodule_modifiers: List[Callable] = [],
):
"""
This method initializes the step
Args:
models (List[ArtModule]): models
logger (Logger, optional): logger. Defaults to None.
trainer_kwargs (Dict, optional): Kwargs passed to lightning Trainer. Defaults to {}.
model_kwargs (Dict, optional): Kwargs passed to model. Defaults to {}.
model_modifiers (List[Callable], optional): model modifiers. Defaults to [].
datamodule_modifiers (List[Callable], optional): datamodule modifiers. Defaults to [].
"""
super().__init__(
model,
trainer_kwargs,
model_kwargs,
model_modifiers,
datamodule_modifiers,
logger=logger,
)
self.num_models = num_models

def do(self, previous_states: Dict):
"""
This method trains the model
Args:
previous_states (Dict): previous states
"""
models_paths = []
for _ in range(self.num_models):
self.reset_trainer(
logger=self.trainer.logger, trainer_kwargs=self.trainer_kwargs
)
self.train(trainer_kwargs={"datamodule": self.datamodule})
models_paths.append(self.trainer.checkpoint_callback.best_model_path)

initialized_models = []
for path in models_paths:
model = self.model_class.load_from_checkpoint(path)
model.eval()
initialized_models.append(model)

self.model = ArtEnsemble(initialized_models)
self.validate(trainer_kwargs={"datamodule": self.datamodule})

def get_check_stage(self):
"""Returns check stage"""
return TrainingStage.VALIDATION.value

def log_model_params(self, model):
self.results["parameters"]["num_models"] = self.num_models
super().log_model_params(model)
36 changes: 36 additions & 0 deletions art/utils/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from art.core import ArtModule
from art.utils.enums import BATCH, PREDICTION

import torch
from torch import nn

from typing import List
from copy import deepcopy


class ArtEnsemble(ArtModule):
"""
Base class for ensembles.
"""

def __init__(self, models: List[ArtModule]):
super().__init__()
self.models = nn.ModuleList(models)

def predict(self, data):
predictions = torch.stack([self.predict_on_model_from_dataloader(model, deepcopy(data)) for model in self.models])
return torch.mean(predictions, dim=0)

def predict_on_model_from_dataloader(self, model, dataloader):
predictions = []
for batch in dataloader:
model.to(self.device)
batch_processed = model.parse_data({BATCH: batch})
predictions.append(model.predict(batch_processed)[PREDICTION])
return torch.cat(predictions)

def log_params(self):
return {
"num_models": len(self.models),
"models": [model.log_params() for model in self.models],
}

0 comments on commit c9cd71d

Please sign in to comment.