Skip to content

Commit

Permalink
Here we go again
Browse files Browse the repository at this point in the history
  • Loading branch information
olaversl committed Dec 18, 2024
1 parent d13af3d commit 7f4d98f
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions src/anemoi/training/losses/binarycrossentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@

import torch

from torch.nn import BCELoss

from anemoi.training.losses.weightedloss import BaseWeightedLoss
import torch.nn as nn

LOGGER = logging.getLogger(__name__)


class BinaryCrossEntropyLoss(BaseWeightedLoss):
class BinaryCrossEntropyLoss(nn.Module):
"""Node-weighted binary cross entropy loss."""

name = "binarycrossentropy"
Expand All @@ -43,11 +41,8 @@ def __init__(
by default False
"""
super().__init__(
node_weights=node_weights,
ignore_nans=ignore_nans,
**kwargs,
)
super().__init__()
self.bce_loss = nn.BCELoss()

def forward(
self,
Expand Down Expand Up @@ -78,5 +73,4 @@ def forward(
torch.Tensor
Weighted binary cross entropy loss
"""
loss = BCELoss()
return loss(pred.float(), target.float()) # BCE likes floats.
return self.bce_loss(pred.float(), target.float()) # BCE likes floats.

0 comments on commit 7f4d98f

Please sign in to comment.