diff --git a/src/anemoi/training/losses/binarycrossentropy.py b/src/anemoi/training/losses/binarycrossentropy.py index 08f582fd..988b47f0 100644 --- a/src/anemoi/training/losses/binarycrossentropy.py +++ b/src/anemoi/training/losses/binarycrossentropy.py @@ -14,14 +14,12 @@ import torch -from torch.nn import BCELoss - -from anemoi.training.losses.weightedloss import BaseWeightedLoss +import torch.nn as nn LOGGER = logging.getLogger(__name__) -class BinaryCrossEntropyLoss(BaseWeightedLoss): +class BinaryCrossEntropyLoss(nn.Module): """Node-weighted binary cross entropy loss.""" name = "binarycrossentropy" @@ -43,11 +41,8 @@ def __init__( by default False """ - super().__init__( - node_weights=node_weights, - ignore_nans=ignore_nans, - **kwargs, - ) + super().__init__() + self.bce_loss = nn.BCELoss() def forward( self, @@ -78,5 +73,4 @@ def forward( torch.Tensor Weighted binary cross entropy loss """ - loss = BCELoss() - return loss(pred.float(), target.float()) # BCE likes floats. + return self.bce_loss(pred.float(), target.float()) # BCE likes floats.