Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit add ensemble step to steps.py and ensemble.py to utils, where the Ensemble model as ArtModule is stored. Example usage (using our tutorial's code from MNIST example): ```python import torch.nn as nn from dataset import MNISTDataModule from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping from art.metrics import build_metric_name from art.utils.enums import TrainingStage from art.utils.enums import ( INPUT, TARGET, ) def get_data_module(n_train=200): mnist_data = datasets.load_dataset("mnist") mnist_data = mnist_data.rename_columns({"image": INPUT, "label": TARGET}) mnist_data['train'] = mnist_data['train'].select(range(n_train)) return MNISTDataModule(mnist_data) datamodule = get_data_module() project = ArtProject(name="mnist-ensemble", datamodule=datamodule) accuracy_metric, ce_loss = Accuracy( task="multiclass", num_classes=10), nn.CrossEntropyLoss() project.register_metrics([accuracy_metric, ce_loss]) checkpoint = ModelCheckpoint(monitor=build_metric_name( accuracy_metric, TrainingStage.VALIDATION.value), mode="max") early_stopping = EarlyStopping(monitor=build_metric_name( ce_loss, TrainingStage.VALIDATION.value), mode="min") project.add_step(Ensemble(MNISTModel, 10, trainer_kwargs={ "max_epochs": 6, "callbacks": [checkpoint, early_stopping], "check_val_every_n_epoch": 5})) project.run_all(force_rerun=True) ```
- Loading branch information