Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on PyTorch. Any feedback is highly appreciated, just open an issue!
Highlights:
- Functional API with sensible defaults to train a state-of-the-art segmentation model with a few lines of code.
- Differentiable augmentations on GPU and CPU thanks to kornia.
- Off-the-shelf logging with tensorboard or wandb.
- Export trained models to bioimage.io model format with one function call to deploy them in ilastik or deepimageJ.
Design:
- All parameters are specified in code, no configuration files.
- No callback logic; to extend the core functionality inherit from
torch_em.trainer.DefaultTrainer
instead. - All data-loading is lazy to support training on large datasets.
torch_em
can be installed via conda: conda install -c conda-forge
.
Find an example script for how to train a 2D U-Net with it below and check out the documentation for more details.
# Train a 2d U-Net for foreground and boundary segmentation of nuclei, using data from
# https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets import get_dsb_loader
model = UNet2d(in_channels=1, out_channels=2)
# Transform to convert from instance segmentation labels to foreground and boundary probabilties.
label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True, ndim=2)
# Create the training and validation data loader.
data_path = "./dsb" # The training data will be downloaded and saved here.
train_loader = get_dsb_loader(
data_path,
patch_shape=(1, 256, 256),
batch_size=8,
split="train",
download=True,
label_transform=label_transform,
)
val_loader = get_dsb_loader(
data_path,
patch_shape=(1, 256, 256),
batch_size=8,
split="test",
label_transform=label_transform,
)
# The trainer handles the details of the training process.
# It will save checkpoints in "checkpoints/dsb-boundary-model"
# and the tensorboard logs in "logs/dsb-boundary-model".
trainer = torch_em.default_segmentation_trainer(
name="dsb-boundary-model",
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=1e-4,
)
trainer.fit(iterations=5000) # Fit for 5000 iterations.
# Export the trained model to the bioimage.io model format.
from glob import glob
import imageio
from torch_em.util import export_bioimageio_model
# Load one of the images to use as reference image.
# Crop it to a shape that is guaranteed to fit the network.
test_im = imageio.imread(glob(f"{data_path}/test/images/*.tif")[0])[:256, :256]
# Export the model.
export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)