From cb0d30d40870d604fcde09695d7b6d76c503716b Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Sun, 11 Feb 2024 14:30:19 -0500 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=92=A9=20dev/post-process:=20post?= =?UTF-8?q?-processing=20related=20changes,=20starting=20from=20rhoadesj/d?= =?UTF-8?q?ev?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../post_processors/argmax_post_processor.py | 3 +- .../post_processors/dummy_post_processor.py | 2 +- .../tasks/post_processors/post_processor.py | 2 + .../threshold_post_processor.py | 2 + .../watershed_post_processor.py | 108 +++++++++++------- 5 files changed, 72 insertions(+), 45 deletions(-) diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 709d1de34..799f2651e 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -19,7 +19,7 @@ def set_prediction(self, prediction_array_identifier): prediction_array_identifier ) - def process(self, parameters, output_array_identifier): + def process(self, parameters, output_array_identifier, overwrite: bool = False): output_array = ZarrArray.create_from_array_identifier( output_array_identifier, [dim for dim in self.prediction_array.axes if dim != "c"], @@ -27,6 +27,7 @@ def process(self, parameters, output_array_identifier): None, self.prediction_array.voxel_size, np.uint8, + overwrite=overwrite, ) output_array[self.prediction_array.roi] = np.argmax( diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index 5a2c7810a..ddb249539 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -21,7 +21,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: def set_prediction(self, prediction_array): pass - def process(self, parameters, output_array_identifier): + def process(self, parameters, output_array_identifier, overwrite: bool = False): # store some dummy data f = zarr.open(str(output_array_identifier.container), "a") f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index 020361cb9..4e4102d6b 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -33,6 +33,8 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", + overwrite: "bool", + blockwise: "bool", ) -> "Array": """Convert predictions into the final output.""" pass diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 67ffdd066..32bf4cfc0 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -28,6 +28,7 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", + overwrite: bool = False, ) -> ZarrArray: # TODO: Investigate Liskov substitution princple and whether it is a problem here # OOP theory states the super class should always be replaceable with its subclasses @@ -47,6 +48,7 @@ def process( self.prediction_array.num_channels, self.prediction_array.voxel_size, np.uint8, + overwrite=overwrite, ) output_array[self.prediction_array.roi] = ( diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 1a7c4627b..307806772 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -24,7 +24,9 @@ def enumerate_parameters(self): """Enumerate all possible parameters of this post-processor. Should return instances of ``PostProcessorParameters``.""" - for i, bias in enumerate([0.1, 0.25, 0.5, 0.75, 0.9]): + for i, bias in enumerate( + [0.1, 0.3, 0.5, 0.7, 0.9] + ): # TODO: add this to the config yield WatershedPostProcessorParameters(id=i, bias=bias) def set_prediction(self, prediction_array_identifier): @@ -32,45 +34,65 @@ def set_prediction(self, prediction_array_identifier): prediction_array_identifier ) - def process(self, parameters, output_array_identifier): - output_array = ZarrArray.create_from_array_identifier( - output_array_identifier, - [axis for axis in self.prediction_array.axes if axis != "c"], - self.prediction_array.roi, - None, - self.prediction_array.voxel_size, - np.uint64, - ) - # if a previous segmentation is provided, it must have a "grid graph" - # in its metadata. - pred_data = self.prediction_array[self.prediction_array.roi] - affs = pred_data[: len(self.offsets)].astype(np.float64) - segmentation = mws.agglom( - affs - parameters.bias, - self.offsets, - ) - # filter fragments - average_affs = np.mean(affs, axis=0) - - filtered_fragments = [] - - fragment_ids = np.unique(segmentation) - - for fragment, mean in zip( - fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids) - ): - if mean < parameters.bias: - filtered_fragments.append(fragment) - - filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) - replace = np.zeros_like(filtered_fragments) - - # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input - if filtered_fragments.size > 0: - segmentation = npi.remap( - segmentation.flatten(), filtered_fragments, replace - ).reshape(segmentation.shape) - - output_array[self.prediction_array.roi] = segmentation - - return output_array + def process( + self, + parameters, + output_array_identifier, + overwrite: bool = False, + blockwise: bool = False, + ): # TODO: will probably break with large arrays... + if not blockwise: + output_array = ZarrArray.create_from_array_identifier( + output_array_identifier, + [axis for axis in self.prediction_array.axes if axis != "c"], + self.prediction_array.roi, + None, + self.prediction_array.voxel_size, + np.uint64, + overwrite=overwrite, + ) + # if a previous segmentation is provided, it must have a "grid graph" + # in its metadata. + # pred_data = self.prediction_array[self.prediction_array.roi] + # affs = pred_data[: len(self.offsets)].astype( + # np.float64 + # ) # TODO: shouldn't need to be float64 + affs = self.prediction_array[self.prediction_array.roi][: len(self.offsets)] + if affs.dtype == np.uint8: + affs = affs.astype(np.float64) / 255.0 + else: + affs = affs.astype(np.float64) + segmentation = mws.agglom( + affs - parameters.bias, + self.offsets, + ) + # filter fragments + average_affs = np.mean(affs, axis=0) + + filtered_fragments = [] + + fragment_ids = np.unique(segmentation) + + for fragment, mean in zip( + fragment_ids, + measurements.mean(average_affs, segmentation, fragment_ids), + ): + if mean < parameters.bias: + filtered_fragments.append(fragment) + + filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) + replace = np.zeros_like(filtered_fragments) + + # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input + if filtered_fragments.size > 0: + segmentation = npi.remap( + segmentation.flatten(), filtered_fragments, replace + ).reshape(segmentation.shape) + + output_array[self.prediction_array.roi] = segmentation + + return output_array + else: + raise NotImplementedError( + "Blockwise processing not yet implemented." + ) # TODO: add rusty mws