Skip to content

Commit

Permalink
Fixed BCE loss (hopefully)..
Browse files Browse the repository at this point in the history
  • Loading branch information
olaversl committed Dec 18, 2024
1 parent c22cbaa commit d13af3d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/anemoi/training/losses/binarycrossentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import torch

from torch.nn import BCELoss

from anemoi.training.losses.weightedloss import BaseWeightedLoss

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,6 +78,5 @@ def forward(
torch.Tensor
Weighted binary cross entropy loss
"""
out = torch.nn.BCELoss(pred, target)
out = self.scale(out, scalar_indices, without_scalars=without_scalars)
return self.scale_by_node_weights(out, squash)
loss = BCELoss()
return loss(pred.float(), target.float()) # BCE likes floats.

0 comments on commit d13af3d

Please sign in to comment.