Skip to content

Latest commit

 

History

History
79 lines (65 loc) · 3.54 KB

README.md

File metadata and controls

79 lines (65 loc) · 3.54 KB

DOC Build Status DOI Anaconda-Server Badge

torch-em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on PyTorch. Any feedback is highly appreciated, just open an issue!

Highlights:

  • Functional API with sensible defaults to train a state-of-the-art segmentation model with a few lines of code.
  • Differentiable augmentations on GPU and CPU thanks to kornia.
  • Off-the-shelf logging with tensorboard or wandb.
  • Export trained models to bioimage.io model format with one function call to deploy them in ilastik or deepimageJ.

Design:

  • All parameters are specified in code, no configuration files.
  • No callback logic; to extend the core functionality inherit from torch_em.trainer.DefaultTrainer instead.
  • All data-loading is lazy to support training on large datasets.

torch_em can be installed via conda: conda install -c conda-forge. Find an example script for how to train a 2D U-Net with it below and check out the documentation for more details.

# Train a 2d U-Net for foreground and boundary segmentation of nuclei, using data from
# https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip

import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets import get_dsb_loader

model = UNet2d(in_channels=1, out_channels=2)

# Transform to convert from instance segmentation labels to foreground and boundary probabilties.
label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True, ndim=2)

# Create the training and validation data loader.
data_path = "./dsb"  # The training data will be downloaded and saved here.
train_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="train",
    download=True,
    label_transform=label_transform,
)
val_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="test",
    label_transform=label_transform,
)

# The trainer handles the details of the training process.
# It will save checkpoints in "checkpoints/dsb-boundary-model"
# and the tensorboard logs in "logs/dsb-boundary-model".
trainer = torch_em.default_segmentation_trainer(
    name="dsb-boundary-model",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    learning_rate=1e-4,
)
trainer.fit(iterations=5000)  # Fit for 5000 iterations.

# Export the trained model to the bioimage.io model format.
from glob import glob
import imageio
from torch_em.util import export_bioimageio_model

# Load one of the images to use as reference image.
# Crop it to a shape that is guaranteed to fit the network.
test_im = imageio.imread(glob(f"{data_path}/test/images/*.tif")[0])[:256, :256]

# Export the model.
export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)