Skip to content

Commit

Permalink
feat: 🔥 Add Tensorstore based data fetching
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 28, 2024
1 parent a49ade1 commit 73185d2
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 47 deletions.
1 change: 1 addition & 0 deletions src/cellmap_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
__email__ = "[email protected]"

from .dataloader import CellMapDataLoader
from .multidataset import CellMapMultiDataset
from .datasplit import CellMapDataSplit
from .dataset import CellMapDataset
from .image import CellMapImage
49 changes: 40 additions & 9 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,48 @@
from torch.utils.data import DataLoader

from .multidataset import CellMapMultiDataset
from .datasplit import CellMapDataSplit
from typing import Callable, Iterable
from typing import Callable, Iterable, Optional


class CellMapDataLoader(DataLoader):
"""This subclasses PyTorch DataLoader to load CellMap data for training. It maintains the same API as the DataLoader class. This includes applying augmentations to the data and returning the data in the correct format for training, such as generating the target arrays (e.g. signed distance transform of labels). It retrieves raw and groundtruth data from a CellMapDataSplit object, which is a subclass of PyTorch Dataset. Training and validation data are split using the CellMapDataSplit object, and separate dataloaders are maintained as `train_loader` and `val_loader` respectively."""
class CellMapDataLoader:
# TODO: docstring corrections
"""This subclasses PyTorch DataLoader to load CellMap data for training. It maintains the same API as the DataLoader class. This includes applying augmentations to the data and returning the data in the correct format for training, such as generating the target arrays (e.g. signed distance transform of labels). It retrieves raw and groundtruth data from a CellMapDataSplit object, which is a subclass of PyTorch Dataset. Training and validation data are split using the CellMapDataSplit object, and separate dataloaders are maintained as `train_loader` and `validate_loader` respectively."""

input_arrays: dict[str, dict[str, tuple[int | float]]]
target_arrays: dict[str, dict[str, tuple[int | float]]]
classes: list[str]
datasplit: CellMapDataSplit
train_datasets: CellMapMultiDataset
validate_datasets: CellMapMultiDataset
train_loader: DataLoader
val_loader: DataLoader
validate_loader: DataLoader
batch_size: int
num_workers: int
is_train: bool
augmentations: list[dict[str, any]]
to_target: Callable

def __init__(
self,
datasplit: CellMapDataSplit,
batch_size: int,
num_workers: int,
is_train: bool,
):
self.datasplit = datasplit
self.batch_size = batch_size
self.num_workers = num_workers
self.is_train = is_train
self.construct()

def construct(self):
self.train_datasets = self.datasplit.train_datasets_combined
self.validate_datasets = self.datasplit.validate_datasets_combined
self.train_loader = DataLoader(
self.train_datasets,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=self.is_train,
)
self.validate_loader = DataLoader(
self.validate_datasets,
batch_size=1,
num_workers=self.num_workers,
shuffle=False,
)
48 changes: 35 additions & 13 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# %%
import csv
from typing import Callable, Dict, Iterable, Optional
from typing import Callable, Dict, Sequence, Optional
import numpy as np
import torch
from torch.utils.data import Dataset
Expand All @@ -24,49 +24,67 @@ class CellMapDataset(Dataset):

raw_path: str
gt_path: str
classes: Iterable[str]
input_arrays: dict[str, dict[str, Iterable[int | float]]]
target_arrays: dict[str, dict[str, Iterable[int | float]]]
classes: Sequence[str]
input_arrays: dict[str, dict[str, Sequence[int | float]]]
target_arrays: dict[str, dict[str, Sequence[int | float]]]
input_sources: dict[str, CellMapImage]
target_sources: dict[str, dict[str, CellMapImage | EmptyImage]]
to_target: Callable
transforms: RandomApply | None
has_data: bool
_bounding_box: Optional[Dict[str, list[int]]]
_bounding_box_shape: Optional[Dict[str, int]]
_sampling_box: Optional[Dict[str, list[int]]]
_sampling_box_shape: Optional[Dict[str, int]]
_class_counts: Optional[Dict[str, Dict[str, int]]]
_largest_voxel_sizes: Optional[Dict[str, int]]
_len: Optional[int]
_iter_coords: Optional[...]

def __init__(
self,
raw_path: str,
gt_path: str,
classes: Iterable[str],
input_arrays: dict[str, dict[str, Iterable[int | float]]],
target_arrays: dict[str, dict[str, Iterable[int | float]]],
classes: Sequence[str],
input_arrays: dict[str, dict[str, Sequence[int | float]]],
target_arrays: dict[str, dict[str, Sequence[int | float]]],
to_target: Callable,
transforms: Optional[Sequence[Callable]] = None,
):
"""Initializes the CellMapDataset class.
Args:
raw_path (str): The path to the raw data.
gt_path (str): The path to the ground truth data.
classes (Iterable[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros.
input_arrays (dict[str, dict[str, Iterable[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:
classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros.
input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:
{
"array_name": {
"shape": typle[int],
"scale": Iterable[float],
"scale": Sequence[float],
},
...
}
target_arrays (dict[str, dict[str, Iterable[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the following structure:
target_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the following structure:
{
"array_name": {
"shape": typle[int],
"scale": Iterable[float],
"scale": Sequence[float],
},
...
}
to_target (Callable): A function to convert the ground truth data to target arrays. The function should have the following structure:
def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tensor]:
transforms (Optional[Sequence[Callable]], optional): A sequence of transformations to apply to the data. Defaults to None.
"""
self.raw_path = raw_path
self.gt_paths = gt_path
self.gt_path_str, self.classes_with_path = split_gt_path(gt_path)
self.classes = classes
self.input_arrays = input_arrays
self.target_arrays = target_arrays
self.to_target = to_target
self.transforms = transforms
self.construct()

def __len__(self):
Expand Down Expand Up @@ -99,7 +117,10 @@ def __getitem__(self, idx):

def __iter__(self):
"""Iterates over the dataset, covering each section of the bounding box. For instance, for calculating validation scores."""
...
# TODO
if self._iter_coords is None:
self._iter_coords = ...
yield self.__getitem__(self._iter_coords)

def construct(self):
"""Constructs the input and target sources for the dataset."""
Expand Down Expand Up @@ -137,6 +158,7 @@ def construct(self):
self._class_counts = None
self._largest_voxel_sizes = None
self._len = None
self._iter_coords = None

@property
def largest_voxel_sizes(self):
Expand Down
77 changes: 54 additions & 23 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,60 @@
import csv
from torch.utils.data import Dataset
from typing import Callable, Dict, Iterable, Optional
from torchvision.transforms.v2 import RandomApply
from typing import Callable, Dict, Iterable, Optional, Sequence

from .multidataset import CellMapMultiDataset
from .dataset import CellMapDataset


class CellMapDataSplit(Dataset):
class CellMapDataSplit:
"""
This subclasses PyTorch Dataset to split data into training and validation sets. It maintains the same API as the Dataset class. It retrieves raw and groundtruth data from CellMapDataset objects.
"""

input_arrays: dict[str, dict[str, Iterable[int | float]]]
target_arrays: dict[str, dict[str, Iterable[int | float]]]
classes: Iterable[str]
input_arrays: dict[str, dict[str, Sequence[int | float]]]
target_arrays: dict[str, dict[str, Sequence[int | float]]]
classes: Sequence[str]
to_target: Callable
datasets: dict[str, Iterable[CellMapDataset]]
train_datasets: Iterable[CellMapDataset]
validate_datasets: Iterable[CellMapDataset]
train_datasets_combined: CellMapMultiDataset
validate_datasets_combined: CellMapMultiDataset
transforms: RandomApply | None

def __init__(
self,
input_arrays: dict[str, dict[str, Iterable[int | float]]],
target_arrays: dict[str, dict[str, Iterable[int | float]]],
classes: Iterable[str],
input_arrays: dict[str, dict[str, Sequence[int | float]]],
target_arrays: dict[str, dict[str, Sequence[int | float]]],
classes: Sequence[str],
to_target: Callable,
datasets: Optional[Dict[str, Iterable[CellMapDataset]]] = None,
dataset_dict: Optional[Dict[str, Dict[str, str]]] = None,
csv_path: Optional[str] = None,
transforms: Optional[Sequence[Callable]] = None,
):
"""Initializes the CellMapDatasets class.
Args:
input_arrays (dict[str, dict[str, Iterable[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:
input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:
{
"array_name": {
"shape": typle[int],
"scale": Iterable[float],
"scale": Sequence[float],
},
...
}
target_arrays (dict[str, dict[str, Iterable[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the following structure:
target_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the following structure:
{
"array_name": {
"shape": typle[int],
"scale": Iterable[float],
"scale": Sequence[float],
},
...
}
classes (Iterable[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros.
classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros.
to_target (Callable): A function to convert the ground truth data to target arrays. The function should have the following structure:
def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tensor]:
datasets (Optional[Dict[str, CellMapDataset]], optional): A dictionary containing the dataset objects. The dictionary should have the following structure:
{
"train": Iterable[CellMapDataset],
Expand All @@ -58,6 +70,7 @@ def __init__(
}. Defaults to None.
csv_path (Optional[str], optional): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure:
train | validate, raw path, gt path
transforms (Optional[Iterable[dict[str, any]]], optional): A list of transforms to apply to the data. Each augmentation should be a dictionary containing the following structure:
"""
self.input_arrays = input_arrays
self.target_arrays = target_arrays
Expand All @@ -71,13 +84,8 @@ def __init__(
self.construct(dataset_dict)
elif csv_path is not None:
self.from_csv(csv_path)

def __len__(self):
return len(self.train_datasets)

def __getitem__(self, idx): ...

def __iter__(self): ...
self.to_target = to_target
self.transforms = RandomApply(transforms) if transforms is not None else None

def from_csv(self, csv_path):
# Load file data from csv file
Expand All @@ -99,17 +107,40 @@ def construct(self, dataset_dict):
for raw, gt in zip(dataset_dict["train"]["raw"], dataset_dict["train"]["gt"]):
self.train_datasets.append(
CellMapDataset(
raw, gt, self.classes, self.input_arrays, self.target_arrays
raw,
gt,
self.classes,
self.input_arrays,
self.target_arrays,
self.to_target,
self.transforms,
)
)
for raw, gt in zip(
dataset_dict["validate"]["raw"], dataset_dict["validate"]["gt"]
):
self.validate_datasets.append(
CellMapDataset(
raw, gt, self.classes, self.input_arrays, self.target_arrays
raw,
gt,
self.classes,
self.input_arrays,
self.target_arrays,
self.to_target,
)
)
self.train_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
self.target_arrays,
[ds for ds in self.train_datasets if ds.has_data],
)
self.validate_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
self.target_arrays,
[ds for ds in self.validate_datasets if ds.has_data],
)


# Example input arrays:
Expand Down
25 changes: 23 additions & 2 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Iterable, Optional
import torch
import tensorstore as ts
from fibsem_tools.io.core import read_xarray
import xarray
import tensorstore
import xarray_tensorstore as xt


class CellMapImage:
Expand All @@ -10,14 +12,16 @@ class CellMapImage:
shape: tuple[float, ...]
scale: tuple[float, ...]
label_class: str
store: ts.TensorStore
array: xarray.DataArray
context: Optional[tensorstore.Context] = None

def __init__(
self,
path: str,
target_class: str,
target_scale: Iterable[float],
target_voxel_shape: Iterable[int],
context: Optional[tensorstore.Context] = None,
):
"""Initializes a CellMapImage object.
Expand All @@ -36,12 +40,29 @@ def __init__(
self.output_shape = tuple(
target_voxel_shape
) # TODO: this should be a dictionary of shapes for each axis
self.context = context
self.construct()

def construct(self):
self._bounding_box = None
self._sampling_box = None
self._class_counts = None
self.ds = read_xarray(self.path)
# Find correct multiscale level based on target scale
# TODO
...
self.array_path = ...
# Construct an xarray with Tensorstore backend
spec = xt._zarr_spec_from_path(self.array_path)
array_future = tensorstore.open( # type: ignore
spec, read=True, write=False, context=self.context
)
array = array_future.result()
new_data = xt._TensorStoreAdapter(array)
self.array = ds.copy(data=new_data) # type: ignore
self.xs = self.array.coords["x"]
self.ys = self.array.coords["y"]
self.zs = self.array.coords["z"]

def __getitem__(self, center: Iterable[float]) -> torch.Tensor:
"""Returns image data centered around the given point, based on the scale and shape of the target output image."""
Expand Down
Loading

0 comments on commit 73185d2

Please sign in to comment.