Skip to content

Commit

Permalink
Merge branch 'develop' into fix/netcdf-variables
Browse files Browse the repository at this point in the history
  • Loading branch information
dietervdb-meteo authored Jan 10, 2025
2 parents 911e6a1 + ceb7e43 commit 6dd8da2
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 21 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ jobs:
codecov_upload: true
secrets: inherit

# Build downstream packages on HPC
downstream-ci-hpc:
name: downstream-ci-hpc
if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }}
uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci-hpc.yml@main
with:
anemoi-inference: ecmwf/anemoi-inference@${{ github.event.pull_request.head.sha || github.sha }}
secrets: inherit
# # Build downstream packages on HPC
# downstream-ci-hpc:
# name: downstream-ci-hpc
# if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }}
# uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci-hpc.yml@main
# with:
# anemoi-inference: ecmwf/anemoi-inference@${{ github.event.pull_request.head.sha || github.sha }}
# secrets: inherit
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ repos:
- --force-single-line-imports
- --profile black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
rev: v0.8.6
hooks:
- id: ruff
args:
Expand All @@ -64,7 +64,7 @@ repos:
hooks:
- id: pyproject-fmt
- repo: https://github.com/jshwi/docsig # Check docstrings against function sig
rev: v0.65.0
rev: v0.66.1
hooks:
- id: docsig
args:
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Keep it human-readable, your future self will thank you!
## [Unreleased]

### Added
- Add support for models with unconnected nodes dropped from input [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Change trigger for boundary forcings [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Add support for automatic loading of anemoi-datasets of more general type [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Add initial state output in netcdf format
- Fix: Enable inference when no constant forcings are used
- Add anemoi-transform link to documentation
Expand All @@ -35,6 +38,7 @@ Keep it human-readable, your future self will thank you!
- Fix SimpleRunner

### Removed
- ci: turn off hpc workflow


## [0.2.0](https://github.com/ecmwf/anemoi-inference/compare/0.1.9...0.2.0) - Use earthkit-data
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def mars_requests(self, *, variables, dates, use_grib_paramid=False, **kwargs):

@cached_property
def _supporting_arrays(self):
return self._metadata.supporting_arrays
return self._metadata._supporting_arrays

@property
def name(self):
Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/inference/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ def __init__(self, context, input, variables, variables_mask):
self.variables_mask = variables_mask
assert isinstance(input, DatasetInput), "Currently only boundary forcings from dataset supported."
self.input = input
num_lam, num_other = input.ds.grids
self.spatial_mask = np.array([False] * num_lam + [True] * num_other, dtype=bool)
if "output_mask" in context.checkpoint._supporting_arrays:
self.spatial_mask = ~context.checkpoint.load_supporting_array("output_mask")
else:
self.spatial_mask = np.array([False] * len(input["latitudes"]), dtype=bool)
self.kinds = dict(retrieved=True) # Used for debugging

def __repr__(self):
Expand Down
24 changes: 21 additions & 3 deletions src/anemoi/inference/inputs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@ class DatasetInput(Input):

def __init__(self, context, args, kwargs):
super().__init__(context)

grid_indices = kwargs.pop("grid_indices", None)

self.args, self.kwargs = args, kwargs
if context.verbosity > 0:
LOG.info(
"Opening dataset with\nargs=%s\nkwargs=%s", json.dumps(args, indent=4), json.dumps(kwargs, indent=4)
)

if grid_indices is None and "grid_indices" in context.checkpoint._supporting_arrays:
grid_indices = context.checkpoint.load_supporting_array("grid_indices")
if context.verbosity > 0:
LOG.info(
"Loading supporting array `grid_indices` from checkpoint, \
the input grid will be reduced accordingly."
)

self.grid_indices = slice(None) if grid_indices is None else grid_indices

@cached_property
def ds(self):
from anemoi.datasets import open_dataset
Expand All @@ -48,11 +61,13 @@ def create_input_state(self, *, date=None):
raise ValueError("`date` must be provided")

date = to_datetime(date)
latitudes = self.ds.latitudes
longitudes = self.ds.longitudes

input_state = dict(
date=date,
latitudes=self.ds.latitudes,
longitudes=self.ds.longitudes,
latitudes=latitudes[self.grid_indices],
longitudes=longitudes[self.grid_indices],
fields=dict(),
)

Expand All @@ -69,7 +84,8 @@ def create_input_state(self, *, date=None):
if variable not in requested_variables:
continue
# Squeeze the data to remove the ensemble dimension
fields[variable] = np.squeeze(data[:, i], axis=1)
values = np.squeeze(data[:, i], axis=1)
fields[variable] = values[:, self.grid_indices]

return input_state

Expand All @@ -82,6 +98,8 @@ def load_forcings(self, *, variables, dates):
data = np.squeeze(data, axis=2)
# Reorder the dimensions to (variable, date, values)
data = np.swapaxes(data, 0, 1)
# apply reduction to `grid_indices`
data = data[..., self.grid_indices]
return data

def _load_dates(self, dates):
Expand Down
9 changes: 4 additions & 5 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def output_tensor_index_to_variable(self):
@cached_property
def number_of_grid_points(self):
"""Return the number of grid points per fields"""
if "grid_indices" in self._supporting_arrays:
return len(self.load_supporting_array("grid_indices"))
try:
return self._metadata.dataset.shape[-1]
except AttributeError:
Expand Down Expand Up @@ -510,14 +512,13 @@ def _find(x):
_find(y)

if isinstance(x, dict):
if "dataset" in x:
if "dataset" in x and isinstance(x["dataset"], str):
result.append(x["dataset"])

for k, v in x.items():
_find(v)

_find(self._config.dataloader.training.dataset)

return result

def open_dataset_args_kwargs(self, *, use_original_paths, from_dataloader=None):
Expand Down Expand Up @@ -717,9 +718,7 @@ def boundary_forcings_inputs(self, context, input_state):

result = []

output_mask = self._config_model.get("output_mask", None)
if output_mask is not None:
assert output_mask == "cutout", "Currently only cutout as output mask supported."
if "output_mask" in self._supporting_arrays:
result.append(
context.create_boundary_forcings(
self.prognostic_variables,
Expand Down

0 comments on commit 6dd8da2

Please sign in to comment.