diff --git a/CHANGELOG.md b/CHANGELOG.md index ec4c47f3..7f2d4c19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,55 +8,82 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) +## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.1...HEAD) + +## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28 + +### Changed +- Perform full shuffle of training dataset [#153](https://github.com/ecmwf/anemoi-training/pull/153) + ### Fixed +- Update `n_pixel` used by datashader to better adapt across resolutions #152 + - Fixed bug in power spectra plotting for the n320 resolution. + - Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165) -### Added -- Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) +### Added +- Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) - Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) +- Bump `anemoi-graphs` version to 0.4.1 [#159](https://github.com/ecmwf/anemoi-training/pull/159) ### Changed + ## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 ### Changed + - Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111) ### Fixed - Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138) + - Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) + - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115) - Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119) - Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) + - Enable longer validation rollout than training - Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91) + - Save entire config in mlflow ### Added - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) + - Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102) - - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) + + - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) + - Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) + - Add without subsetting in ScaleTensor - Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) + - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) + - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) + - Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65) + - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) + - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) - Functionality to change the weight attribute of nodes in the graph at the start of training without re-generating the graph. [#136] (https://github.com/ecmwf/anemoi-training/pull/136) + - Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147) + ### Changed - Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) diff --git a/pyproject.toml b/pyproject.toml index 8d685acd..10d8efda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ - "anemoi-datasets>=0.4", - "anemoi-graphs>=0.4", + "anemoi-datasets>=0.5.2", + "anemoi-graphs>=0.4.1", "anemoi-models>=0.3", "anemoi-utils[provenance]>=0.4.4", "datashader>=0.16.3", diff --git a/src/anemoi/training/config/graph/encoder_decoder_only.yaml b/src/anemoi/training/config/graph/encoder_decoder_only.yaml index b813254d..76907d82 100644 --- a/src/anemoi/training/config/graph/encoder_decoder_only.yaml +++ b/src/anemoi/training/config/graph/encoder_decoder_only.yaml @@ -22,15 +22,15 @@ edges: # Encoder configuration - source_name: ${graph.data} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges cutoff_factor: 0.6 # only for cutoff method attributes: ${graph.attributes.edges} -- source_name: ${graph.hidden} # Decoder configuration +- source_name: ${graph.hidden} target_name: ${graph.data} - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges num_nearest_neighbours: 3 # only for knn method attributes: ${graph.attributes.edges} diff --git a/src/anemoi/training/config/graph/limited_area.yaml b/src/anemoi/training/config/graph/limited_area.yaml index f17bc384..a22405b6 100644 --- a/src/anemoi/training/config/graph/limited_area.yaml +++ b/src/anemoi/training/config/graph/limited_area.yaml @@ -23,23 +23,23 @@ edges: # Encoder configuration - source_name: ${graph.data} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges cutoff_factor: 0.6 # only for cutoff method attributes: ${graph.attributes.edges} # Processor configuration - source_name: ${graph.hidden} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.MultiScaleEdges + edge_builders: + - _target_: anemoi.graphs.edges.MultiScaleEdges x_hops: 1 attributes: ${graph.attributes.edges} # Decoder configuration - source_name: ${graph.hidden} target_name: ${graph.data} - target_mask_attr_name: cutout - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + target_mask_attr_name: cutout num_nearest_neighbours: 3 # only for knn method attributes: ${graph.attributes.edges} diff --git a/src/anemoi/training/config/graph/multi_scale.yaml b/src/anemoi/training/config/graph/multi_scale.yaml index 7e54535e..eec38d82 100644 --- a/src/anemoi/training/config/graph/multi_scale.yaml +++ b/src/anemoi/training/config/graph/multi_scale.yaml @@ -22,22 +22,22 @@ edges: # Encoder configuration - source_name: ${graph.data} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges cutoff_factor: 0.6 # only for cutoff method attributes: ${graph.attributes.edges} -- source_name: ${graph.hidden} # Processor configuration +- source_name: ${graph.hidden} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.MultiScaleEdges + edge_builders: + - _target_: anemoi.graphs.edges.MultiScaleEdges x_hops: 1 attributes: ${graph.attributes.edges} -- source_name: ${graph.hidden} # Decoder configuration +- source_name: ${graph.hidden} target_name: ${graph.data} - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges num_nearest_neighbours: 3 # only for knn method attributes: ${graph.attributes.edges} diff --git a/src/anemoi/training/config/graph/stretched_grid.yaml b/src/anemoi/training/config/graph/stretched_grid.yaml index dad0172d..a92f319b 100644 --- a/src/anemoi/training/config/graph/stretched_grid.yaml +++ b/src/anemoi/training/config/graph/stretched_grid.yaml @@ -34,22 +34,22 @@ edges: # Encoder - source_name: ${graph.data} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges num_nearest_neighbours: 12 attributes: ${graph.attributes.edges} # Processor - source_name: ${graph.hidden} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.MultiScaleEdges + edge_builders: + - _target_: anemoi.graphs.edges.MultiScaleEdges x_hops: 1 attributes: ${graph.attributes.edges} # Decoder - source_name: ${graph.hidden} target_name: ${graph.data} - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges num_nearest_neighbours: 3 attributes: ${graph.attributes.edges} diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 40065e06..062d2d4d 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -201,6 +201,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: low = shard_start + worker_id * self.n_samples_per_worker high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) + self.chunk_index_range = np.arange(low, high, dtype=np.uint32) LOGGER.debug( "Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d", @@ -212,27 +213,17 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) - self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)] - - # each worker must have a different seed for its random number generator, - # otherwise all the workers will output exactly the same data - # should we check lightning env variable "PL_SEED_WORKERS" here? - # but we alwyas want to seed these anyways ... - base_seed = get_base_seed() - seed = ( - base_seed * (self.model_comm_group_id + 1) - worker_id - ) # note that test, validation etc. datasets get same seed - torch.manual_seed(seed) - random.seed(seed) - self.rng = np.random.default_rng(seed=seed) + torch.manual_seed(base_seed) + random.seed(base_seed) + self.rng = np.random.default_rng(seed=base_seed) sanity_rnd = self.rng.random(1) LOGGER.debug( ( "Worker %d (%s, pid %d, glob. rank %d, model comm group %d, " - "group_rank %d, base_seed %d) using seed %d, sanity rnd %f" + "group_rank %d, base_seed %d), sanity rnd %f" ), worker_id, self.label, @@ -241,7 +232,6 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: self.model_comm_group_id, self.model_comm_group_rank, base_seed, - seed, sanity_rnd, ) @@ -256,12 +246,12 @@ def __iter__(self) -> torch.Tensor: """ if self.shuffle: shuffled_chunk_indices = self.rng.choice( - self.chunk_index_range, - size=self.n_samples_per_worker, + self.valid_date_indices, + size=len(self.valid_date_indices), replace=False, - ) + )[self.chunk_index_range] else: - shuffled_chunk_indices = self.chunk_index_range + shuffled_chunk_indices = self.valid_date_indices[self.chunk_index_range] LOGGER.debug( ( diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 08f9d28b..a54bab70 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -938,7 +938,7 @@ def _plot( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, ) - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy() data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) data = data.numpy() @@ -999,7 +999,7 @@ def process( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, ) - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy() data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) data = data.numpy() return data, output_tensor diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 71d5c475..78a80be9 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -501,7 +501,16 @@ def _clean_params(params: dict[str, Any]) -> dict[str, Any]: dict[str, Any] Cleaned up params ready for MlFlow. """ - prefixes_to_remove = ["hardware", "data", "dataloader", "model", "training", "diagnostics", "metadata.config"] + prefixes_to_remove = [ + "hardware", + "data", + "dataloader", + "model", + "training", + "diagnostics", + "metadata.config", + "metadata.dataset.variables_metadata", + ] keys_to_remove = [key for key in params if any(key.startswith(prefix) for prefix in prefixes_to_remove)] for key in keys_to_remove: del params[key] diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 93e2d324..45818b69 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -24,6 +24,7 @@ from matplotlib.collections import PathCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap +from matplotlib.colors import Normalize from matplotlib.colors import TwoSlopeNorm from pyshtools.expand import SHGLQ from pyshtools.expand import SHExpandGLQ @@ -568,8 +569,12 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: datashader=datashader, ) else: - single_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target", datashader=datashader) - single_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred", datashader=datashader) + combined_data = np.concatenate((input_, truth, pred)) + # For 'errors', only persistence and increments need identical colorbar-limits + combined_error = np.concatenate(((pred - input_), (truth - input_))) + norm = Normalize(vmin=np.nanmin(combined_data), vmax=np.nanmax(combined_data)) + single_plot(fig, ax[1], lon, lat, truth, norm=norm, title=f"{vname} target", datashader=datashader) + single_plot(fig, ax[2], lon, lat, pred, norm=norm, title=f"{vname} pred", datashader=datashader) single_plot( fig, ax[3], @@ -619,7 +624,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: datashader=datashader, ) else: - single_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input", datashader=datashader) + single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader) single_plot( fig, ax[4], @@ -627,7 +632,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, pred - input_, cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), + norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), title=f"{vname} increment [pred - input]", datashader=datashader, ) @@ -638,7 +643,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, truth - input_, cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), + norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), title=f"{vname} persist err", datashader=datashader, ) @@ -703,13 +708,13 @@ def single_plot( else: df = pd.DataFrame({"val": data, "x": lon, "y": lat}) # Adjust binning to match the resolution of the data - n_pixels = int(np.floor(data.shape[0] / 212)) + lower_limit = 25 + upper_limit = 500 + n_pixels = max(min(int(np.floor(data.shape[0] * 0.004)), upper_limit), lower_limit) psc = dsshow( df, dsh.Point("x", "y"), dsh.mean("val"), - vmin=data.min(), - vmax=data.max(), cmap=cmap, plot_width=n_pixels, plot_height=n_pixels, @@ -718,8 +723,10 @@ def single_plot( ax=ax, ) - ax.set_xlim((-np.pi, np.pi)) - ax.set_ylim((-np.pi / 2, np.pi / 2)) + xmin, xmax = max(lon.min(), -np.pi), min(lon.max(), np.pi) + ymin, ymax = max(lat.min(), -np.pi / 2), min(lat.max(), np.pi / 2) + ax.set_xlim((xmin - 0.1, xmax + 0.1)) + ax.set_ylim((ymin - 0.1, ymax + 0.1)) continents.plot_continents(ax) diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 80fc70d3..a18ed4dc 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -20,6 +20,7 @@ import numpy as np import pytorch_lightning as pl import torch +from anemoi.utils.config import DotDict from anemoi.utils.provenance import gather_provenance_info from omegaconf import DictConfig from omegaconf import OmegaConf @@ -128,7 +129,8 @@ def graph_data(self) -> HeteroData: from anemoi.graphs.create import GraphCreator - return GraphCreator(config=self.config.graph).create( + graph_config = DotDict(OmegaConf.to_container(self.config.graph, resolve=True)) + return GraphCreator(config=graph_config).create( save_path=graph_filename, overwrite=self.config.graph.overwrite, )