From 0d88b83132aee12c78093f2e3fe0b9cf86c2abd6 Mon Sep 17 00:00:00 2001 From: Tobias Wicky Date: Thu, 30 May 2024 12:43:35 +0200 Subject: [PATCH 1/5] isthisanarmodel? --- neural_lam/data_config.yaml | 34 +++++ neural_lam/models/ar_model.py | 186 ++++++++++++++++++++++++-- neural_lam/models/base_graph_model.py | 2 +- 3 files changed, 211 insertions(+), 11 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index f16a4a30..c234521c 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -62,3 +62,37 @@ projection: central_longitude: 15.0 central_latitude: 63.3 standard_parallels: [63.3, 63.3] + +dataset2: + name: cosmo_example + var_names: + - "T" + - "U" + - "V" + - "RELHUM" + - "PMSL" + - "PP" + var_units: + - K + - m/s + - m/s + - Perc. + - Pa + - hPa + var_longnames: + - "Temperature" + - "Zonal wind component" + - "Meridional wind component" + - "Relative humidity" + - "Pressure at Mean Sea Level" + - "Pressure Perturbation" + var_is_3d: + - 1 + - 1 + - 1 + - 1 + - 0 + - 1 + vertical_levels: [1, 5, 13, 22, 38, 41, 60] + num_forcing_features: 16 + eval_plot_vars: ["TQV"] diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 29b169d4..ccc573f4 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,10 +1,13 @@ # Standard library +import glob import os # Third-party +import imageio import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only import torch import wandb @@ -93,6 +96,20 @@ def __init__(self, args): # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + self.inference_output = [] + "Storage for the output of individual inference steps" + + self.variable_indices = self.pre_compute_variable_indices() + "Index mapping of variable names to their levels in the array." + self.selected_vars_units = [ + (var_name, var_unit) + for var_name, var_unit in zip( + self.config_loader.dataset.var_names, + self.config_loader.dataset.var_units, + ) + if var_name in self.config_loader.dataset.eval_plot_vars + ] + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) @@ -106,6 +123,34 @@ def interior_mask_bool(self): """ return self.interior_mask[:, 0].to(torch.bool) + def pre_compute_variable_indices(self): + """ + Pre-compute indices for each variable in the input tensor + """ + variable_indices = {} + all_vars = [] + index = 0 + # Create a list of tuples for all variables, using level 0 for 2D + # variables + for var_name in self.config_loader.dataset.var_names: + if self.config_loader.dataset.var_is_3d: + for level in self.config_loader.dataset.vertical_levels: + all_vars.append((var_name, level)) + else: + all_vars.append((var_name, 0)) # Use level 0 for 2D variables + + # Sort the variables based on the tuples + sorted_vars = sorted(all_vars) + + for var in sorted_vars: + var_name, level = var + if var_name not in variable_indices: + variable_indices[var_name] = [] + variable_indices[var_name].append(index) + index += 1 + + return variable_indices + @staticmethod def expand_to_batch(x, batch_size): """ @@ -113,7 +158,7 @@ def expand_to_batch(x, batch_size): """ return x.unsqueeze(0).expand(batch_size, -1, -1) - def predict_step(self, prev_state, prev_prev_state, forcing): + def single_prediction(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t @@ -122,6 +167,48 @@ def predict_step(self, prev_state, prev_prev_state, forcing): """ raise NotImplementedError("No prediction step implemented") + def predict_step(self, batch, batch_idx): + """ + Run the inference on batch. + """ + prediction, target, pred_std = self.common_step(batch) + + # Compute all evaluation metrics for error maps + # Note: explicitly list metrics here, as test_metrics can contain + # additional ones, computed differently, but that should be aggregated + # on_predict_epoch_end + for metric_name in ("mse", "mae"): + metric_func = metrics.get_metric(metric_name) + batch_metric_vals = metric_func( + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + self.test_metrics[metric_name].append(batch_metric_vals) + + if self.output_std: + # Store output std. per variable, spatially averaged + mean_pred_std = torch.mean( + pred_std[..., self.interior_mask_bool, :], dim=-2 + ) # (B, pred_steps, d_f) + self.test_metrics["output_std"].append(mean_pred_std) + + # Save per-sample spatial loss for specific times + spatial_loss = self.loss( + prediction, target, pred_std, average_grid=False + ) # (B, pred_steps, num_grid_nodes) + log_spatial_losses = spatial_loss[ + :, [step - 1 for step in self.args.val_steps_to_log] + ] + self.spatial_loss_maps.append(log_spatial_losses) + # (B, N_log, num_grid_nodes) + + if self.trainer.global_rank == 0: + self.plot_examples(batch, batch_idx, prediction=prediction) + self.inference_output.append(prediction) + def unroll_prediction(self, init_states, forcing_features, true_states): """ Roll out prediction taking multiple autoregressive steps with model @@ -139,7 +226,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states): forcing = forcing_features[:, i] border_state = true_states[:, i] - pred_state, pred_std = self.predict_step( + pred_state, pred_std = self.single_prediction( prev_state, prev_prev_state, forcing ) # state: (B, num_grid_nodes, d_f) @@ -345,20 +432,50 @@ def test_step(self, batch, batch_idx): batch, n_additional_examples, prediction=prediction ) - def plot_examples(self, batch, n_examples, prediction=None): + @rank_zero_only + def plot_examples(self, batch, n_examples, batch_idx: int, prediction=None): """ - Plot the first n_examples forecasts from batch - - batch: batch with data to plot corresponding forecasts for - n_examples: number of forecasts to plot - prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. - Generate if None. + Plot the first n_examples forecasts from batch. + + The function checks for the presence of test_dataset or + predict_dataset within the trainer's data module, + handles indexing within the batch for targeted analysis, + performs prediction rescaling, and plots results. + + Parameters: + - batch: batch with data to plot corresponding forecasts for + - n_examples: number of forecasts to plot + - batch_idx (int): index of the batch being processed + - prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. + Generate if None. """ if prediction is None: prediction, target = self.common_step(batch) target = batch[1] + # Determine the dataset to work with (test_dataset or predict_dataset) + dataset = None + if ( + hasattr(self.trainer.datamodule, "test_dataset") + and self.trainer.datamodule.test_dataset + ): + dataset = self.trainer.datamodule.test_dataset + plot_name = "test" + elif ( + hasattr(self.trainer.datamodule, "predict_dataset") + and self.trainer.datamodule.predict_dataset + ): + dataset = self.trainer.datamodule.predict_dataset + plot_name = "prediction" + + if ( + dataset + and self.trainer.global_rank == 0 + and dataset.batch_index == batch_idx + ): + index_within_batch = dataset.index_within_batch + # Rescale to original data scale prediction_rescaled = prediction * self.data_std + self.data_mean target_rescaled = target * self.data_std + self.data_mean @@ -415,7 +532,7 @@ def plot_examples(self, batch, n_examples, prediction=None): example_i = self.plotted_examples wandb.log( { - f"{var_name}_example_{example_i}": wandb.Image(fig) + f"{var_name}_{plot_name}_{example_i}": wandb.Image(fig) for var_name, fig in zip( self.config_loader.dataset.var_names, var_figs ) @@ -573,6 +690,55 @@ def on_test_epoch_end(self): self.spatial_loss_maps.clear() + @rank_zero_only + def on_predict_epoch_end(self): + """ + Compute test metrics and make plots at the end of test epoch. + Will gather stored tensors and perform plotting and logging on rank 0. + """ + + plot_dir_path = f"{wandb.run.dir}/media/images" + value_dir_path = f"{wandb.run.dir}/results/inference" + # Ensure the directory for saving numpy arrays exists + os.makedirs(plot_dir_path, exist_ok=True) + os.makedirs(value_dir_path, exist_ok=True) + + # For values + for i, prediction in enumerate(self.inference_output): + + # Rescale to original data scale + prediction_rescaled = prediction * self.data_std + self.data_mean + + # Process and save the prediction + prediction_array = prediction_rescaled.cpu().numpy() + file_path = os.path.join(value_dir_path, f"prediction_{i}.npy") + np.save(file_path, prediction_array) + + dir_path = f"{wandb.run.dir}/media/images" + for var_name, _ in self.selected_vars_units: + var_indices = self.variable_indices[var_name] + for lvl_i, _ in enumerate(var_indices): + # Calculate var_vrange for each index + lvl = self.config_loader.dataset.vertical_levels[lvl_i] + + # Get all the images for the current variable and index + images = sorted( + glob.glob( + f"{dir_path}/{var_name}_test_lvl_{lvl:02}_t_*.png" + ) + ) + # Generate the GIF + with imageio.get_writer( + f"{dir_path}/{var_name}_lvl_{lvl:02}.gif", + mode="I", + fps=1, + ) as writer: + for filename in images: + image = imageio.imread(filename) + writer.append_data(image) + + self.spatial_loss_maps.clear() + def on_load_checkpoint(self, checkpoint): """ Perform any changes to state dict before loading checkpoint diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 256d4adc..dbe15a02 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -98,7 +98,7 @@ def process_step(self, mesh_rep): """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state, forcing): + def single_prediction(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t From 7c8a629814ab683a36cafbbe4213e02a16dc75ba Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Thu, 30 May 2024 15:52:55 +0200 Subject: [PATCH 2/5] Implement prediction step to the trainer --- train_model.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/train_model.py b/train_model.py index fe064384..2cfed246 100644 --- a/train_model.py +++ b/train_model.py @@ -217,6 +217,7 @@ def main(): None, "val", "test", + "predict", ), f"Unknown eval setting: {args.eval}" # Get an (actual) random run id as a unique identifier @@ -294,6 +295,7 @@ def main(): callbacks=[checkpoint_callback], check_val_every_n_epoch=args.val_interval, precision=args.precision, + limit_predict_batches=1 ) # Only init once, on rank 0 only @@ -305,7 +307,7 @@ def main(): if args.eval: if args.eval == "val": eval_loader = val_loader - else: # Test + elif args.eval == "test": eval_loader = torch.utils.data.DataLoader( WeatherDataset( config_loader.dataset.name, @@ -318,9 +320,34 @@ def main(): shuffle=False, num_workers=args.n_workers, ) + elif args.eval == "predict": + pred_loader = torch.utils.data.DataLoader( + WeatherDataset( + config_loader.dataset.name, + pred_length=max_pred_length, + split="predict", + subsample_step=args.step_length, + subset=bool(args.subset_ds), + ), + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + ) + print(f"Running prediction on {args.eval}") + trainer.predict( + model=model, + dataloaders=pred_loader, + return_predictions=True, + ckpt_path=args.load, + ) + else: + print(f"Unknown evaluation mode: {args.eval}") + raise ValueError(f"Unknown evaluation mode: {args.eval}") - print(f"Running evaluation on {args.eval}") - trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load) + if args.eval in ["val", "test"]: + print(f"Running evaluation on {args.eval}") + trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load) + else: # Train model trainer.fit( From b842ee06da2e9cd9314f3579bef4a0b1800ec26e Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Thu, 30 May 2024 17:02:33 +0200 Subject: [PATCH 3/5] improvements, still errors in num_samples --- neural_lam/data_config.yaml | 75 +++++------------------------------ neural_lam/weather_dataset.py | 2 +- 2 files changed, 10 insertions(+), 67 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index c234521c..5280eaf0 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -1,70 +1,5 @@ dataset: - name: meps_example - var_names: - - pres_0g - - pres_0s - - nlwrs_0 - - nswrs_0 - - r_2 - - r_65 - - t_2 - - t_65 - - t_500 - - t_850 - - u_65 - - u_850 - - v_65 - - v_850 - - wvint_0 - - z_1000 - - z_500 - var_units: - - Pa - - Pa - - r"$\mathrm{W}/\mathrm{m}^2$" - - r"$\mathrm{W}/\mathrm{m}^2$" - - "" - - "" - - K - - K - - K - - K - - m/s - - m/s - - m/s - - m/s - - r"$\mathrm{kg}/\mathrm{m}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" - var_longnames: - - pres_heightAboveGround_0_instant - - pres_heightAboveSea_0_instant - - nlwrs_heightAboveGround_0_accum - - nswrs_heightAboveGround_0_accum - - r_heightAboveGround_2_instant - - r_hybrid_65_instant - - t_heightAboveGround_2_instant - - t_hybrid_65_instant - - t_isobaricInhPa_500_instant - - t_isobaricInhPa_850_instant - - u_hybrid_65_instant - - u_isobaricInhPa_850_instant - - v_hybrid_65_instant - - v_isobaricInhPa_850_instant - - wvint_entireAtmosphere_0_instant - - z_isobaricInhPa_1000_instant - - z_isobaricInhPa_500_instant - num_forcing_features: 16 -grid_shape_state: [268, 238] -projection: - class: LambertConformal - kwargs: - central_longitude: 15.0 - central_latitude: 63.3 - standard_parallels: [63.3, 63.3] - -dataset2: - name: cosmo_example + name: cosmo var_names: - "T" - "U" @@ -96,3 +31,11 @@ dataset2: vertical_levels: [1, 5, 13, 22, 38, 41, 60] num_forcing_features: 16 eval_plot_vars: ["TQV"] + grid_shape_state: [268, 238] + projection: + class: LambertConformal + kwargs: + central_longitude: 15.0 + central_latitude: 63.3 + standard_parallels: [63.3, 63.3] + diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index a782806b..686bffcd 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -35,7 +35,7 @@ def __init__( ): super().__init__() - assert split in ("train", "val", "test"), "Unknown dataset split" + assert split in ("train", "val", "test", "pred"), "Unknown dataset split" self.sample_dir_path = os.path.join( "data", dataset_name, "samples", split ) From 36807e73a8dafbad01d83547e6e51209291f1c71 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Fri, 31 May 2024 09:30:14 +0200 Subject: [PATCH 4/5] current updates --- neural_lam/weather_dataset.py | 2 +- train_model.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 686bffcd..60cae426 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -35,7 +35,7 @@ def __init__( ): super().__init__() - assert split in ("train", "val", "test", "pred"), "Unknown dataset split" + assert split in ("train", "val", "test", "predict"), "Unknown dataset split" self.sample_dir_path = os.path.join( "data", dataset_name, "samples", split ) diff --git a/train_model.py b/train_model.py index 2cfed246..2df3d4d1 100644 --- a/train_model.py +++ b/train_model.py @@ -292,6 +292,7 @@ def main(): accelerator=device_name, logger=logger, log_every_n_steps=1, + devices=4, callbacks=[checkpoint_callback], check_val_every_n_epoch=args.val_interval, precision=args.precision, From 5c5aff1cc5eed0adc3a7b9dcfad0a2363e913018 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Fri, 31 May 2024 10:45:14 +0200 Subject: [PATCH 5/5] we need the number of nodes --- neural_lam/data_config.yaml | 41 ------------------------------------- train_model.py | 3 ++- 2 files changed, 2 insertions(+), 42 deletions(-) delete mode 100644 neural_lam/data_config.yaml diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml deleted file mode 100644 index 5280eaf0..00000000 --- a/neural_lam/data_config.yaml +++ /dev/null @@ -1,41 +0,0 @@ -dataset: - name: cosmo - var_names: - - "T" - - "U" - - "V" - - "RELHUM" - - "PMSL" - - "PP" - var_units: - - K - - m/s - - m/s - - Perc. - - Pa - - hPa - var_longnames: - - "Temperature" - - "Zonal wind component" - - "Meridional wind component" - - "Relative humidity" - - "Pressure at Mean Sea Level" - - "Pressure Perturbation" - var_is_3d: - - 1 - - 1 - - 1 - - 1 - - 0 - - 1 - vertical_levels: [1, 5, 13, 22, 38, 41, 60] - num_forcing_features: 16 - eval_plot_vars: ["TQV"] - grid_shape_state: [268, 238] - projection: - class: LambertConformal - kwargs: - central_longitude: 15.0 - central_latitude: 63.3 - standard_parallels: [63.3, 63.3] - diff --git a/train_model.py b/train_model.py index 2df3d4d1..51bcc053 100644 --- a/train_model.py +++ b/train_model.py @@ -293,6 +293,7 @@ def main(): logger=logger, log_every_n_steps=1, devices=4, + num_nodes=1, callbacks=[checkpoint_callback], check_val_every_n_epoch=args.val_interval, precision=args.precision, @@ -355,7 +356,7 @@ def main(): model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, - ckpt_path=args.load, + ckpt_path=args.load )