diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 5248e08..0a9261d 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -29,12 +29,13 @@ class CellMapDataset(Dataset): target_arrays: dict[str, dict[str, Sequence[int | float]]] input_sources: dict[str, CellMapImage] target_sources: dict[str, dict[str, CellMapImage | EmptyImage]] - spatial_transforms: Optional[Sequence[dict[str, any]]] # type: ignore + spatial_transforms: Optional[dict[str, any]] # type: ignore raw_value_transforms: Optional[Callable] gt_value_transforms: Optional[Callable | Sequence[Callable] | dict[str, Callable]] - context: Optional[tensorstore.Context] # type: ignore has_data: bool is_train: bool + axis_order: str + context: Optional[tensorstore.Context] # type: ignore _bounding_box: Optional[Dict[str, list[int]]] _bounding_box_shape: Optional[Dict[str, int]] _sampling_box: Optional[Dict[str, list[int]]] @@ -51,12 +52,13 @@ def __init__( classes: Sequence[str], input_arrays: dict[str, dict[str, Sequence[int | float]]], target_arrays: dict[str, dict[str, Sequence[int | float]]], - spatial_transforms: Optional[Sequence[dict[str, any]]] = None, # type: ignore + spatial_transforms: Optional[dict[str, any]] = None, # type: ignore raw_value_transforms: Optional[Callable] = None, gt_value_transforms: Optional[ Callable | Sequence[Callable] | dict[str, Callable] ] = None, is_train: bool = False, + axis_order: str = "zyx", context: Optional[tensorstore.Context] = None, # type: ignore ): """Initializes the CellMapDataset class. @@ -99,6 +101,7 @@ def __init__( self.raw_value_transforms = raw_value_transforms self.gt_value_transforms = gt_value_transforms self.is_train = is_train + self.axis_order = axis_order self.context = context self.construct() @@ -116,8 +119,8 @@ def __len__(self): def __getitem__(self, idx): """Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index.""" - # TODO: make center dictionary by axis center = np.unravel_index(idx, list(self.sampling_box_shape.values())) + center = {c: center[i] for i, c in enumerate(self.axis_order)} self._current_center = center spatial_transforms = self.generate_spatial_transforms() outputs = {} @@ -161,6 +164,7 @@ def construct(self): self._iter_coords = None self._current_center = None self._current_spatial_transforms = None + self._rng = None self.input_sources = {} for array_name, array_info in self.input_arrays.items(): self.input_sources[array_name] = CellMapImage( @@ -203,22 +207,47 @@ def construct(self): def generate_spatial_transforms(self): """Generates spatial transforms for the dataset.""" + if self._rng is None: + rng = np.random.default_rng() + else: + rng = self._rng + if not self.is_train or self.spatial_transforms is None: return None spatial_transforms = {} - # TODO - ... + for transform, params in self.spatial_transforms.items(): + if transform == "mirror": + # input: "mirror": {"axes": {"x": 0.5, "y": 0.5, "z":0.1}} + # output: {"mirror": ["x", "y"]} + spatial_transforms[transform] = [] + for axis, prob in params["axes"]: + if rng.random() < prob: + spatial_transforms[transform].append(axis) + elif transform == "transpose": + # only reorder axes specified in params + # input: "transpose": {"axes": ["x", "z"]} + # output: {"transpose": {"x": 2, "y": 1, "z": 0}} + axes = {axis: i for i, axis in enumerate(self.axis_order)} + shuffled_axes = rng.permutation( + [axes[a] for a in params["axes"]] + ) # shuffle indices + shuffled_axes = { + axis: shuffled_axes[i] for i, axis in enumerate(params["axes"]) + } # reassign axes + spatial_transforms[transform] = axes.update(shuffled_axes) + else: + raise ValueError(f"Unknown spatial transform: {transform}") self._current_spatial_transforms = spatial_transforms @property def largest_voxel_sizes(self): """Returns the largest voxel size of the dataset.""" if self._largest_voxel_size is None: - largest_voxel_size = {c: 0 for c in "zyx"} + largest_voxel_size = {c: 0 for c in self.axis_order} for source in [self.input_sources.values(), self.target_sources.values()]: if source.scale is None: continue - for c, size in zip("zyx", source.scale): + for c, size in zip(self.axis_order, source.scale): largest_voxel_size[c] = max(largest_voxel_size[c], size) self._largest_voxel_size = largest_voxel_size @@ -228,7 +257,7 @@ def largest_voxel_sizes(self): def bounding_box(self): """Returns the bounding box of the dataset.""" if self._bounding_box is None: - bounding_box = {c: [0, 2**32] for c in "zyx"} + bounding_box = {c: [0, 2**32] for c in self.axis_order} for source in [self.input_sources.values(), self.target_sources.values()]: if source.bounding_box is None: continue @@ -242,7 +271,7 @@ def bounding_box(self): def bounding_box_shape(self): """Returns the shape of the bounding box of the dataset in voxels of the largest voxel size.""" if self._bounding_box_shape is None: - bounding_box_shape = {c: 0 for c in "zyx"} + bounding_box_shape = {c: 0 for c in self.axis_order} for c, (start, stop) in self.bounding_box.items(): size = stop - start size /= self.largest_voxel_sizes[c] @@ -254,7 +283,7 @@ def bounding_box_shape(self): def sampling_box(self): """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box).""" if self._sampling_box is None: - sampling_box = {c: [0, 2**32] for c in "zyx"} + sampling_box = {c: [0, 2**32] for c in self.axis_order} for source in [self.input_sources.values(), self.target_sources.values()]: if source.sampling_box is None: continue @@ -268,7 +297,7 @@ def sampling_box(self): def sampling_box_shape(self): """Returns the shape of the sampling box of the dataset in voxels of the largest voxel size.""" if self._sampling_box_shape is None: - sampling_box_shape = {c: 0 for c in "zyx"} + sampling_box_shape = {c: 0 for c in self.axis_order} for c, (start, stop) in self.sampling_box.items(): size = stop - start size /= self.largest_voxel_sizes[c] diff --git a/src/cellmap_data/datasplit.py b/src/cellmap_data/datasplit.py index 5870d28..5a89af7 100644 --- a/src/cellmap_data/datasplit.py +++ b/src/cellmap_data/datasplit.py @@ -1,5 +1,4 @@ import csv -from torchvision.transforms.v2 import RandomApply from typing import Callable, Dict, Iterable, Optional, Sequence import tensorstore diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index a1e06aa..8263c5b 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -58,12 +58,9 @@ def __init__( self.context = context self.construct() - def __getitem__(self, center: Sequence[float] | dict[str, float]) -> torch.Tensor: + def __getitem__(self, center: dict[str, float]) -> torch.Tensor: """Returns image data centered around the given point, based on the scale and shape of the target output image.""" # Find vectors of coordinates in world space to pull data from - if isinstance(center, Sequence): - temp = {c: center[i] for i, c in enumerate(self.axes)} - center = temp coords = {} for c in self.axes: coords[c] = np.linspace( @@ -136,27 +133,35 @@ def apply_spatial_transforms( ) -> torch.Tensor: """Applies spatial transformations to the given coordinates.""" # Apply spatial transformations to the coordinates + # TODO: Implement non-90 degree rotations if self._current_spatial_transforms is not None: for transform, params in self._current_spatial_transforms.items(): if transform not in self.post_image_transforms: - # TODO: Implement non-90 degree rotations - # TODO - ... + if transform == "mirror": + for axis in params: + # TODO: Make sure this works and doesn't collapse to coords + coords[axis] = coords[axis][::-1] + else: + raise ValueError(f"Unknown spatial transform: {transform}") self._last_coords = coords # Pull data from the image data = self.return_data(coords) + data = data.values # Apply and spatial transformations that require the image array (e.g. transpose) if self._current_spatial_transforms is not None: for transform, params in self._current_spatial_transforms.items(): if transform in self.post_image_transforms: if transform == "transpose": - data = data.transpose(*params) + # TODO ... make sure this works + # data = data.transpose(*params) + new_order = [params[c] for c in self.axes] + data = np.transpose(data, new_order) else: raise ValueError(f"Unknown spatial transform: {transform}") - return torch.tensor(data.values) + return torch.tensor(data) def return_data(self, coords: dict[str, Sequence[float]]): # Pull data from the image based on the given coordinates. This interpolates the data to the nearest pixel automatically.