Skip to content

Commit

Permalink
rework selection workflow to precalculate coordinates
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Dec 10, 2024
1 parent 976d41b commit 8320ef6
Showing 1 changed file with 46 additions and 21 deletions.
67 changes: 46 additions & 21 deletions src/scportrait/pipeline/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import numpy as np
import h5py
import pickle
from lmd.lib import SegmentationLoader
from alphabase.io import tempmmap

from lmd.segmentation import _create_coord_index_sparse
import timeit

class LMDSelection(ProcessingStep):
"""
Expand All @@ -14,10 +16,19 @@ class LMDSelection(ProcessingStep):

# define all valid path optimization methods used with the "path_optimization" argument in the configuration
VALID_PATH_OPTIMIZERS = ["none", "hilbert", "greedy"]
COORD_PICKLE_FILE = "coord_index.pkl"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

#check config for required parameters
assert "segmentation_channel" in self.config, "segmentation_channel not defined in config"

self.segmentation_channel_to_select = self.config["segmentation_channel"]

#the coord pickle file should be saved in the same directory as the segmentation results because it is based on that segmentation (if that segmentation is updated or changed it should also be recalculated)
self.coord_pickle_file_path = os.path.join(self.project_location, self.DEFAULT_SEGMENTATION_DIR_NAME, f"{self.segmentation_channel_to_select}_{self.COORD_PICKLE_FILE}")

def process(self, hdf_location, cell_sets, calibration_marker, name=None):
"""
Process function for selecting cells and generating their XML.
Expand All @@ -27,12 +38,13 @@ def process(self, hdf_location, cell_sets, calibration_marker, name=None):
hdf_location (str): Path of the segmentation hdf5 file. If this class is used as part of a project processing workflow, this argument will be provided.
cell_sets (list of dict): List of dictionaries containing the sets of cells which should be sorted into a single well.
calibration_marker (numpy.array): Array of size ‘(3,2)’ containing the calibration marker coordinates in the ‘(row, column)’ format.
name (str): Name of the output file. If not provided, the name will be generated based on the names of the cell sets or if also not specified set to "selected_cells".
Important:
If this class is used as part of a project processing workflow, the first argument will be provided by the ``Project``
class based on the previous segmentation. Therefore, only the second and third argument need to be provided. The Project
class will automaticly provide the most recent segmentation forward together with the supplied parameters.
class will automatically provide the most recent segmentation together with the supplied parameters.
Example:
Expand All @@ -46,7 +58,6 @@ class will automaticly provide the most recent segmentation forward together wit
# A numpy Array of shape (3, 2) should be passed.
calibration_marker = np.array([marker_0, marker_1, marker_2])
# Sets of cells can be defined by providing a name and a list of classes in a dictionary.
cells_to_select = [{"name": "dataset1", "classes": [1,2,3]}]
Expand All @@ -70,7 +81,7 @@ class will automaticly provide the most recent segmentation forward together wit
threads: 10
# the number of parallel processes to use for generation of cell sets each set
# will processed with the designated number of threads
# will be processed with the designated number of threads
processes_cell_sets: 1
# defines the channel used for generating cutting masks
Expand Down Expand Up @@ -123,22 +134,36 @@ class will automaticly provide the most recent segmentation forward together wit

self.log("Selection process started")

## TO Do
# check if classes and seglookup table already exist as pickle file
# if not create them
# else load them and proceed with selection

# load segmentation from hdf5
hf = h5py.File(hdf_location, "r")
hdf_labels = hf.get("labels")

# create memory mapped temporary array for saving the segmentation
c, x, y = hdf_labels.shape
segmentation = tempmmap.array(
shape=(x, y), dtype=hdf_labels.dtype, tmp_dir_abs_path=self._tmp_dir_path
)
segmentation = hdf_labels[self.config["segmentation_channel"], :, :]

#calculate a coordinate lookup file where for each cell id the coordinates for their location in the segmentation mask are stored
if os.path.exists(self.coord_pickle_file_path):
self.log(f"Loading coordinate lookup index from file {self.coord_pickle_file_path}.")
with open(self.coord_pickle_file_path, "rb") as f:
coord_index = pickle.load(f)
segmentation = None
else:
self.log("Calculating coordinate lookup index.")

#start timer for performance evaluation
start_time = timeit.default_timer()

# load segmentation from hdf5
with h5py.File(hdf_location, "r") as hf:
hdf_labels = hf.get("labels")

# create memory mapped temporary array for saving the segmentation
c, x, y = hdf_labels.shape
segmentation = tempmmap.array(
shape=(x, y), dtype=hdf_labels.dtype, tmp_dir_abs_path=self._tmp_dir_path
)
segmentation = hdf_labels[self.config["segmentation_channel"], :, :]

coord_index = _create_coord_index_sparse(segmentation)
with open(self.coord_pickle_file_path, "wb") as f:
pickle.dump(coord_index, f)
self.log(f"Coordinate lookup index saved to file {self.coord_pickle_file_path}.")
self.log(f"Coordinate lookup index calculation took {timeit.default_timer() - start_time} seconds.")

#add default orientation transform
self.config["orientation_transform"] = np.array([[0, -1], [1, 0]])

sl = SegmentationLoader(
Expand All @@ -147,7 +172,7 @@ class will automaticly provide the most recent segmentation forward together wit
processes=self.config["processes_cell_sets"],
)

shape_collection = sl(segmentation, cell_sets, calibration_marker)
shape_collection = sl(segmentation, cell_sets, calibration_marker, coords_lookup=coord_index)

if self.debug:
shape_collection.plot(calibration=True)
Expand Down

0 comments on commit 8320ef6

Please sign in to comment.