Skip to content

Commit

Permalink
update size checking
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Feb 12, 2024
1 parent 3c2f0fe commit 4a4bd94
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ def initialize_weights(self, model):
# if the model is not the same, we can try to load the weights
# of the common layers
model_dict = model.state_dict()
common_layers = set(model_dict.keys()) & set(weights.model.keys())
for layer in common_layers:
if model_dict[layer].shape == weights.model[layer].shape:
model_dict[layer] = weights.model[layer]
else:
logger.warning(f"layer {layer} has different shape, not loading")
pretrained_dict = {k: v for k, v in weights.model.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict) # update only the existing and matching layers
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")

0 comments on commit 4a4bd94

Please sign in to comment.