Skip to content

Commit

Permalink
Fixed model weights path
Browse files Browse the repository at this point in the history
  • Loading branch information
urbj committed Nov 6, 2024
1 parent c2b1e4d commit 92dbf59
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion CandyCrunch/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
if torch.cuda.is_available():
device = "cuda:0"

sdict = os.path.join(this_dir, 'sugarbase.pt')
sdict = os.path.join(this_dir, 'candycrunch.pt')
sdict = torch.load(sdict, map_location = device)
sdict = {k.replace('module.', ''): v for k, v in sdict.items()}
candycrunch = CandyCrunch_CNN(2048, num_classes = len(glycans)).to(device)
Expand Down

0 comments on commit 92dbf59

Please sign in to comment.