Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement optional positional embedding in processor #57

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
12 changes: 12 additions & 0 deletions models/src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
num_heads: int,
activation: str,
window_size: int,
positional_encoding_hidden: Optional[Tensor] = None,
dropout_p: float = 0.0,
):
super().__init__()
Expand All @@ -80,6 +81,12 @@ def __init__(

self.layer_norm1 = nn.LayerNorm(num_channels)

self.register_buffer("positional_encoding_hidden", positional_encoding_hidden)
if self.positional_encoding_hidden is not None:
self.pos_embedder = nn.Linear(
self.positional_encoding_hidden.shape[-1], num_channels
) # hidden_dim is num_channels

self.attention = MultiHeadSelfAttention(
num_heads=num_heads,
embed_dim=num_channels,
Expand All @@ -99,6 +106,11 @@ def __init__(
def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
if self.positional_encoding_hidden is not None:
pos_embedding = self.pos_embedder(self.positional_encoding_hidden)
pos_embedding = pos_embedding.repeat(batch_size, 1)
x = x + pos_embedding

# Need to be out of place for gradient propagation
x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group)
x = x + self.mlp(self.layer_norm2(x))
Expand Down
2 changes: 2 additions & 0 deletions models/src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
positional_encoding_hidden: Optional[Tensor] = None,
dropout_p: float = 0.0,
) -> None:
"""Initialize TransformerProcessor.
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
num_heads=num_heads,
activation=activation,
window_size=window_size,
positional_encoding_hidden=positional_encoding_hidden,
dropout_p=dropout_p,
)

Expand Down
59 changes: 59 additions & 0 deletions models/src/anemoi/models/layers/positionalencoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABC
from abc import abstractmethod

import torch
from torch import Tensor


class BasePositionalEncoding(ABC):
"""Configurable method calcuating positional encodings for latlons of a grid.

To enable the positional encoding add the following to the model-config file and
chose the corresponding positional-encoding-class:
```
positional_encoding:
_target_: anemoi.models.layers.positionalencoding.CosSinLatCosSinLon
_convert_: all
```
If the entry positional_encoding does not exist or is None, no positional encoding is used.

"""

def __init__(self) -> None:
"""Initialise Function for calculating the positional encodings."""

@abstractmethod
def positional_encoding(self, latlons_hidden: Tensor) -> Tensor: ...


class LatCosSinLon(BasePositionalEncoding):
"""Lat, cos(lon), sin(lon) for grid points."""

def positional_encoding(self, latlons_hidden: Tensor) -> Tensor:
"""Output lat, cos(lon), sin(lon) for grid points."""
lat_coslon_sinlon_hidden = torch.cat(
(
latlons_hidden[:, 0].unsqueeze(-1),
torch.cos(latlons_hidden[:, 1].unsqueeze(-1)),
torch.sin(latlons_hidden[:, 1].unsqueeze(-1)),
),
dim=-1,
)
return lat_coslon_sinlon_hidden


class CosSinLatCosSinLon(BasePositionalEncoding):
"""Cos(lat), sin(lat), cos(lon), sin(lon) for grid points."""

def positional_encoding(self, latlons_hidden: Tensor) -> Tensor:
"""Output cos(lat), sin(lat), cos(lon), sin(lon) for grid points."""
coslat_sinlat_coslon_sinlon_hidden = torch.cat(
(
torch.cos(latlons_hidden[:, 0].unsqueeze(-1)),
torch.sin(latlons_hidden[:, 0].unsqueeze(-1)),
torch.cos(latlons_hidden[:, 1].unsqueeze(-1)),
torch.sin(latlons_hidden[:, 1].unsqueeze(-1)),
),
dim=-1,
)
return coslat_sinlat_coslon_sinlon_hidden
5 changes: 5 additions & 0 deletions models/src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
num_chunks: int = 2,
activation: str = "GELU",
cpu_offload: bool = False,
positional_encoding_hidden: Optional[Tensor] = None,
**kwargs,
) -> None:
"""Initialize BaseProcessor."""
Expand All @@ -49,6 +50,7 @@ def __init__(
self.num_chunks = num_chunks
self.num_channels = num_channels
self.chunk_size = num_layers // num_chunks
self.positional_encoding_hidden = positional_encoding_hidden

assert (
num_layers % num_chunks == 0
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
num_chunks: int = 2,
activation: str = "GELU",
cpu_offload: bool = False,
positional_encoding_hidden: Optional[Tensor] = None,
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
Expand Down Expand Up @@ -125,6 +128,7 @@ def __init__(
num_chunks=num_chunks,
activation=activation,
cpu_offload=cpu_offload,
positional_encoding_hidden=positional_encoding_hidden,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
)
Expand All @@ -137,6 +141,7 @@ def __init__(
num_layers=self.chunk_size,
window_size=window_size,
activation=activation,
positional_encoding_hidden=positional_encoding_hidden,
dropout_p=dropout_p,
)

Expand Down
11 changes: 11 additions & 0 deletions models/src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,22 @@ def __init__(
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

positional_encoding_hidden = None
if model_config.model.get("positional_encoding") is not None:
LOGGER.info(
"Using positional encoding. Target function: %s", model_config.model.positional_encoding._target_
)
self.positional_encoding = instantiate(model_config.model.positional_encoding)
positional_encoding_hidden = self.positional_encoding.positional_encoding(
self.node_attributes.get_coordinates(self._graph_name_hidden)
)

# Processor hidden -> hidden
self.processor = instantiate(
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
positional_encoding_hidden=positional_encoding_hidden,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)
Expand Down
Loading