From 8389f9fb0ae1d74eee4284d2583a7676745c8069 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Tue, 22 Oct 2024 16:06:03 -0400 Subject: [PATCH] issue with load_state_dict and zpad --- cellpose/resnet_torch.py | 61 ++++++++++++++++++++++++++++------------ cellpose/transforms.py | 1 + 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/cellpose/resnet_torch.py b/cellpose/resnet_torch.py index c21cd0cf..f51265e5 100644 --- a/cellpose/resnet_torch.py +++ b/cellpose/resnet_torch.py @@ -258,7 +258,7 @@ def save_model(self, filename): filename (str): The path to the file where the model will be saved. """ torch.save(self.state_dict(), filename) - + def load_model(self, filename, device=None): """ Load the model from a file. @@ -268,34 +268,21 @@ def load_model(self, filename, device=None): device (torch.device, optional): The device to load the model on. Defaults to None. """ if (device is not None) and (device.type != "cpu"): - state_dict = torch.load(filename, map_location=device, weights_only=True) + state_dict = torch.load(filename, map_location=device) else: self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D, self.diam_mean) - state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True) - - self.load_state_dict(state_dict) - - def load_state_dict(self, state_dict): - """ - Load the state dictionary into the model. + state_dict = torch.load(filename, map_location=torch.device("cpu")) - This method overrides the default `load_state_dict` to handle Cellpose's custom - loading mechanism and ensures compatibility with BioImage.IO Core. - - Args: - state_dict (Mapping[str, Any]): A state dictionary to load into the model - """ if state_dict["output.2.weight"].shape[0] != self.nout: for name in self.state_dict(): if "output" not in name: self.state_dict()[name].copy_(state_dict[name]) else: - super().load_state_dict( - {name: param for name, param in state_dict.items()}, + self.load_state_dict( + dict([(name, param) for name, param in state_dict.items()]), strict=False) - class CPnetBioImageIO(CPnet): """ A subclass of the CPnet model compatible with the BioImage.IO Spec. @@ -316,3 +303,41 @@ def forward(self, x): """ output_tensor, style_tensor, downsampled_tensors = super().forward(x) return output_tensor, style_tensor, *downsampled_tensors + + + def load_model(self, filename, device=None): + """ + Load the model from a file. + + Args: + filename (str): The path to the file where the model is saved. + device (torch.device, optional): The device to load the model on. Defaults to None. + """ + if (device is not None) and (device.type != "cpu"): + state_dict = torch.load(filename, map_location=device, weights_only=True) + else: + self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D, + self.diam_mean) + state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True) + + self.load_state_dict(state_dict) + + def load_state_dict(self, state_dict): + """ + Load the state dictionary into the model. + + This method overrides the default `load_state_dict` to handle Cellpose's custom + loading mechanism and ensures compatibility with BioImage.IO Core. + + Args: + state_dict (Mapping[str, Any]): A state dictionary to load into the model + """ + if state_dict["output.2.weight"].shape[0] != self.nout: + for name in self.state_dict(): + if "output" not in name: + self.state_dict()[name].copy_(state_dict[name]) + else: + super().load_state_dict( + {name: param for name, param in state_dict.items()}, + strict=False) + diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 7738d748..655c2c73 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -176,6 +176,7 @@ def normalize99(Y, lower=1, upper=99, copy=True, downsample=False): ndarray: The normalized image. """ X = Y.copy() if copy else Y + X = X.astype("float32") if X.dtype!="float64" and X.dtype!="float32" else X if downsample and X.size > 224**3: nskip = [max(1, X.shape[i] // 224) for i in range(X.ndim)] nskip[0] = max(1, X.shape[0] // 50) if X.ndim == 3 else nskip[0]