Skip to content

Commit

Permalink
if test to check for scaled_attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen committed Nov 26, 2024
1 parent 7bb2919 commit 87262ee
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/anemoi/training/losses/nodeweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,12 @@ def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str
def weights(self, graph_data: HeteroData) -> torch.Tensor:
attr_weight = super().weights(graph_data)

mask = graph_data[self.target][self.scaled_attribute].squeeze().bool()
if self.scaled_attribute in graph_data[self.target]:
mask = graph_data[self.target][self.scaled_attribute].squeeze().bool()
else:
error_msg = f"scaled_attribute {self.scaled_attribute} not found in graph_object"
raise KeyError(error_msg)

unmasked_sum = torch.sum(attr_weight[~mask])
weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask)
attr_weight[mask] = weight_per_masked_node
Expand Down

0 comments on commit 87262ee

Please sign in to comment.