Skip to content

Commit

Permalink
fix bce loss
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Sep 27, 2023
1 parent ee03505 commit cd4077d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def compute(self, prediction, target, weight):
return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
loss = torch.nn.BCELoss()
return loss(prediction * weight, target * weight)
loss = torch.nn.BCEWithLogitsLoss(reduction='none')
return torch.mean(loss(prediction , target) * weight)
# return abs(prediction * weight - target * weight).sum()

def distance_loss(self, prediction, target, weight):
loss = torch.nn.MSELoss()
Expand Down

0 comments on commit cd4077d

Please sign in to comment.