Skip to content

Commit

Permalink
Readd missing function
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Nov 16, 2024
1 parent d070cbf commit e5d2665
Showing 1 changed file with 105 additions and 12 deletions.
117 changes: 105 additions & 12 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import os
import shutil
from typing import Literal, TypeAlias
from pathlib import Path
from typing import Any, Literal, TypeAlias

import datatree
import numpy as np
import xarray
from alphabase.io import tempmmap
from spatialdata import SpatialData
Expand All @@ -18,7 +20,6 @@
get_chunk_size,
)

# Type aliases
ChunkSize: TypeAlias = tuple[int, int]
ObjectType: TypeAlias = Literal["images", "labels", "points", "tables"]

Expand Down Expand Up @@ -173,7 +174,7 @@ def _write_segmentation_object_sdata(

def _write_segmentation_sdata(
self,
segmentation: xarray.DataArray,
segmentation: xarray.DataArray | np.ndarray,
segmentation_label: str,
classes: set[str] | None = None,
chunks: ChunkSize = (1000, 1000),
Expand Down Expand Up @@ -249,26 +250,32 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None
centroids_object = self._get_centers(_sdata, segmentation_label)
self._write_points_object_sdata(centroids_object, self.centers_name, overwrite=overwrite)

def _load_input_image_to_memmap(self, tmp_dir_abs_path: str, image: xarray.DataArray | None = None) -> str:
"""Load input image to memory mapped array.
def _load_input_image_to_memmap(self, tmp_dir_abs_path: str | Path, image: np.NDArray[Any] | None = None) -> str:
"""Helper function to load the input image from sdata to memory mapped temp arrays for faster access.
Loading happens in a chunked manner to avoid memory issues.
Args:
tmp_dir_abs_path: Path for temporary storage
image: Optional image data to load
tmp_dir_abs_path: Absolute path to the directory where the memory mapped arrays should be stored.
image: Optional pre-loaded image array to process.
Returns:
Path to memory mapped array
Path to the memory mapped array. Can be reconneted to using the `mmap_array_from_path`
function from the alphabase.io.tempmmap module.
Raises:
ValueError: If input image not found
ValueError: If input image is not found in sdata object.
"""
if image is None:
_sdata = self._check_sdata_status(return_sdata=True)

if not self.input_image_status:
raise ValueError("Input image not found in sdata object.")
image = self._get_input_image(_sdata)

image = self._get_input_image(_sdata)
shape = image.shape

# initialize empty memory mapped arrays to store the data
path_input_image = tempmmap.create_empty_mmap(
shape=shape,
dtype=image.dtype,
Expand All @@ -277,15 +284,101 @@ def _load_input_image_to_memmap(self, tmp_dir_abs_path: str, image: xarray.DataA

input_image_mmap = tempmmap.mmap_array_from_path(path_input_image)

Z: int | None = None
if len(shape) == 3:
C, Y, X = shape
for c in range(C):
input_image_mmap[c] = image[c].compute()

elif len(shape) == 4:
Z, C, Y, X = shape

if Z is not None:
for z in range(Z):
for c in range(C):
input_image_mmap[z][c] = image[z][c].compute()
else:
for c in range(C):
input_image_mmap[c] = image[c].compute()

# cleanup the cache
del input_image_mmap, image

return path_input_image

def _load_seg_to_memmap(
self,
seg_name: list[str],
tmp_dir_abs_path: str | Path,
) -> str:
"""Helper function to load segmentation masks from sdata to memory mapped temp arrays for faster access.
Loading happens in a chunked manner to avoid memory issues.
Args:
seg_name: List of segmentation element names that should be loaded found in the sdata object.
The segmentation elments need to have the same size.
tmp_dir_abs_path: Absolute path to the directory where the memory mapped arrays should be stored.
Returns:
Path to the memory mapped array. Can be reconneted to using the `mmap_array_from_path`
function from the alphabase.io.tempmmap module.
Raises:
AssertionError: If not all segmentation elements are found in sdata object or if shapes don't match.
"""
_sdata = self._check_sdata_status(return_sdata=True)

assert all(
seg in _sdata.labels for seg in seg_name
), "Not all passed segmentation elements found in sdata object."

seg_objects = [_sdata.labels[seg] for seg in seg_name]

shapes = [seg.shape for seg in seg_objects]

Z: int | None = None
Y: int | None = None
X: int | None = None
for shape in shapes:
if len(shape) == 2:
if Y is None:
Y, X = shape
else:
assert Y == shape[0]
assert X == shape[1]
elif len(shape) == 3:
if Z is None:
Z, Y, X = shape
else:
assert Z == shape[0]
assert Y == shape[1]
assert X == shape[2]

n_masks = len(seg_objects)

if Z is not None and Y is not None and X is not None:
shape = (n_masks, Z, Y, X)
elif Y is not None and X is not None:
shape = (n_masks, Y, X)
else:
raise ValueError("Unable to determine shape from segmentation masks")

# initialize empty memory mapped arrays to store the data
path_seg_masks = tempmmap.create_empty_mmap(
shape=shape,
dtype=seg_objects[0].data.dtype,
tmp_dir_abs_path=tmp_dir_abs_path,
)

seg_masks = tempmmap.mmap_array_from_path(path_seg_masks)

for i, seg in enumerate(seg_objects):
if Z is not None:
for z in range(Z):
seg_masks[i][z] = seg.data[z].compute()
else:
seg_masks[i] = seg.data.compute()

# cleanup the cache
del seg_masks, seg_objects, seg

return path_seg_masks

0 comments on commit e5d2665

Please sign in to comment.