diff --git a/models/src/anemoi/models/layers/block.py b/models/src/anemoi/models/layers/block.py index 72e487d2..afcb776e 100644 --- a/models/src/anemoi/models/layers/block.py +++ b/models/src/anemoi/models/layers/block.py @@ -68,6 +68,7 @@ def __init__( num_heads: int, activation: str, window_size: int, + positional_encoding_hidden: Optional[Tensor] = None, dropout_p: float = 0.0, ): super().__init__() @@ -80,6 +81,12 @@ def __init__( self.layer_norm1 = nn.LayerNorm(num_channels) + self.register_buffer("positional_encoding_hidden", positional_encoding_hidden) + if self.positional_encoding_hidden is not None: + self.pos_embedder = nn.Linear( + self.positional_encoding_hidden.shape[-1], num_channels + ) # hidden_dim is num_channels + self.attention = MultiHeadSelfAttention( num_heads=num_heads, embed_dim=num_channels, @@ -99,6 +106,11 @@ def __init__( def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: + if self.positional_encoding_hidden is not None: + pos_embedding = self.pos_embedder(self.positional_encoding_hidden) + pos_embedding = pos_embedding.repeat(batch_size, 1) + x = x + pos_embedding + # Need to be out of place for gradient propagation x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) x = x + self.mlp(self.layer_norm2(x)) diff --git a/models/src/anemoi/models/layers/chunk.py b/models/src/anemoi/models/layers/chunk.py index 5c4fae38..67e1481c 100644 --- a/models/src/anemoi/models/layers/chunk.py +++ b/models/src/anemoi/models/layers/chunk.py @@ -74,6 +74,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", + positional_encoding_hidden: Optional[Tensor] = None, dropout_p: float = 0.0, ) -> None: """Initialize TransformerProcessor. @@ -102,6 +103,7 @@ def __init__( num_heads=num_heads, activation=activation, window_size=window_size, + positional_encoding_hidden=positional_encoding_hidden, dropout_p=dropout_p, ) diff --git a/models/src/anemoi/models/layers/positionalencoding.py b/models/src/anemoi/models/layers/positionalencoding.py new file mode 100644 index 00000000..06b454d8 --- /dev/null +++ b/models/src/anemoi/models/layers/positionalencoding.py @@ -0,0 +1,59 @@ +from abc import ABC +from abc import abstractmethod + +import torch +from torch import Tensor + + +class BasePositionalEncoding(ABC): + """Configurable method calcuating positional encodings for latlons of a grid. + + To enable the positional encoding add the following to the model-config file and + chose the corresponding positional-encoding-class: + ``` + positional_encoding: + _target_: anemoi.models.layers.positionalencoding.CosSinLatCosSinLon + _convert_: all + ``` + If the entry positional_encoding does not exist or is None, no positional encoding is used. + + """ + + def __init__(self) -> None: + """Initialise Function for calculating the positional encodings.""" + + @abstractmethod + def positional_encoding(self, latlons_hidden: Tensor) -> Tensor: ... + + +class LatCosSinLon(BasePositionalEncoding): + """Lat, cos(lon), sin(lon) for grid points.""" + + def positional_encoding(self, latlons_hidden: Tensor) -> Tensor: + """Output lat, cos(lon), sin(lon) for grid points.""" + lat_coslon_sinlon_hidden = torch.cat( + ( + latlons_hidden[:, 0].unsqueeze(-1), + torch.cos(latlons_hidden[:, 1].unsqueeze(-1)), + torch.sin(latlons_hidden[:, 1].unsqueeze(-1)), + ), + dim=-1, + ) + return lat_coslon_sinlon_hidden + + +class CosSinLatCosSinLon(BasePositionalEncoding): + """Cos(lat), sin(lat), cos(lon), sin(lon) for grid points.""" + + def positional_encoding(self, latlons_hidden: Tensor) -> Tensor: + """Output cos(lat), sin(lat), cos(lon), sin(lon) for grid points.""" + coslat_sinlat_coslon_sinlon_hidden = torch.cat( + ( + torch.cos(latlons_hidden[:, 0].unsqueeze(-1)), + torch.sin(latlons_hidden[:, 0].unsqueeze(-1)), + torch.cos(latlons_hidden[:, 1].unsqueeze(-1)), + torch.sin(latlons_hidden[:, 1].unsqueeze(-1)), + ), + dim=-1, + ) + return coslat_sinlat_coslon_sinlon_hidden diff --git a/models/src/anemoi/models/layers/processor.py b/models/src/anemoi/models/layers/processor.py index 8dba1f66..24aa3026 100644 --- a/models/src/anemoi/models/layers/processor.py +++ b/models/src/anemoi/models/layers/processor.py @@ -40,6 +40,7 @@ def __init__( num_chunks: int = 2, activation: str = "GELU", cpu_offload: bool = False, + positional_encoding_hidden: Optional[Tensor] = None, **kwargs, ) -> None: """Initialize BaseProcessor.""" @@ -49,6 +50,7 @@ def __init__( self.num_chunks = num_chunks self.num_channels = num_channels self.chunk_size = num_layers // num_chunks + self.positional_encoding_hidden = positional_encoding_hidden assert ( num_layers % num_chunks == 0 @@ -94,6 +96,7 @@ def __init__( num_chunks: int = 2, activation: str = "GELU", cpu_offload: bool = False, + positional_encoding_hidden: Optional[Tensor] = None, num_heads: int = 16, mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, @@ -125,6 +128,7 @@ def __init__( num_chunks=num_chunks, activation=activation, cpu_offload=cpu_offload, + positional_encoding_hidden=positional_encoding_hidden, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, ) @@ -137,6 +141,7 @@ def __init__( num_layers=self.chunk_size, window_size=window_size, activation=activation, + positional_encoding_hidden=positional_encoding_hidden, dropout_p=dropout_p, ) diff --git a/models/src/anemoi/models/models/encoder_processor_decoder.py b/models/src/anemoi/models/models/encoder_processor_decoder.py index 5e08adb4..1e5cb6c5 100644 --- a/models/src/anemoi/models/models/encoder_processor_decoder.py +++ b/models/src/anemoi/models/models/encoder_processor_decoder.py @@ -76,11 +76,22 @@ def __init__( dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) + positional_encoding_hidden = None + if model_config.model.get("positional_encoding") is not None: + LOGGER.info( + "Using positional encoding. Target function: %s", model_config.model.positional_encoding._target_ + ) + self.positional_encoding = instantiate(model_config.model.positional_encoding) + positional_encoding_hidden = self.positional_encoding.positional_encoding( + self.node_attributes.get_coordinates(self._graph_name_hidden) + ) + # Processor hidden -> hidden self.processor = instantiate( model_config.model.processor, num_channels=self.num_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], + positional_encoding_hidden=positional_encoding_hidden, src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], )