diff --git a/dacapo/apply.py b/dacapo/apply.py index 64f23df3c..b33cffe46 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -1,12 +1,200 @@ import logging +from typing import Optional +from funlib.geometry import Roi, Coordinate +import numpy as np +from dacapo.experiments.datasplits.datasets.arrays.array import Array +from dacapo.experiments.datasplits.datasets.dataset import Dataset +from dacapo.experiments.run import Run + +from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( + PostProcessorParameters, +) +import dacapo.experiments.tasks.post_processors as post_processors +from dacapo.store.array_store import LocalArrayIdentifier +from dacapo.predict import predict +from dacapo.compute_context import LocalTorch, ComputeContext +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray +from dacapo.store import ( + create_config_store, + create_weights_store, +) + +from pathlib import Path logger = logging.getLogger(__name__) -def apply(run_name: str, iteration: int, dataset_name: str): +def apply( + run_name: str, + input_container: Path or str, + input_dataset: str, + output_path: Path or str, + validation_dataset: Optional[Dataset or str] = None, + criterion: Optional[str] = "voi", + iteration: Optional[int] = None, + parameters: Optional[PostProcessorParameters or str] = None, + roi: Optional[Roi or str] = None, + num_cpu_workers: int = 30, + output_dtype: Optional[np.dtype or str] = np.uint8, + compute_context: ComputeContext = LocalTorch(), + overwrite: bool = True, + file_format: str = "zarr", +): + """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" + if isinstance(output_dtype, str): + output_dtype = np.dtype(output_dtype) + + if isinstance(roi, str): + start, end = zip( + *[ + tuple(int(coord) for coord in axis.split(":")) + for axis in roi.strip("[]").split(",") + ] + ) + roi = Roi( + Coordinate(start), + Coordinate(end) - Coordinate(start), + ) + + assert (validation_dataset is not None and isinstance(criterion, str)) or ( + isinstance(iteration, int) + ), "Either validation_dataset and criterion, or iteration must be provided." + + # retrieving run + logger.info("Loading run %s", run_name) + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + # create weights store + weights_store = create_weights_store() + + # load weights + if iteration is None: + # weights_store._load_best(run, criterion) + iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion) + logger.info("Loading weights for iteration %i", iteration) + weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights? + + # find the best parameters + if isinstance(validation_dataset, str): + val_ds_name = validation_dataset + validation_dataset = [ + dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name + ][0] + logger.info("Finding best parameters for validation dataset %s", validation_dataset) + if parameters is None: + parameters = run.task.evaluator.get_overall_best_parameters( + validation_dataset, criterion + ) + assert ( + parameters is not None + ), "Unable to retieve parameters. Parameters must be provided explicitly." + + elif isinstance(parameters, str): + try: + post_processor_name = parameters.split("(")[0] + post_processor_kwargs = parameters.split("(")[1].strip(")").split(",") + post_processor_kwargs = { + key.strip(): value.strip() + for key, value in [arg.split("=") for arg in post_processor_kwargs] + } + for key, value in post_processor_kwargs.items(): + if value.isdigit(): + post_processor_kwargs[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + post_processor_kwargs[key] = float(value) + except: + raise ValueError( + f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'" + ) + try: + parameters = getattr(post_processors, post_processor_name)( + **post_processor_kwargs + ) + except Exception as e: + logger.error( + f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.", + exc_info=True, + ) + raise e + + assert isinstance( + parameters, PostProcessorParameters + ), "Parameters must be parsable to a PostProcessorParameters object." + + # make array identifiers for input, predictions and outputs + input_array_identifier = LocalArrayIdentifier(input_container, input_dataset) + input_array = ZarrArray.open_from_array_identifier(input_array_identifier) + roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( + input_array.roi + ) + output_container = Path( + output_path, + "".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}", + ) + prediction_array_identifier = LocalArrayIdentifier( + output_container, f"prediction_{run_name}_{iteration}" + ) + output_array_identifier = LocalArrayIdentifier( + output_container, f"output_{run_name}_{iteration}_{parameters}" + ) + logger.info( - "Applying results from run %s at iteration %d to dataset %s", - run_name, + "Applying best results from run %s at iteration %i to dataset %s", + run.name, iteration, - dataset_name, + Path(input_container, input_dataset), + ) + return apply_run( + run, + parameters, + input_array, + prediction_array_identifier, + output_array_identifier, + roi, + num_cpu_workers, + output_dtype, + compute_context, + overwrite, + ) + + +def apply_run( + run: Run, + parameters: PostProcessorParameters, + input_array: Array, + prediction_array_identifier: LocalArrayIdentifier, + output_array_identifier: LocalArrayIdentifier, + roi: Optional[Roi] = None, + num_cpu_workers: int = 30, + output_dtype: Optional[np.dtype] = np.uint8, + compute_context: ComputeContext = LocalTorch(), + overwrite: bool = True, +): + """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" + run.model.eval() + + # render prediction dataset + logger.info("Predicting on dataset %s", prediction_array_identifier) + predict( + run.model, + input_array, + prediction_array_identifier, + output_roi=roi, + num_cpu_workers=num_cpu_workers, + output_dtype=output_dtype, + compute_context=compute_context, + overwrite=overwrite, ) + + # post-process the output + logger.info("Post-processing output to dataset %s", output_array_identifier) + post_processor = run.task.post_processor + post_processor.set_prediction(prediction_array_identifier) + post_processor.process( + parameters, output_array_identifier, overwrite=overwrite, blockwise=True + ) + + logger.info("Done") + return diff --git a/dacapo/cli.py b/dacapo/cli.py index 76a5e18e0..f97906508 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -1,3 +1,5 @@ +from typing import Optional + import dacapo import click import logging @@ -40,21 +42,52 @@ def validate(run_name, iteration): @cli.command() @click.option( - "-r", "--run", required=True, type=str, help="The name of the run to use." + "-r", "--run-name", required=True, type=str, help="The name of the run to apply." ) @click.option( - "-i", - "--iteration", + "-ic", + "--input_container", required=True, - type=int, - help="The iteration weights and parameters to use.", + type=click.Path(exists=True, file_okay=False), ) +@click.option("-id", "--input_dataset", required=True, type=str) +@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) +@click.option("-vd", "--validation_dataset", type=str, default=None) +@click.option("-c", "--criterion", default="voi") +@click.option("-i", "--iteration", type=int, default=None) +@click.option("-p", "--parameters", type=str, default=None) @click.option( - "-r", - "--dataset", - required=True, + "-roi", + "--roi", type=str, - help="The name of the dataset to apply the run to.", + required=False, + help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", ) -def apply(run_name, iteration, dataset_name): - dacapo.apply(run_name, iteration, dataset_name) +@click.option("-w", "--num_cpu_workers", type=int, default=30) +@click.option("-dt", "--output_dtype", type=str, default="uint8") +def apply( + run_name: str, + input_container: str, + input_dataset: str, + output_path: str, + validation_dataset: Optional[str] = None, + criterion: Optional[str] = "voi", + iteration: Optional[int] = None, + parameters: Optional[str] = None, + roi: Optional[str] = None, + num_cpu_workers: int = 30, + output_dtype: Optional[str] = "uint8", +): + dacapo.apply( + run_name, + input_container, + input_dataset, + output_path, + validation_dataset, + criterion, + iteration, + parameters, + roi, + num_cpu_workers, + output_dtype, + ) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 42030e701..25f2c224e 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -52,7 +52,7 @@ def axes(self): logger.debug( "DaCapo expects Zarr datasets to have an 'axes' attribute!\n" f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" - f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}", + f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}", ) return ["c", "z", "y", "x"][-self.dims : :] diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index c1014fd02..4a1b8cc4a 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -12,7 +12,11 @@ def __init__(self, task_config): """Create a `DummyTask` from a `DummyTaskConfig`.""" self.predictor = AffinitiesPredictor( - neighborhood=task_config.neighborhood, lsds=task_config.lsds + neighborhood=task_config.neighborhood, + lsds=task_config.lsds, + num_voxels=task_config.num_voxels, + downsample_lsds=task_config.downsample_lsds, + grow_boundary_iterations=task_config.grow_boundary_iterations, ) self.loss = AffinitiesLoss(len(task_config.neighborhood)) self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood) diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index d4b2c6199..0a94db79d 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -30,3 +30,23 @@ class AffinitiesTaskConfig(TaskConfig): "It has been shown that lsds as an auxiliary task can help affinity predictions." }, ) + num_voxels: int = attr.ib( + default=20, + metadata={ + "help_text": "The number of voxels to use for the gaussian sigma when computing lsds." + }, + ) + downsample_lsds: int = attr.ib( + default=1, + metadata={ + "help_text": "The amount to downsample the lsds. " + "This is useful for speeding up training and inference." + }, + ) + grow_boundary_iterations: int = attr.ib( + default=0, + metadata={ + "help_text": "The number of iterations to run the grow boundaries algorithm. " + "This is useful for refining the boundaries of the affinities, and reducing merging of adjacent objects." + }, + ) diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index 81efb2375..40d81f5da 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -17,9 +17,17 @@ class AffinitiesPredictor(Predictor): - def __init__(self, neighborhood: List[Coordinate], lsds: bool = True): + def __init__( + self, + neighborhood: List[Coordinate], + lsds: bool = True, + num_voxels: int = 20, + downsample_lsds: int = 1, + grow_boundary_iterations: int = 0, + ): self.neighborhood = neighborhood self.lsds = lsds + self.num_voxels = num_voxels if lsds: self._extractor = None if self.dims == 2: @@ -30,12 +38,16 @@ def __init__(self, neighborhood: List[Coordinate], lsds: bool = True): raise ValueError( f"Cannot compute lsds on volumes with {self.dims} dimensions" ) + self.downsample_lsds = downsample_lsds else: self.num_lsds = 0 + self.grow_boundary_iterations = grow_boundary_iterations def extractor(self, voxel_size): if self._extractor is None: - self._extractor = LsdExtractor(self.sigma(voxel_size)) + self._extractor = LsdExtractor( + self.sigma(voxel_size), downsample=self.downsample_lsds + ) return self._extractor @@ -45,8 +57,7 @@ def dims(self): def sigma(self, voxel_size): voxel_dist = max(voxel_size) # arbitrarily chosen - num_voxels = 10 # arbitrarily chosen - sigma = voxel_dist * num_voxels + sigma = voxel_dist * self.num_voxels # arbitrarily chosen return Coordinate((sigma,) * self.dims) def lsd_pad(self, voxel_size): @@ -118,7 +129,9 @@ def _grow_boundaries(self, mask, slab): slice(start[d], start[d] + slab[d]) for d in range(len(slab)) ) mask_slab = mask[slices] - dilated_mask_slab = ndimage.binary_dilation(mask_slab, iterations=1) + dilated_mask_slab = ndimage.binary_dilation( + mask_slab, iterations=self.grow_boundary_iterations + ) foreground[slices] = dilated_mask_slab # label new background @@ -130,10 +143,12 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): (moving_class_counts, moving_lsd_class_counts) = ( moving_class_counts if moving_class_counts is not None else (None, None) ) - # mask_data = self._grow_boundaries( - # mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes) - # ) - mask_data = mask[target.roi] + if self.grow_boundary_iterations > 0: + mask_data = self._grow_boundaries( + mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes) + ) + else: + mask_data = mask[target.roi] aff_weights, moving_class_counts = balance_weights( target[target.roi][: self.num_channels - self.num_lsds].astype(np.uint8), 2, diff --git a/dacapo/experiments/training_stats.py b/dacapo/experiments/training_stats.py index cd3fcd012..72c631ed4 100644 --- a/dacapo/experiments/training_stats.py +++ b/dacapo/experiments/training_stats.py @@ -16,7 +16,9 @@ class TrainingStats: def add_iteration_stats(self, iteration_stats: TrainingIterationStats) -> None: if len(self.iteration_stats) > 0: - assert iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 + assert ( + iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 + ), f"Expected iteration {self.iteration_stats[-1].iteration + 1}, got {iteration_stats.iteration}" self.iteration_stats.append(iteration_stats) diff --git a/dacapo/predict.py b/dacapo/predict.py index 5a40e303c..340517528 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -24,6 +24,8 @@ def predict( num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), output_roi: Optional[Roi] = None, + output_dtype: Optional[np.dtype] = np.float32, # add necessary type conversions + overwrite: bool = False, ): # get the model's input and output size @@ -56,7 +58,8 @@ def predict( output_roi, model.num_out_channels, output_voxel_size, - np.float32, + output_dtype, + overwrite=overwrite, ) # create gunpowder keys @@ -75,8 +78,8 @@ def predict( # raw: (1, c, d, h, w) gt_padding = (output_size - output_roi.shape) % output_size - prediction_roi = output_roi.grow(gt_padding) - + prediction_roi = output_roi.grow(gt_padding) # TODO: are we sure this makes sense? + # TODO: Add cache node? # predict pipeline += gp_torch.Predict( model=model, @@ -84,7 +87,9 @@ def predict( outputs={0: prediction}, array_specs={ prediction: gp.ArraySpec( - roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32 + roi=prediction_roi, + voxel_size=output_voxel_size, + dtype=np.float32, # assumes network output is float32 ) }, spawn_subprocess=False, @@ -97,22 +102,29 @@ def predict( pipeline += gp.Squeeze([raw, prediction]) # raw: (c, d, h, w) # prediction: (c, d, h, w) - # raw: (c, d, h, w) - # prediction: (c, d, h, w) + + # convert to uint8 if necessary: + if output_dtype == np.uint8: + pipeline += gp.IntensityScaleShift( + prediction, scale=255.0, shift=0.0 + ) # assumes float32 is [0,1] + pipeline += gp.AsType(prediction, output_dtype) # write to zarr pipeline += gp.ZarrWrite( {prediction: prediction_array_identifier.dataset}, prediction_array_identifier.container.parent, prediction_array_identifier.container.name, - dataset_dtypes={prediction: np.float32}, + dataset_dtypes={prediction: output_dtype}, ) # create reference batch request ref_request = gp.BatchRequest() ref_request.add(raw, input_size) ref_request.add(prediction, output_size) - pipeline += gp.Scan(ref_request) + pipeline += gp.Scan( + ref_request + ) # TODO: This is a slow implementation for rendering # build pipeline and predict in complete output ROI diff --git a/dacapo/train.py b/dacapo/train.py index e8667d8b8..7beb096b4 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -100,7 +100,7 @@ def train_run( weights_store.retrieve_weights(run, iteration=latest_weights_iteration) logger.error( f"Found weights for iteration {latest_weights_iteration}, but " - f"run {run.name} was only trained until {trained_until}." + f"run {run.name} was only trained until {trained_until}. " ) # start/resume training @@ -129,18 +129,20 @@ def train_run( # train for at most 100 iterations at a time, then store training stats iterations = min(100, run.train_until - trained_until) iteration_stats = None - - for iteration_stats in tqdm( + bar = tqdm( trainer.iterate( iterations, run.model, run.optimizer, compute_context.device, ), - "training", - iterations, - ): + desc=f"training until {iterations + trained_until}", + total=run.train_until, + initial=trained_until, + ) + for iteration_stats in bar: run.training_stats.add_iteration_stats(iteration_stats) + bar.set_postfix({"loss": iteration_stats.loss}) if (iteration_stats.iteration + 1) % run.validation_interval == 0: break @@ -162,22 +164,26 @@ def train_run( run.model = run.model.to(torch.device("cpu")) run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) - weights_store.store_weights(run, iteration_stats.iteration + 1) - validate_run( - run, - iteration_stats.iteration + 1, - compute_context=compute_context, - ) - stats_store.store_validation_iteration_scores( - run.name, run.validation_scores - ) stats_store.store_training_stats(run.name, run.training_stats) + weights_store.store_weights(run, iteration_stats.iteration + 1) + try: + validate_run( + run, + iteration_stats.iteration + 1, + compute_context=compute_context, + ) + stats_store.store_validation_iteration_scores( + run.name, run.validation_scores + ) + except Exception as e: + logger.error( + f"Validation failed for run {run.name} at iteration " + f"{iteration_stats.iteration + 1}.", + exc_info=e, + ) # make sure to move optimizer back to the correct device run.move_optimizer(compute_context.device) run.model.train() - weights_store.store_weights(run, run.training_stats.trained_until()) - stats_store.store_training_stats(run.name, run.training_stats) - logger.info("Trained until %d, finished.", trained_until) diff --git a/setup.py b/setup.py index b38a41edd..34faf365b 100644 --- a/setup.py +++ b/setup.py @@ -5,16 +5,16 @@ description="Framework for easy composition of volumetric machine learning jobs.", long_description=open("README.md", "r").read(), long_description_content_type="text/markdown", - version="0.1", + version="0.1.1", url="https://github.com/funkelab/dacapo", - author="Jan Funke, Will Patton", - author_email="funkej@janelia.hhmi.org, pattonw@janelia.hhmi.org", + author="Jan Funke, Will Patton, Jeff Rhoades", + author_email="funkej@janelia.hhmi.org, pattonw@janelia.hhmi.org, rhoadesj@hhmi.org", license="MIT", packages=find_packages(), entry_points={"console_scripts": ["dacapo=dacapo.cli:cli"]}, include_package_data=True, install_requires=[ - "numpy", + "numpy==1.22.3", "pyyaml", "zarr", "cattrs",