Skip to content

Commit

Permalink
issue with load_state_dict and zpad
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Oct 22, 2024
1 parent 729b701 commit 8389f9f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
61 changes: 43 additions & 18 deletions cellpose/resnet_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def save_model(self, filename):
filename (str): The path to the file where the model will be saved.
"""
torch.save(self.state_dict(), filename)

def load_model(self, filename, device=None):
"""
Load the model from a file.
Expand All @@ -268,34 +268,21 @@ def load_model(self, filename, device=None):
device (torch.device, optional): The device to load the model on. Defaults to None.
"""
if (device is not None) and (device.type != "cpu"):
state_dict = torch.load(filename, map_location=device, weights_only=True)
state_dict = torch.load(filename, map_location=device)
else:
self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
self.diam_mean)
state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)

self.load_state_dict(state_dict)

def load_state_dict(self, state_dict):
"""
Load the state dictionary into the model.
state_dict = torch.load(filename, map_location=torch.device("cpu"))

This method overrides the default `load_state_dict` to handle Cellpose's custom
loading mechanism and ensures compatibility with BioImage.IO Core.
Args:
state_dict (Mapping[str, Any]): A state dictionary to load into the model
"""
if state_dict["output.2.weight"].shape[0] != self.nout:
for name in self.state_dict():
if "output" not in name:
self.state_dict()[name].copy_(state_dict[name])
else:
super().load_state_dict(
{name: param for name, param in state_dict.items()},
self.load_state_dict(
dict([(name, param) for name, param in state_dict.items()]),
strict=False)


class CPnetBioImageIO(CPnet):
"""
A subclass of the CPnet model compatible with the BioImage.IO Spec.
Expand All @@ -316,3 +303,41 @@ def forward(self, x):
"""
output_tensor, style_tensor, downsampled_tensors = super().forward(x)
return output_tensor, style_tensor, *downsampled_tensors


def load_model(self, filename, device=None):
"""
Load the model from a file.
Args:
filename (str): The path to the file where the model is saved.
device (torch.device, optional): The device to load the model on. Defaults to None.
"""
if (device is not None) and (device.type != "cpu"):
state_dict = torch.load(filename, map_location=device, weights_only=True)
else:
self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
self.diam_mean)
state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)

self.load_state_dict(state_dict)

def load_state_dict(self, state_dict):
"""
Load the state dictionary into the model.
This method overrides the default `load_state_dict` to handle Cellpose's custom
loading mechanism and ensures compatibility with BioImage.IO Core.
Args:
state_dict (Mapping[str, Any]): A state dictionary to load into the model
"""
if state_dict["output.2.weight"].shape[0] != self.nout:
for name in self.state_dict():
if "output" not in name:
self.state_dict()[name].copy_(state_dict[name])
else:
super().load_state_dict(
{name: param for name, param in state_dict.items()},
strict=False)

1 change: 1 addition & 0 deletions cellpose/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def normalize99(Y, lower=1, upper=99, copy=True, downsample=False):
ndarray: The normalized image.
"""
X = Y.copy() if copy else Y
X = X.astype("float32") if X.dtype!="float64" and X.dtype!="float32" else X
if downsample and X.size > 224**3:
nskip = [max(1, X.shape[i] // 224) for i in range(X.ndim)]
nskip[0] = max(1, X.shape[0] // 50) if X.ndim == 3 else nskip[0]
Expand Down

0 comments on commit 8389f9f

Please sign in to comment.