diff --git a/src/cellmap_data/__init__.py b/src/cellmap_data/__init__.py index 7117ef0..d85cc6e 100644 --- a/src/cellmap_data/__init__.py +++ b/src/cellmap_data/__init__.py @@ -16,6 +16,7 @@ __email__ = "rhoadesj@hhmi.org" from .dataloader import CellMapDataLoader +from .multidataset import CellMapMultiDataset from .datasplit import CellMapDataSplit from .dataset import CellMapDataset from .image import CellMapImage diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index 1c49441..cf1be89 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -1,17 +1,48 @@ from torch.utils.data import DataLoader + +from .multidataset import CellMapMultiDataset from .datasplit import CellMapDataSplit -from typing import Callable, Iterable +from typing import Callable, Iterable, Optional -class CellMapDataLoader(DataLoader): - """This subclasses PyTorch DataLoader to load CellMap data for training. It maintains the same API as the DataLoader class. This includes applying augmentations to the data and returning the data in the correct format for training, such as generating the target arrays (e.g. signed distance transform of labels). It retrieves raw and groundtruth data from a CellMapDataSplit object, which is a subclass of PyTorch Dataset. Training and validation data are split using the CellMapDataSplit object, and separate dataloaders are maintained as `train_loader` and `val_loader` respectively.""" +class CellMapDataLoader: + # TODO: docstring corrections + """This subclasses PyTorch DataLoader to load CellMap data for training. It maintains the same API as the DataLoader class. This includes applying augmentations to the data and returning the data in the correct format for training, such as generating the target arrays (e.g. signed distance transform of labels). It retrieves raw and groundtruth data from a CellMapDataSplit object, which is a subclass of PyTorch Dataset. Training and validation data are split using the CellMapDataSplit object, and separate dataloaders are maintained as `train_loader` and `validate_loader` respectively.""" - input_arrays: dict[str, dict[str, tuple[int | float]]] - target_arrays: dict[str, dict[str, tuple[int | float]]] - classes: list[str] datasplit: CellMapDataSplit + train_datasets: CellMapMultiDataset + validate_datasets: CellMapMultiDataset train_loader: DataLoader - val_loader: DataLoader + validate_loader: DataLoader + batch_size: int + num_workers: int is_train: bool - augmentations: list[dict[str, any]] - to_target: Callable + + def __init__( + self, + datasplit: CellMapDataSplit, + batch_size: int, + num_workers: int, + is_train: bool, + ): + self.datasplit = datasplit + self.batch_size = batch_size + self.num_workers = num_workers + self.is_train = is_train + self.construct() + + def construct(self): + self.train_datasets = self.datasplit.train_datasets_combined + self.validate_datasets = self.datasplit.validate_datasets_combined + self.train_loader = DataLoader( + self.train_datasets, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.is_train, + ) + self.validate_loader = DataLoader( + self.validate_datasets, + batch_size=1, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index e93afa7..58a72ed 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -1,6 +1,6 @@ # %% import csv -from typing import Callable, Dict, Iterable, Optional +from typing import Callable, Dict, Sequence, Optional import numpy as np import torch from torch.utils.data import Dataset @@ -24,42 +24,58 @@ class CellMapDataset(Dataset): raw_path: str gt_path: str - classes: Iterable[str] - input_arrays: dict[str, dict[str, Iterable[int | float]]] - target_arrays: dict[str, dict[str, Iterable[int | float]]] + classes: Sequence[str] + input_arrays: dict[str, dict[str, Sequence[int | float]]] + target_arrays: dict[str, dict[str, Sequence[int | float]]] input_sources: dict[str, CellMapImage] target_sources: dict[str, dict[str, CellMapImage | EmptyImage]] + to_target: Callable + transforms: RandomApply | None + has_data: bool + _bounding_box: Optional[Dict[str, list[int]]] + _bounding_box_shape: Optional[Dict[str, int]] + _sampling_box: Optional[Dict[str, list[int]]] + _sampling_box_shape: Optional[Dict[str, int]] + _class_counts: Optional[Dict[str, Dict[str, int]]] + _largest_voxel_sizes: Optional[Dict[str, int]] + _len: Optional[int] + _iter_coords: Optional[...] def __init__( self, raw_path: str, gt_path: str, - classes: Iterable[str], - input_arrays: dict[str, dict[str, Iterable[int | float]]], - target_arrays: dict[str, dict[str, Iterable[int | float]]], + classes: Sequence[str], + input_arrays: dict[str, dict[str, Sequence[int | float]]], + target_arrays: dict[str, dict[str, Sequence[int | float]]], + to_target: Callable, + transforms: Optional[Sequence[Callable]] = None, ): """Initializes the CellMapDataset class. Args: raw_path (str): The path to the raw data. gt_path (str): The path to the ground truth data. - classes (Iterable[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros. - input_arrays (dict[str, dict[str, Iterable[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure: + classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros. + input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure: { "array_name": { "shape": typle[int], - "scale": Iterable[float], + "scale": Sequence[float], }, ... } - target_arrays (dict[str, dict[str, Iterable[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the following structure: + target_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the following structure: { "array_name": { "shape": typle[int], - "scale": Iterable[float], + "scale": Sequence[float], }, ... } + to_target (Callable): A function to convert the ground truth data to target arrays. The function should have the following structure: + def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tensor]: + transforms (Optional[Sequence[Callable]], optional): A sequence of transformations to apply to the data. Defaults to None. """ self.raw_path = raw_path self.gt_paths = gt_path @@ -67,6 +83,8 @@ def __init__( self.classes = classes self.input_arrays = input_arrays self.target_arrays = target_arrays + self.to_target = to_target + self.transforms = transforms self.construct() def __len__(self): @@ -99,7 +117,10 @@ def __getitem__(self, idx): def __iter__(self): """Iterates over the dataset, covering each section of the bounding box. For instance, for calculating validation scores.""" - ... + # TODO + if self._iter_coords is None: + self._iter_coords = ... + yield self.__getitem__(self._iter_coords) def construct(self): """Constructs the input and target sources for the dataset.""" @@ -137,6 +158,7 @@ def construct(self): self._class_counts = None self._largest_voxel_sizes = None self._len = None + self._iter_coords = None @property def largest_voxel_sizes(self): diff --git a/src/cellmap_data/datasplit.py b/src/cellmap_data/datasplit.py index 4869458..9c089bb 100644 --- a/src/cellmap_data/datasplit.py +++ b/src/cellmap_data/datasplit.py @@ -1,48 +1,60 @@ import csv -from torch.utils.data import Dataset -from typing import Callable, Dict, Iterable, Optional +from torchvision.transforms.v2 import RandomApply +from typing import Callable, Dict, Iterable, Optional, Sequence + +from .multidataset import CellMapMultiDataset from .dataset import CellMapDataset -class CellMapDataSplit(Dataset): +class CellMapDataSplit: """ This subclasses PyTorch Dataset to split data into training and validation sets. It maintains the same API as the Dataset class. It retrieves raw and groundtruth data from CellMapDataset objects. """ - input_arrays: dict[str, dict[str, Iterable[int | float]]] - target_arrays: dict[str, dict[str, Iterable[int | float]]] - classes: Iterable[str] + input_arrays: dict[str, dict[str, Sequence[int | float]]] + target_arrays: dict[str, dict[str, Sequence[int | float]]] + classes: Sequence[str] + to_target: Callable datasets: dict[str, Iterable[CellMapDataset]] + train_datasets: Iterable[CellMapDataset] + validate_datasets: Iterable[CellMapDataset] + train_datasets_combined: CellMapMultiDataset + validate_datasets_combined: CellMapMultiDataset + transforms: RandomApply | None def __init__( self, - input_arrays: dict[str, dict[str, Iterable[int | float]]], - target_arrays: dict[str, dict[str, Iterable[int | float]]], - classes: Iterable[str], + input_arrays: dict[str, dict[str, Sequence[int | float]]], + target_arrays: dict[str, dict[str, Sequence[int | float]]], + classes: Sequence[str], + to_target: Callable, datasets: Optional[Dict[str, Iterable[CellMapDataset]]] = None, dataset_dict: Optional[Dict[str, Dict[str, str]]] = None, csv_path: Optional[str] = None, + transforms: Optional[Sequence[Callable]] = None, ): """Initializes the CellMapDatasets class. Args: - input_arrays (dict[str, dict[str, Iterable[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure: + input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure: { "array_name": { "shape": typle[int], - "scale": Iterable[float], + "scale": Sequence[float], }, ... } - target_arrays (dict[str, dict[str, Iterable[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the following structure: + target_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the following structure: { "array_name": { "shape": typle[int], - "scale": Iterable[float], + "scale": Sequence[float], }, ... } - classes (Iterable[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros. + classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros. + to_target (Callable): A function to convert the ground truth data to target arrays. The function should have the following structure: + def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tensor]: datasets (Optional[Dict[str, CellMapDataset]], optional): A dictionary containing the dataset objects. The dictionary should have the following structure: { "train": Iterable[CellMapDataset], @@ -58,6 +70,7 @@ def __init__( }. Defaults to None. csv_path (Optional[str], optional): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure: train | validate, raw path, gt path + transforms (Optional[Iterable[dict[str, any]]], optional): A list of transforms to apply to the data. Each augmentation should be a dictionary containing the following structure: """ self.input_arrays = input_arrays self.target_arrays = target_arrays @@ -71,13 +84,8 @@ def __init__( self.construct(dataset_dict) elif csv_path is not None: self.from_csv(csv_path) - - def __len__(self): - return len(self.train_datasets) - - def __getitem__(self, idx): ... - - def __iter__(self): ... + self.to_target = to_target + self.transforms = RandomApply(transforms) if transforms is not None else None def from_csv(self, csv_path): # Load file data from csv file @@ -99,7 +107,13 @@ def construct(self, dataset_dict): for raw, gt in zip(dataset_dict["train"]["raw"], dataset_dict["train"]["gt"]): self.train_datasets.append( CellMapDataset( - raw, gt, self.classes, self.input_arrays, self.target_arrays + raw, + gt, + self.classes, + self.input_arrays, + self.target_arrays, + self.to_target, + self.transforms, ) ) for raw, gt in zip( @@ -107,9 +121,26 @@ def construct(self, dataset_dict): ): self.validate_datasets.append( CellMapDataset( - raw, gt, self.classes, self.input_arrays, self.target_arrays + raw, + gt, + self.classes, + self.input_arrays, + self.target_arrays, + self.to_target, ) ) + self.train_datasets_combined = CellMapMultiDataset( + self.classes, + self.input_arrays, + self.target_arrays, + [ds for ds in self.train_datasets if ds.has_data], + ) + self.validate_datasets_combined = CellMapMultiDataset( + self.classes, + self.input_arrays, + self.target_arrays, + [ds for ds in self.validate_datasets if ds.has_data], + ) # Example input arrays: diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 4a9087e..3f37a11 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -1,7 +1,9 @@ from typing import Iterable, Optional import torch -import tensorstore as ts from fibsem_tools.io.core import read_xarray +import xarray +import tensorstore +import xarray_tensorstore as xt class CellMapImage: @@ -10,7 +12,8 @@ class CellMapImage: shape: tuple[float, ...] scale: tuple[float, ...] label_class: str - store: ts.TensorStore + array: xarray.DataArray + context: Optional[tensorstore.Context] = None def __init__( self, @@ -18,6 +21,7 @@ def __init__( target_class: str, target_scale: Iterable[float], target_voxel_shape: Iterable[int], + context: Optional[tensorstore.Context] = None, ): """Initializes a CellMapImage object. @@ -36,12 +40,29 @@ def __init__( self.output_shape = tuple( target_voxel_shape ) # TODO: this should be a dictionary of shapes for each axis + self.context = context self.construct() def construct(self): self._bounding_box = None self._sampling_box = None self._class_counts = None + self.ds = read_xarray(self.path) + # Find correct multiscale level based on target scale + # TODO + ... + self.array_path = ... + # Construct an xarray with Tensorstore backend + spec = xt._zarr_spec_from_path(self.array_path) + array_future = tensorstore.open( # type: ignore + spec, read=True, write=False, context=self.context + ) + array = array_future.result() + new_data = xt._TensorStoreAdapter(array) + self.array = ds.copy(data=new_data) # type: ignore + self.xs = self.array.coords["x"] + self.ys = self.array.coords["y"] + self.zs = self.array.coords["z"] def __getitem__(self, center: Iterable[float]) -> torch.Tensor: """Returns image data centered around the given point, based on the scale and shape of the target output image.""" diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py new file mode 100644 index 0000000..b50a430 --- /dev/null +++ b/src/cellmap_data/multidataset.py @@ -0,0 +1,44 @@ +from typing import Iterable, Sequence +from torch.utils.data import Dataset + +from .dataset import CellMapDataset + + +class CellMapMultiDataset(Dataset): + """ + This subclasses PyTorch Dataset to wrap multiple CellMapDataset objects under a common API, which can be used for dataloading. It maintains the same API as the Dataset class. It retrieves raw and groundtruth data from CellMapDataset objects. + """ + + classes: Sequence[str] + input_arrays: dict[str, dict[str, Sequence[int | float]]] + target_arrays: dict[str, dict[str, Sequence[int | float]]] + datasets: Iterable[CellMapDataset] + + def __init__( + self, + classes: Sequence[str], + input_arrays: dict[str, dict[str, Sequence[int | float]]], + target_arrays: dict[str, dict[str, Sequence[int | float]]], + datasets: Iterable[CellMapDataset], + ): + self.input_arrays = input_arrays + self.target_arrays = target_arrays + self.classes = classes + self.datasets = datasets + self.construct() + + def __len__(self): + # TODO + ... + + def __getitem__(self, idx: int): + # TODO + ... + + def __iter__(self): + # TODO + ... + + def construct(self): + # TODO + ...