From 978b328576cb49f147dfe21e4a01c769ad21c797 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 31 May 2024 14:47:20 +0100 Subject: [PATCH] support for old weights --- .../inference/checkpoint/metadata/__init__.py | 2 + .../checkpoint/metadata/version_0_0_0.py | 120 ++++++++---------- 2 files changed, 56 insertions(+), 66 deletions(-) diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index 1b2aef9..e3cc2c0 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -224,6 +224,8 @@ def _computed_forcings(self): ] ) + print("FORCINGS", self._forcing_params()) + constants = set(self._forcing_params()) - set(self.constants_from_input) - set(self.computed_constants) if constants - known: diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_0_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_0_0.py index afa8b03..bbda5f9 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_0_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_0_0.py @@ -22,6 +22,54 @@ class Version_0_0_0(Metadata): def __init__(self, metadata): super().__init__(metadata) + FORCING_PARAMS = [ + "z", + "lsm", + "sdor", + "slor", + "cos_latitude", + "cos_longitude", + "sin_latitude", + "sin_longitude", + "cos_julian_day", + "cos_local_time", + "sin_julian_day", + "sin_local_time", + "insolation", + ] + + indices = dict( + forcing=self._index_of(FORCING_PARAMS), + full=self._index_of(self.variables), + diagnostic=[], + prognostic=self._index_of(self.ordering), + ) + + config = dict( + data_indices=dict( + data=dict( + input=indices, + output=indices, + ), + model=dict( + input=indices, + output=indices, + ), + ), + config=dict( + data=dict(timestep=6, frequency=6), + training=dict( + multistep_input=2, + precision="32", + ), + ), + ) + + self._metadata.update(config) + + def _index_of(self, names): + return [self.variable_to_index[name] for name in names] + def dump(self, indent=0): print("Version_0_0_0: Not implemented") @@ -46,6 +94,8 @@ def dump(self, indent=0): [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000], ) + param_level_ml = tuple() + ordering = [ "q_50", "q_100", @@ -146,9 +196,8 @@ def dump(self, indent=0): "sin_latitude", "sin_longitude", ] - computed_constants_mask = [] - computer_forcing = [ + computed_forcing = [ "cos_julian_day", "cos_local_time", "sin_julian_day", @@ -158,19 +207,11 @@ def dump(self, indent=0): @property def variables(self): - return self.ordering + self.computed_constants + self.forcing_params - - @property - def num_input_features(self): - raise NotImplementedError() - - @property - def data_to_model(self): - raise NotImplementedError() + return self.ordering + self.computed_constants + self.computed_forcing @property - def model_to_data(self): - raise NotImplementedError() + def variables_with_nans(self): + return [] ########################################################################### @property @@ -187,56 +228,3 @@ def select(self): param_level=self.variables, remapping={"param_level": "{param}_{levelist}"}, ) - - ########################################################################### - - @property - def constants_from_input(self): - raise NotImplementedError() - - @property - def constants_from_input_mask(self): - raise NotImplementedError() - - @property - def constant_data_from_input_mask(self): - raise NotImplementedError() - - ########################################################################### - - @property - def prognostic_input_mask(self): - raise NotImplementedError() - - @property - def prognostic_data_input_mask(self): - raise NotImplementedError() - - @property - def prognostic_output_mask(self): - raise NotImplementedError() - - @property - def diagnostic_output_mask(self): - raise NotImplementedError() - - @property - def diagnostic_params(self): - raise NotImplementedError() - - @property - def prognostic_params(self): - raise NotImplementedError() - - ########################################################################### - @property - def precision(self): - raise NotImplementedError() - - @property - def multi_step(self): - raise NotImplementedError() - - @property - def imputable_variables(self): - raise NotImplementedError()