diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 122526b14..1475c7b97 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -5,6 +5,9 @@ import numpy as np from typing import Dict, Any +import logging + +logger = logging.getLogger(__file__) class ConcatArray(Array): @@ -116,5 +119,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray: axis=0, ) if concatenated.shape[0] == 1: - raise Exception(f"{concatenated.shape}, shapes") + logger.info( + f"Concatenated array has only one channel: {self.name} {concatenated.shape}" + ) return concatenated diff --git a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py index beaa474d1..e08ffe562 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py @@ -41,7 +41,7 @@ def attrs(self): @property def axes(self): - return ["t", "z", "y", "x"][-self.dims :] + return ["c", "z", "y", "x"][-self.dims :] @property def dims(self) -> int: diff --git a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py index 7101d737e..5f2bc0483 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py @@ -35,7 +35,7 @@ def from_gp_array(cls, array: gp.Array): ((["b", "c"] if len(array.data.shape) == instance.dims + 2 else [])) + (["c"] if len(array.data.shape) == instance.dims + 1 else []) + [ - "t", + "c", "z", "y", "x", diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index cadfcb6cd..42030e701 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -54,7 +54,7 @@ def axes(self): f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}", ) - return ["t", "z", "y", "x"][-self.dims : :] + return ["c", "z", "y", "x"][-self.dims : :] @property def dims(self) -> int: diff --git a/dacapo/experiments/model.py b/dacapo/experiments/model.py index bbaacb2dc..fe1f8e7d5 100644 --- a/dacapo/experiments/model.py +++ b/dacapo/experiments/model.py @@ -24,7 +24,7 @@ def __init__( self, architecture: Architecture, prediction_head: torch.nn.Module, - eval_activation: torch.nn.Module = None, + eval_activation: torch.nn.Module | None = None, ): super().__init__() diff --git a/dacapo/train.py b/dacapo/train.py index 9203c1be3..86473ee36 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -16,6 +16,7 @@ def train(run_name: str, compute_context: ComputeContext = LocalTorch()): """Train a run""" if compute_context.train(run_name): + logger.error("Run %s is already being trained", run_name) # if compute context runs train in some other process # we are done here. return @@ -96,10 +97,15 @@ def train_run( weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - raise RuntimeError( + weights_store.retrieve_weights(run, iteration=latest_weights_iteration) + logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}." ) + # raise RuntimeError( + # f"Found weights for iteration {latest_weights_iteration}, but " + # f"run {run.name} was only trained until {trained_until}." + # ) # start/resume training @@ -157,7 +163,7 @@ def train_run( run.model.eval() # free up optimizer memory to allow larger validation blocks - run.model = run.model.to(torch.device("cpu")) + # run.model = run.model.to(torch.device("cpu")) run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) weights_store.store_weights(run, iteration_stats.iteration + 1) diff --git a/dacapo/validate.py b/dacapo/validate.py index 25b7463e1..a1cf9da7d 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -141,6 +141,7 @@ def validate_run( prediction_array_identifier = array_store.validation_prediction_array( run.name, iteration, validation_dataset ) + logger.info("Predicting on dataset %s", validation_dataset.name) predict( run.model, validation_dataset.raw, @@ -148,6 +149,7 @@ def validate_run( compute_context=compute_context, output_roi=validation_dataset.gt.roi, ) + logger.info("Predicted on dataset %s", validation_dataset.name) post_processor.set_prediction(prediction_array_identifier) diff --git a/setup.py b/setup.py index 3ba1f0d0b..b38a41edd 100644 --- a/setup.py +++ b/setup.py @@ -36,5 +36,9 @@ "funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate", "gunpowder>=1.3", "lsds>=0.1.3", + "xarray", + "cattrs", + "numpy-indexed", + "click", ], )