diff --git a/src/anemoi/training/losses/nodeweights.py b/src/anemoi/training/losses/nodeweights.py index 1b8fbf9a..ed4afaf4 100644 --- a/src/anemoi/training/losses/nodeweights.py +++ b/src/anemoi/training/losses/nodeweights.py @@ -122,7 +122,12 @@ def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str def weights(self, graph_data: HeteroData) -> torch.Tensor: attr_weight = super().weights(graph_data) - mask = graph_data[self.target][self.scaled_attribute].squeeze().bool() + if self.scaled_attribute in graph_data[self.target]: + mask = graph_data[self.target][self.scaled_attribute].squeeze().bool() + else: + error_msg = f"scaled_attribute {self.scaled_attribute} not found in graph_object" + raise KeyError(error_msg) + unmasked_sum = torch.sum(attr_weight[~mask]) weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask) attr_weight[mask] = weight_per_masked_node