Skip to content

Commit

Permalink
feat: ✨ Implement spatial transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 30, 2024
1 parent 3929c35 commit cf092c9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 22 deletions.
53 changes: 41 additions & 12 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ class CellMapDataset(Dataset):
target_arrays: dict[str, dict[str, Sequence[int | float]]]
input_sources: dict[str, CellMapImage]
target_sources: dict[str, dict[str, CellMapImage | EmptyImage]]
spatial_transforms: Optional[Sequence[dict[str, any]]] # type: ignore
spatial_transforms: Optional[dict[str, any]] # type: ignore
raw_value_transforms: Optional[Callable]
gt_value_transforms: Optional[Callable | Sequence[Callable] | dict[str, Callable]]
context: Optional[tensorstore.Context] # type: ignore
has_data: bool
is_train: bool
axis_order: str
context: Optional[tensorstore.Context] # type: ignore
_bounding_box: Optional[Dict[str, list[int]]]
_bounding_box_shape: Optional[Dict[str, int]]
_sampling_box: Optional[Dict[str, list[int]]]
Expand All @@ -51,12 +52,13 @@ def __init__(
classes: Sequence[str],
input_arrays: dict[str, dict[str, Sequence[int | float]]],
target_arrays: dict[str, dict[str, Sequence[int | float]]],
spatial_transforms: Optional[Sequence[dict[str, any]]] = None, # type: ignore
spatial_transforms: Optional[dict[str, any]] = None, # type: ignore
raw_value_transforms: Optional[Callable] = None,
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None,
is_train: bool = False,
axis_order: str = "zyx",
context: Optional[tensorstore.Context] = None, # type: ignore
):
"""Initializes the CellMapDataset class.
Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(
self.raw_value_transforms = raw_value_transforms
self.gt_value_transforms = gt_value_transforms
self.is_train = is_train
self.axis_order = axis_order
self.context = context
self.construct()

Expand All @@ -116,8 +119,8 @@ def __len__(self):

def __getitem__(self, idx):
"""Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index."""
# TODO: make center dictionary by axis
center = np.unravel_index(idx, list(self.sampling_box_shape.values()))
center = {c: center[i] for i, c in enumerate(self.axis_order)}
self._current_center = center
spatial_transforms = self.generate_spatial_transforms()
outputs = {}
Expand Down Expand Up @@ -161,6 +164,7 @@ def construct(self):
self._iter_coords = None
self._current_center = None
self._current_spatial_transforms = None
self._rng = None
self.input_sources = {}
for array_name, array_info in self.input_arrays.items():
self.input_sources[array_name] = CellMapImage(
Expand Down Expand Up @@ -203,22 +207,47 @@ def construct(self):

def generate_spatial_transforms(self):
"""Generates spatial transforms for the dataset."""
if self._rng is None:
rng = np.random.default_rng()
else:
rng = self._rng

if not self.is_train or self.spatial_transforms is None:
return None
spatial_transforms = {}
# TODO
...
for transform, params in self.spatial_transforms.items():
if transform == "mirror":
# input: "mirror": {"axes": {"x": 0.5, "y": 0.5, "z":0.1}}
# output: {"mirror": ["x", "y"]}
spatial_transforms[transform] = []
for axis, prob in params["axes"]:
if rng.random() < prob:
spatial_transforms[transform].append(axis)
elif transform == "transpose":
# only reorder axes specified in params
# input: "transpose": {"axes": ["x", "z"]}
# output: {"transpose": {"x": 2, "y": 1, "z": 0}}
axes = {axis: i for i, axis in enumerate(self.axis_order)}
shuffled_axes = rng.permutation(
[axes[a] for a in params["axes"]]
) # shuffle indices
shuffled_axes = {
axis: shuffled_axes[i] for i, axis in enumerate(params["axes"])
} # reassign axes
spatial_transforms[transform] = axes.update(shuffled_axes)
else:
raise ValueError(f"Unknown spatial transform: {transform}")
self._current_spatial_transforms = spatial_transforms

@property
def largest_voxel_sizes(self):
"""Returns the largest voxel size of the dataset."""
if self._largest_voxel_size is None:
largest_voxel_size = {c: 0 for c in "zyx"}
largest_voxel_size = {c: 0 for c in self.axis_order}
for source in [self.input_sources.values(), self.target_sources.values()]:
if source.scale is None:
continue
for c, size in zip("zyx", source.scale):
for c, size in zip(self.axis_order, source.scale):
largest_voxel_size[c] = max(largest_voxel_size[c], size)
self._largest_voxel_size = largest_voxel_size

Expand All @@ -228,7 +257,7 @@ def largest_voxel_sizes(self):
def bounding_box(self):
"""Returns the bounding box of the dataset."""
if self._bounding_box is None:
bounding_box = {c: [0, 2**32] for c in "zyx"}
bounding_box = {c: [0, 2**32] for c in self.axis_order}
for source in [self.input_sources.values(), self.target_sources.values()]:
if source.bounding_box is None:
continue
Expand All @@ -242,7 +271,7 @@ def bounding_box(self):
def bounding_box_shape(self):
"""Returns the shape of the bounding box of the dataset in voxels of the largest voxel size."""
if self._bounding_box_shape is None:
bounding_box_shape = {c: 0 for c in "zyx"}
bounding_box_shape = {c: 0 for c in self.axis_order}
for c, (start, stop) in self.bounding_box.items():
size = stop - start
size /= self.largest_voxel_sizes[c]
Expand All @@ -254,7 +283,7 @@ def bounding_box_shape(self):
def sampling_box(self):
"""Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box)."""
if self._sampling_box is None:
sampling_box = {c: [0, 2**32] for c in "zyx"}
sampling_box = {c: [0, 2**32] for c in self.axis_order}
for source in [self.input_sources.values(), self.target_sources.values()]:
if source.sampling_box is None:
continue
Expand All @@ -268,7 +297,7 @@ def sampling_box(self):
def sampling_box_shape(self):
"""Returns the shape of the sampling box of the dataset in voxels of the largest voxel size."""
if self._sampling_box_shape is None:
sampling_box_shape = {c: 0 for c in "zyx"}
sampling_box_shape = {c: 0 for c in self.axis_order}
for c, (start, stop) in self.sampling_box.items():
size = stop - start
size /= self.largest_voxel_sizes[c]
Expand Down
1 change: 0 additions & 1 deletion src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import csv
from torchvision.transforms.v2 import RandomApply
from typing import Callable, Dict, Iterable, Optional, Sequence
import tensorstore

Expand Down
23 changes: 14 additions & 9 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,9 @@ def __init__(
self.context = context
self.construct()

def __getitem__(self, center: Sequence[float] | dict[str, float]) -> torch.Tensor:
def __getitem__(self, center: dict[str, float]) -> torch.Tensor:
"""Returns image data centered around the given point, based on the scale and shape of the target output image."""
# Find vectors of coordinates in world space to pull data from
if isinstance(center, Sequence):
temp = {c: center[i] for i, c in enumerate(self.axes)}
center = temp
coords = {}
for c in self.axes:
coords[c] = np.linspace(
Expand Down Expand Up @@ -136,27 +133,35 @@ def apply_spatial_transforms(
) -> torch.Tensor:
"""Applies spatial transformations to the given coordinates."""
# Apply spatial transformations to the coordinates
# TODO: Implement non-90 degree rotations
if self._current_spatial_transforms is not None:
for transform, params in self._current_spatial_transforms.items():
if transform not in self.post_image_transforms:
# TODO: Implement non-90 degree rotations
# TODO
...
if transform == "mirror":
for axis in params:
# TODO: Make sure this works and doesn't collapse to coords
coords[axis] = coords[axis][::-1]
else:
raise ValueError(f"Unknown spatial transform: {transform}")
self._last_coords = coords

# Pull data from the image
data = self.return_data(coords)
data = data.values

# Apply and spatial transformations that require the image array (e.g. transpose)
if self._current_spatial_transforms is not None:
for transform, params in self._current_spatial_transforms.items():
if transform in self.post_image_transforms:
if transform == "transpose":
data = data.transpose(*params)
# TODO ... make sure this works
# data = data.transpose(*params)
new_order = [params[c] for c in self.axes]
data = np.transpose(data, new_order)
else:
raise ValueError(f"Unknown spatial transform: {transform}")

return torch.tensor(data.values)
return torch.tensor(data)

def return_data(self, coords: dict[str, Sequence[float]]):
# Pull data from the image based on the given coordinates. This interpolates the data to the nearest pixel automatically.
Expand Down

0 comments on commit cf092c9

Please sign in to comment.