Skip to content

Commit

Permalink
Update documentation for torch_em.trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 4, 2025
1 parent 69b4d51 commit 8be48b7
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 39 deletions.
25 changes: 22 additions & 3 deletions torch_em/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,12 +446,31 @@ def default_segmentation_trainer(
compile_model: Optional[Union[bool, str]] = None,
rank: Optional[int] = None,
):
"""Get a trainer for training a segmentation network.
"""Get a trainer for a segmentation network.
It will create a `torch.optim.AdamW` optimizer and learning rate scheduler
that reduces the learning rate on plateau.
It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
By default, it uses the dice score as loss and metric.
This can be changed by passing arguments for `loss` and/or `metric`.
See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
Here's an example for training a 2D U-Net with this function:
```python
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
name="unet-training"
model=UNet2d(in_channels=1, out_channels=1)
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
```
Args:
name: The name of the checkpoint that will be created by the trainer.
model: The model to train.
Expand Down
124 changes: 100 additions & 24 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,52 @@
class DefaultTrainer:
"""Trainer class for training a segmentation network.
TODO
The trainer class implements the core logic for training a network with pytorch.
It implements a training loop to run training and validation, which is started with `fit`.
The checkpoints and logs from the training run will be saved in the current working directory,
or in the directory specifified by `save_root`. Training can be continued from a checkpoint
by passing it's location to the `load_from_checkpoint` argument of `fit`.
A pre-configured instance of the trainer can be obtained from `torch_em.default_segmentation_trainer`.
Alternatively, the trainer class can also be instantiated as in this example:
```python
import torch
from torch_em.loss import DiceLoss
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
from torch_em.trainer import DefaultTrainer
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
# Create the model and optimizer.
model = UNet2d(in_channels=1, out_channels=1)
optimizer = torch.optim.AdamW(model.parameters())
trainer = DefaultTrainer(
name="unet-training",
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
model=model,
loss=DiceLoss(), # The loss function.
optimizer=optimizer,
metric=DiceLoss(), # The metric. The trainer expects smaller values to represent better results.
device="cuda", # The device to use for training.
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
```
Args:
name: The name of the checkpoint that will be created by the trainer.
train_loader: The data loader containing the training data.
val_loader: The data loader containing the validation data.
model: The model to train.
loss: The loss function for training.
optimizer: TODO
optimizer: The optimizer.
metric: The metric for validation.
device: The torch device to use for training. If None, will use a GPU if available.
lr_scheduler: TODO
lr_scheduler: The learning rate scheduler.
log_image_interval: The interval for saving images during logging, in training iterations.
mixed_precision: Whether to train with mixed precision.
early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
Expand All @@ -56,7 +90,7 @@ def __init__(
optimizer: torch.optim.Optimizer,
metric: Callable,
device: Union[str, torch.device],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
log_image_interval: int = 100,
mixed_precision: bool = True,
early_stopping: Optional[int] = None,
Expand Down Expand Up @@ -106,12 +140,12 @@ def __init__(
self.logger_kwargs = logger_kwargs
self.log_image_interval = log_image_interval

@property # because the logger may generate and set trainer.id on logger.__init__
@property
def checkpoint_folder(self):
assert self.id_ is not None
# save_root enables saving the checkpoints somewhere else than in the local
# folder. This is handy for filesystems with limited space, where saving the checkpoints
# and log files can easily lead to running out of space.
assert self.id_ is not None # Because the logger may generate and set trainer.id on logger.__init__.
# Save_root enables saving the checkpoints somewhere else than in the local older.
# This is handy for filesystems with limited space, where saving the checkpoints
# and log files can ead to running out of space.
save_root = getattr(self, "save_root", None)
return os.path.join("./checkpoints", self.id_) if save_root is None else\
os.path.join(save_root, "./checkpoints", self.id_)
Expand All @@ -125,7 +159,7 @@ def epoch(self):
return self._epoch

class Deserializer:
"""Determines how to deserialize the trainer kwargs from serialized 'init_data'
"""Determines how to deserialize the trainer kwargs from serialized 'init_data'.
Examples:
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
Expand All @@ -146,19 +180,25 @@ class Deserializer:
>>> self.trainer_kwargs["the_answer"] = generic_answer + 1
>>> else:
>>> self.trainer_kwargs["the_answer"] = generic_answer * 2
Args:
init_data: The initialization data of the trainer.
save_path: The path where the checkpoint was saved.
device: The device.
"""

def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]):
def __init__(self, init_data: Dict, save_path: str, device: Union[str, torch.device]):
self.init_data = init_data
self.save_path = save_path
# populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'
# Populate with deserialized trainer kwargs during deserialization; possibly overwrite 'device'.
self.trainer_kwargs: Dict[str, Any] = dict(
device=torch.device(self.init_data["device"]) if device is None else torch.device(device)
)

def load(self, kwarg_name: str, optional):
"""`optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'"""

"""@private
"""
# `optional` is True if self.trainer.__class__.__init__ specifies a default value for 'kwarg_name'
if kwarg_name == "device":
pass # deserialized in __init__
elif kwarg_name.endswith("_loader"):
Expand All @@ -168,6 +208,8 @@ def load(self, kwarg_name: str, optional):
load(kwarg_name, optional=optional)

def load_data_loader(self, loader_name, optional) -> None:
"""@private
"""
ds = self.init_data.get(loader_name.replace("_loader", "_dataset"))
if ds is None and optional:
return
Expand All @@ -186,6 +228,8 @@ def load_generic(
only_class: bool = False,
dynamic_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""@private
"""
if kwarg_name in self.init_data:
self.trainer_kwargs[kwarg_name] = self.init_data[kwarg_name]
return
Expand All @@ -209,16 +253,24 @@ def load_generic(
)

def load_name(self, kwarg_name: str, optional: bool):
"""@private
"""
self.trainer_kwargs[kwarg_name] = os.path.split(os.path.dirname(self.save_path))[1]

def load_optimizer(self, kwarg_name: str, optional: bool):
"""@private
"""
self.load_generic(kwarg_name, self.trainer_kwargs["model"].parameters(), optional=optional)

def load_lr_scheduler(self, kwarg_name: str, optional: bool):
"""@private
"""
self.load_generic(kwarg_name, self.trainer_kwargs["optimizer"], optional=optional)

# todo: remove and rename kwarg 'logger' to 'logger_class'
def load_logger(self, kwarg_name: str, optional: bool):
"""@private
"""
assert kwarg_name == "logger"
self.load_generic("logger", optional=optional, only_class=True)

Expand All @@ -235,6 +287,8 @@ def from_checkpoint(
name: Literal["best", "latest"] = "best",
device: Optional[Union[str, torch.device]] = None,
):
"""@private
"""
save_path = os.path.join(checkpoint_folder, f"{name}.pt")
# make sure the correct device is set if we don't have access to CUDA
if not torch.cuda.is_available():
Expand Down Expand Up @@ -271,7 +325,7 @@ def from_checkpoint(
return trainer

class Serializer:
"""Implements how to serialize trainer kwargs from a trainer instance
"""Implements how to serialize trainer kwargs from a trainer instance.
Examples:
To extend the serialization process you can inherite from this Serializer in a derived Trainer class.
Expand Down Expand Up @@ -317,13 +371,18 @@ class Serializer:
>>> assert kwarg_name == "the_answer"
>>> # 'optional' is True if MyTrainer.__init__ specifies a default value for 'kwarg_name'
>>> self.trainer_kwargs[kwarg_name] = self.init_data["the"] * 10 + self.init_data["answer"]
Args:
trainer: The trainer instance.
"""

def __init__(self, trainer: DefaultTrainer):
self.trainer = trainer
self.init_data = {} # to be populated during serialization process

def dump(self, kwarg_name: str) -> None:
"""@private
"""
dumper = getattr(self, f"dump_{kwarg_name}", None)
if dumper is not None:
dumper(kwarg_name)
Expand Down Expand Up @@ -357,16 +416,22 @@ def dump(self, kwarg_name: str) -> None:
self.dump_generic_instance(kwarg_name)

def dump_generic_builtin(self, kwarg_name: str) -> None:
"""@private
"""
assert hasattr(self.trainer, kwarg_name)
self.init_data[kwarg_name] = getattr(self.trainer, kwarg_name)

def dump_generic_class(self, kwarg_name: str) -> None:
"""@private
"""
assert hasattr(self.trainer, kwarg_name)
assert kwarg_name.endswith("_class")
obj = getattr(self.trainer, kwarg_name)
self.init_data[kwarg_name] = None if obj is None else f"{obj.__module__}.{obj.__name__}"

def dump_generic_instance(self, kwarg_name: str) -> None:
"""@private
"""
assert hasattr(self.trainer, kwarg_name)
instance = getattr(self.trainer, kwarg_name)
self.init_data.update(
Expand All @@ -377,10 +442,14 @@ def dump_generic_instance(self, kwarg_name: str) -> None:
)

def dump_device(self, kwarg_name: str):
"""@private
"""
assert hasattr(self.trainer, kwarg_name)
self.init_data[kwarg_name] = str(getattr(self.trainer, kwarg_name))

def dump_data_loader(self, kwarg_name: str) -> None:
"""@private
"""
assert hasattr(self.trainer, kwarg_name)
loader = getattr(self.trainer, kwarg_name)
if loader is None:
Expand All @@ -393,9 +462,13 @@ def dump_data_loader(self, kwarg_name: str) -> None:
)

def dump_logger(self, kwarg_name: str): # todo: remove and rename kwarg 'logger' to 'logger_class'
"""@private
"""
self.dump_generic_class(f"{kwarg_name}_class")

def dump_model(self, kwarg_name: str):
"""@private
"""
if is_compiled(self.trainer.model):
self.init_data.update(
{"model_class": self.trainer._model_class, "model_kwargs": self.trainer._model_kwargs}
Expand Down Expand Up @@ -497,6 +570,8 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None):
return best_metric

def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
"""@private
"""
save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
extra_init_dict = extra_save_dict.pop("init", {})
save_dict = {
Expand All @@ -521,6 +596,8 @@ def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **e
torch.save(save_dict, save_path)

def load_checkpoint(self, checkpoint="best"):
"""@private
"""
if isinstance(checkpoint, str):
save_path = os.path.join(self.checkpoint_folder, f"{checkpoint}.pt")
if not os.path.exists(save_path):
Expand Down Expand Up @@ -577,15 +654,14 @@ def fit(
Exactly one of 'iterations' or 'epochs' has to be passed.
Parameters:
iterations [int] - how long to train, specified in iterations (default: None)
load_from_checkpoint [str] - path to a checkpoint from where training should be continued (default: None)
epochs [int] - how long to train, specified in epochs (default: None)
save_every_kth_epoch [int] - save checkpoints after every kth epoch separately.
The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
progress [progress_bar] - optional progress bar for integration with external tools.
Expected to follow the tqdm interface.
overwrite_training [bool] - Whether to overwrite the trained model.
Args:
iterations: How long to train, specified in iterations.
load_from_checkpoint: Path to a checkpoint from where training should be continued .
epochs: How long to train, specified in epochs.
save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file.
The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
progress: Optional progress bar for integration with external tools. Expected to follow the tqdm interface.
overwrite_training: Whether to overwrite existing checkpoints in the save directory.
"""
best_metric = self._initialize(iterations, load_from_checkpoint, epochs)

Expand Down
2 changes: 2 additions & 0 deletions torch_em/trainer/logger_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


class TorchEmLogger:
"""@private
"""
def __init__(self, trainer, save_root, **kwargs):
self.trainer = trainer
self.save_root = save_root
Expand Down
26 changes: 22 additions & 4 deletions torch_em/trainer/spoco_trainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
import time
from copy import deepcopy
from typing import Optional

import torch
from .default_trainer import DefaultTrainer
from .tensorboard_logger import TensorboardLogger


class SPOCOTrainer(DefaultTrainer):
"""Trainer for a SPOCO model.
For details check out "Sparse Object-level Supervision for Instance Segmentation with Pixel Embeddings":
https://arxiv.org/abs/2103.14572
Args:
model: The model to train.
momentum: The momementum value for exponential moving weight averaging.
semisupervised_loss: Optional loss for semi-supervised learning.
semisupervised_loader: Optional data loader for semi-supervised learning.
logger: The logger.
kwargs: Additional keyord arguments for `torch_em.trainer.DefaultTrainer`.
"""
def __init__(
self,
model,
momentum=0.999,
semisupervised_loss=None,
semisupervised_loader=None,
model: torch.nn.Module,
momentum: float = 0.999,
semisupervised_loss: Optional[torch.nn.Module] = None,
semisupervised_loader: Optional[torch.utils.data.DataLoader] = None,
logger=TensorboardLogger,
**kwargs,
):
Expand All @@ -33,11 +47,15 @@ def _momentum_update(self):
param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)

def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
"""@private
"""
super().save_checkpoint(
name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
)

def load_checkpoint(self, checkpoint="best"):
"""@private
"""
save_dict = super().load_checkpoint(checkpoint)
self.model2.load_state_dict(save_dict["model2_state"])
self.model2.to(self.device)
Expand Down
Loading

0 comments on commit 8be48b7

Please sign in to comment.