Skip to content

Commit

Permalink
Merge pull request #25 from janelia-cellmap/rhoadesj_simple_changes
Browse files Browse the repository at this point in the history
Rhoadesj simple changes
  • Loading branch information
mzouink authored Feb 8, 2024
2 parents 34e8253 + fe23b5d commit 5f50f9b
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 57 deletions.
196 changes: 192 additions & 4 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,200 @@
import logging
from typing import Optional
from funlib.geometry import Roi, Coordinate
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays.array import Array
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.run import Run

from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
)
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.compute_context import LocalTorch, ComputeContext
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store import (
create_config_store,
create_weights_store,
)

from pathlib import Path

logger = logging.getLogger(__name__)


def apply(run_name: str, iteration: int, dataset_name: str):
def apply(
run_name: str,
input_container: Path or str,
input_dataset: str,
output_path: Path or str,
validation_dataset: Optional[Dataset or str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[PostProcessorParameters or str] = None,
roi: Optional[Roi or str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype or str] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
):
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used."""
if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

if isinstance(roi, str):
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in roi.strip("[]").split(",")
]
)
roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)

assert (validation_dataset is not None and isinstance(criterion, str)) or (
isinstance(iteration, int)
), "Either validation_dataset and criterion, or iteration must be provided."

# retrieving run
logger.info("Loading run %s", run_name)
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# create weights store
weights_store = create_weights_store()

# load weights
if iteration is None:
# weights_store._load_best(run, criterion)
iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion)
logger.info("Loading weights for iteration %i", iteration)
weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights?

# find the best parameters
if isinstance(validation_dataset, str):
val_ds_name = validation_dataset
validation_dataset = [
dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name
][0]
logger.info("Finding best parameters for validation dataset %s", validation_dataset)
if parameters is None:
parameters = run.task.evaluator.get_overall_best_parameters(
validation_dataset, criterion
)
assert (
parameters is not None
), "Unable to retieve parameters. Parameters must be provided explicitly."

elif isinstance(parameters, str):
try:
post_processor_name = parameters.split("(")[0]
post_processor_kwargs = parameters.split("(")[1].strip(")").split(",")
post_processor_kwargs = {
key.strip(): value.strip()
for key, value in [arg.split("=") for arg in post_processor_kwargs]
}
for key, value in post_processor_kwargs.items():
if value.isdigit():
post_processor_kwargs[key] = int(value)
elif value.replace(".", "", 1).isdigit():
post_processor_kwargs[key] = float(value)
except:
raise ValueError(
f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'"
)
try:
parameters = getattr(post_processors, post_processor_name)(
**post_processor_kwargs
)
except Exception as e:
logger.error(
f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.",
exc_info=True,
)
raise e

assert isinstance(
parameters, PostProcessorParameters
), "Parameters must be parsable to a PostProcessorParameters object."

# make array identifiers for input, predictions and outputs
input_array_identifier = LocalArrayIdentifier(input_container, input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
input_array.roi
)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}",
)
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
)
output_array_identifier = LocalArrayIdentifier(
output_container, f"output_{run_name}_{iteration}_{parameters}"
)

logger.info(
"Applying results from run %s at iteration %d to dataset %s",
run_name,
"Applying best results from run %s at iteration %i to dataset %s",
run.name,
iteration,
dataset_name,
Path(input_container, input_dataset),
)
return apply_run(
run,
parameters,
input_array,
prediction_array_identifier,
output_array_identifier,
roi,
num_cpu_workers,
output_dtype,
compute_context,
overwrite,
)


def apply_run(
run: Run,
parameters: PostProcessorParameters,
input_array: Array,
prediction_array_identifier: LocalArrayIdentifier,
output_array_identifier: LocalArrayIdentifier,
roi: Optional[Roi] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""
run.model.eval()

# render prediction dataset
logger.info("Predicting on dataset %s", prediction_array_identifier)
predict(
run.model,
input_array,
prediction_array_identifier,
output_roi=roi,
num_cpu_workers=num_cpu_workers,
output_dtype=output_dtype,
compute_context=compute_context,
overwrite=overwrite,
)

# post-process the output
logger.info("Post-processing output to dataset %s", output_array_identifier)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(
parameters, output_array_identifier, overwrite=overwrite, blockwise=True
)

logger.info("Done")
return
55 changes: 44 additions & 11 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import dacapo
import click
import logging
Expand Down Expand Up @@ -40,21 +42,52 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
)
@click.option(
"-i",
"--iteration",
"-ic",
"--input_container",
required=True,
type=int,
help="The iteration weights and parameters to use.",
type=click.Path(exists=True, file_okay=False),
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False))
@click.option("-vd", "--validation_dataset", type=str, default=None)
@click.option("-c", "--criterion", default="voi")
@click.option("-i", "--iteration", type=int, default=None)
@click.option("-p", "--parameters", type=str, default=None)
@click.option(
"-r",
"--dataset",
required=True,
"-roi",
"--roi",
type=str,
help="The name of the dataset to apply the run to.",
required=False,
help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]",
)
def apply(run_name, iteration, dataset_name):
dacapo.apply(run_name, iteration, dataset_name)
@click.option("-w", "--num_cpu_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
def apply(
run_name: str,
input_container: str,
input_dataset: str,
output_path: str,
validation_dataset: Optional[str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[str] = None,
roi: Optional[str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[str] = "uint8",
):
dacapo.apply(
run_name,
input_container,
input_dataset,
output_path,
validation_dataset,
criterion,
iteration,
parameters,
roi,
num_cpu_workers,
output_dtype,
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def axes(self):
logger.debug(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}",
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
)
return ["c", "z", "y", "x"][-self.dims : :]

Expand Down
6 changes: 5 additions & 1 deletion dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def __init__(self, task_config):
"""Create a `DummyTask` from a `DummyTaskConfig`."""

self.predictor = AffinitiesPredictor(
neighborhood=task_config.neighborhood, lsds=task_config.lsds
neighborhood=task_config.neighborhood,
lsds=task_config.lsds,
num_voxels=task_config.num_voxels,
downsample_lsds=task_config.downsample_lsds,
grow_boundary_iterations=task_config.grow_boundary_iterations,
)
self.loss = AffinitiesLoss(len(task_config.neighborhood))
self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood)
Expand Down
20 changes: 20 additions & 0 deletions dacapo/experiments/tasks/affinities_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,23 @@ class AffinitiesTaskConfig(TaskConfig):
"It has been shown that lsds as an auxiliary task can help affinity predictions."
},
)
num_voxels: int = attr.ib(
default=20,
metadata={
"help_text": "The number of voxels to use for the gaussian sigma when computing lsds."
},
)
downsample_lsds: int = attr.ib(
default=1,
metadata={
"help_text": "The amount to downsample the lsds. "
"This is useful for speeding up training and inference."
},
)
grow_boundary_iterations: int = attr.ib(
default=0,
metadata={
"help_text": "The number of iterations to run the grow boundaries algorithm. "
"This is useful for refining the boundaries of the affinities, and reducing merging of adjacent objects."
},
)
33 changes: 24 additions & 9 deletions dacapo/experiments/tasks/predictors/affinities_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@


class AffinitiesPredictor(Predictor):
def __init__(self, neighborhood: List[Coordinate], lsds: bool = True):
def __init__(
self,
neighborhood: List[Coordinate],
lsds: bool = True,
num_voxels: int = 20,
downsample_lsds: int = 1,
grow_boundary_iterations: int = 0,
):
self.neighborhood = neighborhood
self.lsds = lsds
self.num_voxels = num_voxels
if lsds:
self._extractor = None
if self.dims == 2:
Expand All @@ -30,12 +38,16 @@ def __init__(self, neighborhood: List[Coordinate], lsds: bool = True):
raise ValueError(
f"Cannot compute lsds on volumes with {self.dims} dimensions"
)
self.downsample_lsds = downsample_lsds
else:
self.num_lsds = 0
self.grow_boundary_iterations = grow_boundary_iterations

def extractor(self, voxel_size):
if self._extractor is None:
self._extractor = LsdExtractor(self.sigma(voxel_size))
self._extractor = LsdExtractor(
self.sigma(voxel_size), downsample=self.downsample_lsds
)

return self._extractor

Expand All @@ -45,8 +57,7 @@ def dims(self):

def sigma(self, voxel_size):
voxel_dist = max(voxel_size) # arbitrarily chosen
num_voxels = 10 # arbitrarily chosen
sigma = voxel_dist * num_voxels
sigma = voxel_dist * self.num_voxels # arbitrarily chosen
return Coordinate((sigma,) * self.dims)

def lsd_pad(self, voxel_size):
Expand Down Expand Up @@ -118,7 +129,9 @@ def _grow_boundaries(self, mask, slab):
slice(start[d], start[d] + slab[d]) for d in range(len(slab))
)
mask_slab = mask[slices]
dilated_mask_slab = ndimage.binary_dilation(mask_slab, iterations=1)
dilated_mask_slab = ndimage.binary_dilation(
mask_slab, iterations=self.grow_boundary_iterations
)
foreground[slices] = dilated_mask_slab

# label new background
Expand All @@ -130,10 +143,12 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
(moving_class_counts, moving_lsd_class_counts) = (
moving_class_counts if moving_class_counts is not None else (None, None)
)
# mask_data = self._grow_boundaries(
# mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes)
# )
mask_data = mask[target.roi]
if self.grow_boundary_iterations > 0:
mask_data = self._grow_boundaries(
mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes)
)
else:
mask_data = mask[target.roi]
aff_weights, moving_class_counts = balance_weights(
target[target.roi][: self.num_channels - self.num_lsds].astype(np.uint8),
2,
Expand Down
Loading

0 comments on commit 5f50f9b

Please sign in to comment.