From e5c245f21d01ca7379ee1a47a3de448cc02e8a59 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sun, 5 May 2024 20:21:36 +0200 Subject: [PATCH 01/26] yaml_config for cosmo data --- neural_lam/data_config.yaml | 130 ++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 neural_lam/data_config.yaml diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml new file mode 100644 index 00000000..0aedd7fe --- /dev/null +++ b/neural_lam/data_config.yaml @@ -0,0 +1,130 @@ +zarrs: # List of zarrs containing fields related to state + state: + path: /scratch/sadamov/template.zarr # Path to zarr + dims: # Name of dimensions in zarr, to be used for indexing + time: time + level: z + x: x # Either give "grid" (flattened) dimension or "x" and "y" + y: y + static: + path: /scratch/sadamov/template.zarr + dims: + level: z + x: x + y: y + forcing: + path: /scratch/sadamov/template.zarr + dims: + time: time + level: z + x: x + y: y + boundary: + zarrs: + mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary. +state: # Variables forecasted by the model + surface: # Single-field variables + - CLCT + - PMSL + - PS + - T_2M + - TOT_PREC + - U_10M + - V_10M + surface_units: + - "%" + - Pa + - Pa + - K + - kg/m^2 + - m/s + - m/s + atmosphere: # Variables with vertical levels + - PP + - QV + - RELHUM + - T + - U + - V + - W + atmosphere_units: + - Pa + - kg/kg + - "%" + - K + - m/s + - m/s + - Pa/s + levels: # Levels to use for atmosphere variables + - 0 + - 5 + - 8 + - 11 + - 13 + - 15 + - 19 + - 22 + - 26 + - 30 + - 38 + - 44 + - 59 +static: # Static inputs + surface: + - HSURF + surface_units: + - m + atmosphere: + - FI + atmosphere_units: + - m^2/s^2 + levels: + - 0 + - 5 + - 8 + - 11 + - 13 + - 15 + - 19 + - 22 + - 26 + - 30 + - 38 + - 44 + - 59 +forcing: # Forcing variables, dynamic inputs to the model + surface: + - ASOB_S + surface_units: + - W/m^2 + atmosphere: + atmosphere_units: + levels: +boundary: # Boundary conditions + surface: + surface_units: + atmosphere: + atmosphere_units: + levels: +lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells + lat: lat + lon: lon +grid_shape: + x: 582 + y: 390 +splits: + train: + start: 2015-01-01T00 + end: 2024-12-31T23 + val: + start: 2015-01-01T00 + end: 2024-12-31T23 + test: + start: 2015-01-01T00 + end: 2024-12-31T23 +projection: + class: RotatedPole # Name of class in cartopy.crs + kwargs: # Parsed and used directly as kwargs to projection-class above + pole_longitude: 10.0 + pole_latitude: -43.0 +normalization_zarr: data/meps_example/norm.zarr From 33e7ecf5fdf1949e5f055f465a337a40076ad91c Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sun, 5 May 2024 20:21:55 +0200 Subject: [PATCH 02/26] initial version of single zarr dataset --- neural_lam/weather_dataset.py | 437 ++++++++++++++++------------------ 1 file changed, 205 insertions(+), 232 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index eeefc313..8c9e0072 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,260 +1,233 @@ # Standard library -import datetime as dt -import glob import os # Third-party -import numpy as np +import pytorch_lightning as pl import torch +import xarray as xr +import yaml -# First-party -from neural_lam import constants, utils + +class ConfigLoader: + """ + Class for loading configuration files. + + This class loads a YAML configuration file and provides a way to access + its values as attributes. + """ + + def __init__(self, config_path, values=None): + self.config_path = config_path + if values is None: + self.values = self.load_config() + else: + self.values = values + + def load_config(self): + with open(self.config_path, "r") as file: + return yaml.safe_load(file) + + def __getattr__(self, name): + keys = name.split(".") + value = self.values + for key in keys: + if key in value: + value = value[key] + else: + None + if isinstance(value, dict): + return ConfigLoader(None, values=value) + return value + + def __getitem__(self, key): + value = self.values[key] + if isinstance(value, dict): + return ConfigLoader(None, values=value) + return value + + def __contains__(self, key): + return key in self.values class WeatherDataset(torch.utils.data.Dataset): """ - For our dataset: - N_t' = 65 - N_t = 65//subsample_step (= 21 for 3h steps) - dim_x = 268 - dim_y = 238 - N_grid = 268x238 = 63784 - d_features = 17 (d_features' = 18) - d_forcing = 5 + Dataset class for weather data. + + This class loads and processes weather data from zarr files based on the + provided configuration. It supports splitting the data into train, + validation, and test sets. """ + def process_dataset(self, dataset_name): + """ + Process a single dataset specified by the dataset name. + + Args: + dataset_name (str): Name of the dataset to process. + + Returns: + xarray.Dataset: Processed dataset. + """ + + dataset_path = os.path.join(self.config_loader.zarrs[dataset_name].path) + dataset = xr.open_zarr(dataset_path, consolidated=True) + + start, end = self.config_loader.splits[self.split].start, self.config_loader.splits[self.split].end + dataset = dataset.sel(time=slice(start, end)) + dataset = dataset.rename_dims( + {v: k for k, v in self.config_loader.zarrs[dataset_name].dims.values.items() + if k not in dataset.dims}) + if 'grid' not in dataset.dims: + dataset = dataset.stack(grid=('x', 'y')) + + vars_surface = [] + if self.config_loader[dataset_name].surface: + vars_surface = dataset[self.config_loader[dataset_name].surface] + + vars_atmosphere = [] + if self.config_loader[dataset_name].atmosphere: + vars_atmosphere = xr.merge( + [dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}") + for var in self.config_loader[dataset_name].atmosphere + for level in self.config_loader[dataset_name].levels]) + + if vars_surface and vars_atmosphere: + dataset = xr.merge([vars_surface, vars_atmosphere]) + elif vars_surface: + dataset = vars_surface + elif vars_atmosphere: + dataset = vars_atmosphere + else: + raise ValueError(f"No variables specified for dataset: {dataset_name}") + + dataset = dataset.squeeze(drop=True).to_array() + if "time" in dataset.dims: + dataset = dataset.transpose("time", "grid", "variable") + else: + dataset = dataset.transpose("grid", "variable") + return dataset + def __init__( self, - dataset_name, - pred_length=19, split="train", - subsample_step=3, - standardize=True, - subset=False, + batch_size=4, + ar_steps=3, control_only=False, + yaml_path="neural_lam/data_config.yaml", ): super().__init__() - assert split in ("train", "val", "test"), "Unknown dataset split" - self.sample_dir_path = os.path.join( - "data", dataset_name, "samples", split - ) + assert split in ( + "train", + "val", + "test", + ), "Unknown dataset split" - member_file_regexp = ( - "nwp*mbr000.npy" if control_only else "nwp*mbr*.npy" - ) - sample_paths = glob.glob( - os.path.join(self.sample_dir_path, member_file_regexp) - ) - self.sample_names = [path.split("/")[-1][4:-4] for path in sample_paths] - # Now on form "yyymmddhh_mbrXXX" - - if subset: - self.sample_names = self.sample_names[:50] # Limit to 50 samples - - self.sample_length = pred_length + 2 # 2 init states - self.subsample_step = subsample_step - self.original_sample_length = ( - 65 // self.subsample_step - ) # 21 for 3h steps - assert ( - self.sample_length <= self.original_sample_length - ), "Requesting too long time series samples" - - # Set up for standardization - self.standardize = standardize - if standardize: - ds_stats = utils.load_dataset_stats(dataset_name, "cpu") - self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], - ds_stats["flux_mean"], - ds_stats["flux_std"], - ) + self.split = split + self.batch_size = batch_size + self.ar_steps = ar_steps + self.control_only = control_only + self.config_loader = ConfigLoader(yaml_path) - # If subsample index should be sampled (only duing training) - self.random_subsample = split == "train" + self.state = self.process_dataset("state") + self.static = self.process_dataset("static") + self.forcings = self.process_dataset("forcing") + # self.boundary = self.process_dataset("boundary") + + self.static = self.static.expand_dims({"time": self.state.time}, axis=0) + self.ds = xr.concat([self.state, self.static], dim="variable") def __len__(self): - return len(self.sample_names) + return len(self.ds.time) - self.ar_steps def __getitem__(self, idx): - # === Sample === - sample_name = self.sample_names[idx] - sample_path = os.path.join( - self.sample_dir_path, f"nwp_{sample_name}.npy" - ) - try: - full_sample = torch.tensor( - np.load(sample_path), dtype=torch.float32 - ) # (N_t', dim_x, dim_y, d_features') - except ValueError: - print(f"Failed to load {sample_path}") - - # Only use every ss_step:th time step, sample which of ss_step - # possible such time series - if self.random_subsample: - subsample_index = torch.randint(0, self.subsample_step, ()).item() - else: - subsample_index = 0 - subsample_end_index = self.original_sample_length * self.subsample_step - sample = full_sample[ - subsample_index : subsample_end_index : self.subsample_step - ] - # (N_t, dim_x, dim_y, d_features') - - # Remove feature 15, "z_height_above_ground" - sample = torch.cat( - (sample[:, :, :, :15], sample[:, :, :, 16:]), dim=3 - ) # (N_t, dim_x, dim_y, d_features) - - # Accumulate solar radiation instead of just subsampling - rad_features = full_sample[:, :, :, 2:4] # (N_t', dim_x, dim_y, 2) - # Accumulate for first time step - init_accum_rad = torch.sum( - rad_features[: (subsample_index + 1)], dim=0, keepdim=True - ) # (1, dim_x, dim_y, 2) - # Accumulate for rest of subsampled sequence - in_subsample_len = ( - subsample_end_index - self.subsample_step + subsample_index + 1 - ) - rad_features_in_subsample = rad_features[ - (subsample_index + 1) : in_subsample_len - ] # (N_t*, dim_x, dim_y, 2), N_t* = (N_t-1)*ss_step - _, dim_x, dim_y, _ = sample.shape - rest_accum_rad = torch.sum( - rad_features_in_subsample.view( - self.original_sample_length - 1, - self.subsample_step, - dim_x, - dim_y, - 2, - ), - dim=1, - ) # (N_t-1, dim_x, dim_y, 2) - accum_rad = torch.cat( - (init_accum_rad, rest_accum_rad), dim=0 - ) # (N_t, dim_x, dim_y, 2) - # Replace in sample - sample[:, :, :, 2:4] = accum_rad - - # Flatten spatial dim - sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) - - # Uniformly sample time id to start sample from - init_id = torch.randint( - 0, 1 + self.original_sample_length - self.sample_length, () + sample = self.ds.isel(time=slice(idx, idx + self.ar_steps)) + forcings = self.forcings.isel(time=slice(idx, idx + self.ar_steps)) + sample = torch.tensor(sample.values, dtype=torch.float32) + forcings = torch.tensor(forcings.values, dtype=torch.float32) + + init_states = sample[:2] + target_states = sample[2:] + + batch_times = self.ds.isel( + time=slice( + idx, + idx + + self.ar_steps)).time.values.astype(str).tolist() + + # init_states: (2, N_grid, d_features) + # target_states: (ar_steps-2, N_grid, d_features) + # forcings: (ar_steps, N_grid, d_windowed_forcings) + # batch_times: (ar_steps,) + return init_states, target_states, forcings, batch_times + + +class WeatherDataModule(pl.LightningDataModule): + """DataModule for weather data.""" + + def __init__( + self, + batch_size=4, + num_workers=16, + ): + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + def setup(self, stage=None): + if stage == "fit" or stage is None: + self.train_dataset = WeatherDataset( + split="train", + batch_size=self.batch_size, + ) + self.val_dataset = WeatherDataset( + split="val", + batch_size=self.batch_size, + ) + + if stage == "test" or stage is None: + self.test_dataset = WeatherDataset( + split="test", + batch_size=self.batch_size, + ) + + def train_dataloader(self): + """Load train dataset.""" + return torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, ) - sample = sample[init_id : (init_id + self.sample_length)] - # (sample_length, N_grid, d_features) - - if self.standardize: - # Standardize sample - sample = (sample - self.data_mean) / self.data_std - - # Split up sample in init. states and target states - init_states = sample[:2] # (2, N_grid, d_features) - target_states = sample[2:] # (sample_length-2, N_grid, d_features) - - # === Forcing features === - # Now batch-static features are just part of forcing, - # repeated over temporal dimension - # Load water coverage - sample_datetime = sample_name[:10] - water_path = os.path.join( - self.sample_dir_path, f"wtr_{sample_datetime}.npy" + + def val_dataloader(self): + """Load validation dataset.""" + return torch.utils.data.DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, ) - water_cover_features = torch.tensor( - np.load(water_path), dtype=torch.float32 - ).unsqueeze( - -1 - ) # (dim_x, dim_y, 1) - # Flatten - water_cover_features = water_cover_features.flatten(0, 1) # (N_grid, 1) - # Expand over temporal dimension - water_cover_expanded = water_cover_features.unsqueeze(0).expand( - self.sample_length - 2, -1, -1 # -2 as added on after windowing - ) # (sample_len, N_grid, 1) - - # TOA flux - flux_path = os.path.join( - self.sample_dir_path, - f"nwp_toa_downwelling_shortwave_flux_{sample_datetime}.npy", + + def test_dataloader(self): + """Load test dataset.""" + return torch.utils.data.DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, ) - flux = torch.tensor(np.load(flux_path), dtype=torch.float32).unsqueeze( - -1 - ) # (N_t', dim_x, dim_y, 1) - - if self.standardize: - flux = (flux - self.flux_mean) / self.flux_std - - # Flatten and subsample flux forcing - flux = flux.flatten(1, 2) # (N_t, N_grid, 1) - flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1) - flux = flux[ - init_id : (init_id + self.sample_length) - ] # (sample_len, N_grid, 1) - - # Time of day and year - dt_obj = dt.datetime.strptime(sample_datetime, "%Y%m%d%H") - dt_obj = dt_obj + dt.timedelta( - hours=2 + subsample_index - ) # Offset for first index - # Extract for initial step - init_hour_in_day = dt_obj.hour - start_of_year = dt.datetime(dt_obj.year, 1, 1) - init_seconds_into_year = (dt_obj - start_of_year).total_seconds() - - # Add increments for all steps - hour_inc = ( - torch.arange(self.sample_length) * self.subsample_step - ) # (sample_len,) - hour_of_day = ( - init_hour_in_day + hour_inc - ) # (sample_len,), Can be > 24 but ok - second_into_year = ( - init_seconds_into_year + hour_inc * 3600 - ) # (sample_len,) - # can roll over to next year, ok because periodicity - - # Encode as sin/cos - hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,) - year_angle = ( - (second_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi - ) # (sample_len,) - datetime_forcing = torch.stack( - ( - torch.sin(hour_angle), - torch.cos(hour_angle), - torch.sin(year_angle), - torch.cos(year_angle), - ), - dim=1, - ) # (N_t, 4) - datetime_forcing = (datetime_forcing + 1) / 2 # Rescale to [0,1] - datetime_forcing = datetime_forcing.unsqueeze(1).expand( - -1, flux.shape[1], -1 - ) # (sample_len, N_grid, 4) - - # Put forcing features together - forcing_features = torch.cat( - (flux, datetime_forcing), dim=-1 - ) # (sample_len, N_grid, d_forcing) - - # Combine forcing over each window of 3 time steps - forcing_windowed = torch.cat( - ( - forcing_features[:-2], - forcing_features[1:-1], - forcing_features[2:], - ), - dim=2, - ) # (sample_len-2, N_grid, 3*d_forcing) - # Now index 0 of ^ corresponds to forcing at index 0-2 of sample - - # batch-static water cover is added after windowing, - # as it is static over time - forcing = torch.cat((water_cover_expanded, forcing_windowed), dim=2) - # (sample_len-2, N_grid, forcing_dim) - - return init_states, target_states, forcing + + +data_module = WeatherDataModule(batch_size=4, num_workers=0) +data_module.setup() +train_dataloader = data_module.train_dataloader() +for batch in train_dataloader: + print(batch[0].shape) + print(batch[1].shape) + print(batch[2].shape) + print(batch[3]) + break From 9936e3bcb9b32a81908af3cdde196b89a5e3b5ee Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 6 May 2024 23:12:15 +0200 Subject: [PATCH 03/26] handling None zarrs --- neural_lam/data_config.yaml | 4 +- neural_lam/weather_dataset.py | 80 ++++++++++++++++++++++++----------- 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index 0aedd7fe..af2672af 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -20,7 +20,7 @@ zarrs: # List of zarrs containing fields related to state x: x y: y boundary: - zarrs: + path: mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary. state: # Variables forecasted by the model surface: # Single-field variables @@ -127,4 +127,4 @@ projection: kwargs: # Parsed and used directly as kwargs to projection-class above pole_longitude: 10.0 pole_latitude: -43.0 -normalization_zarr: data/meps_example/norm.zarr +normalization_zarr: /scratch/sadamov/norm.zarr diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 8c9e0072..968a985f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,5 +1,6 @@ # Standard library import os +from functools import lru_cache # Third-party import pytorch_lightning as pl @@ -27,6 +28,7 @@ def load_config(self): with open(self.config_path, "r") as file: return yaml.safe_load(file) + @lru_cache(maxsize=None) def __getattr__(self, name): keys = name.split(".") value = self.values @@ -69,16 +71,26 @@ def process_dataset(self, dataset_name): xarray.Dataset: Processed dataset. """ - dataset_path = os.path.join(self.config_loader.zarrs[dataset_name].path) + dataset_path = self.config_loader.zarrs[dataset_name].path + if dataset_path is None or not os.path.exists(dataset_path): + print(f"Dataset '{dataset_name}' not found at path: {dataset_path}") + return None dataset = xr.open_zarr(dataset_path, consolidated=True) - start, end = self.config_loader.splits[self.split].start, self.config_loader.splits[self.split].end + start, end = ( + self.config_loader.splits[self.split].start, + self.config_loader.splits[self.split].end, + ) dataset = dataset.sel(time=slice(start, end)) dataset = dataset.rename_dims( - {v: k for k, v in self.config_loader.zarrs[dataset_name].dims.values.items() - if k not in dataset.dims}) - if 'grid' not in dataset.dims: - dataset = dataset.stack(grid=('x', 'y')) + { + v: k + for k, v in self.config_loader.zarrs[dataset_name].dims.values.items() + if k not in dataset.dims + } + ) + if "grid" not in dataset.dims: + dataset = dataset.stack(grid=("x", "y")) vars_surface = [] if self.config_loader[dataset_name].surface: @@ -87,9 +99,12 @@ def process_dataset(self, dataset_name): vars_atmosphere = [] if self.config_loader[dataset_name].atmosphere: vars_atmosphere = xr.merge( - [dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}") - for var in self.config_loader[dataset_name].atmosphere - for level in self.config_loader[dataset_name].levels]) + [ + dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}") + for var in self.config_loader[dataset_name].atmosphere + for level in self.config_loader[dataset_name].levels + ] + ) if vars_surface and vars_atmosphere: dataset = xr.merge([vars_surface, vars_atmosphere]) @@ -98,7 +113,8 @@ def process_dataset(self, dataset_name): elif vars_atmosphere: dataset = vars_atmosphere else: - raise ValueError(f"No variables specified for dataset: {dataset_name}") + print("No variables found in dataset {dataset_name}") + return None dataset = dataset.squeeze(drop=True).to_array() if "time" in dataset.dims: @@ -130,36 +146,49 @@ def __init__( self.config_loader = ConfigLoader(yaml_path) self.state = self.process_dataset("state") + assert self.state is not None, "State dataset not found" self.static = self.process_dataset("static") self.forcings = self.process_dataset("forcing") - # self.boundary = self.process_dataset("boundary") + self.boundary = self.process_dataset("boundary") - self.static = self.static.expand_dims({"time": self.state.time}, axis=0) - self.ds = xr.concat([self.state, self.static], dim="variable") + if self.static is not None: + self.static = self.static.expand_dims({"time": self.state.time}, axis=0) + self.state = xr.concat([self.state, self.static], dim="variable") def __len__(self): - return len(self.ds.time) - self.ar_steps + return len(self.state.time) - self.ar_steps def __getitem__(self, idx): - sample = self.ds.isel(time=slice(idx, idx + self.ar_steps)) - forcings = self.forcings.isel(time=slice(idx, idx + self.ar_steps)) - sample = torch.tensor(sample.values, dtype=torch.float32) - forcings = torch.tensor(forcings.values, dtype=torch.float32) + sample = torch.tensor( + self.state.isel(time=slice(idx, idx + self.ar_steps)).values, + dtype=torch.float32, + ) + + forcings = torch.tensor( + self.forcings.isel(time=slice(idx, idx + self.ar_steps)).values, + dtype=torch.float32, + ) if self.forcings is not None else torch.tensor([]) + + boundary = torch.tensor( + self.boundary.isel(time=slice(idx, idx + self.ar_steps)).values, + dtype=torch.float32, + ) if self.boundary is not None else torch.tensor([]) init_states = sample[:2] target_states = sample[2:] - batch_times = self.ds.isel( - time=slice( - idx, - idx + - self.ar_steps)).time.values.astype(str).tolist() + batch_times = ( + self.state.isel(time=slice(idx, idx + self.ar_steps)) + .time.values.astype(str) + .tolist() + ) # init_states: (2, N_grid, d_features) # target_states: (ar_steps-2, N_grid, d_features) # forcings: (ar_steps, N_grid, d_windowed_forcings) + # boundary: (ar_steps, N_grid, d_windowed_boundary) # batch_times: (ar_steps,) - return init_states, target_states, forcings, batch_times + return init_states, target_states, forcings, boundary, batch_times class WeatherDataModule(pl.LightningDataModule): @@ -229,5 +258,6 @@ def test_dataloader(self): print(batch[0].shape) print(batch[1].shape) print(batch[2].shape) - print(batch[3]) + print(batch[3].shape) + print(batch[4]) break From 774d16a7685f912a74f5a19cc0faba4d4ab09eb7 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 8 May 2024 11:08:18 +0200 Subject: [PATCH 04/26] removed all dependencies on constants.py user configs are retrieved either from data_config.yaml or they are set as flags to train_model.py Capabilities of ConfigLoader class extended --- create_parameter_weights.py | 13 +--- neural_lam/constants.py | 120 ---------------------------------- neural_lam/models/ar_model.py | 53 ++++++++------- neural_lam/utils.py | 6 +- neural_lam/vis.py | 24 ++++--- neural_lam/weather_dataset.py | 35 ++++++---- train_model.py | 37 ++++++++++- 7 files changed, 102 insertions(+), 186 deletions(-) delete mode 100644 neural_lam/constants.py diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 494a5e81..6956d4ca 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -8,7 +8,6 @@ from tqdm import tqdm # First-party -from neural_lam import constants from neural_lam.weather_dataset import WeatherDataset @@ -45,6 +44,7 @@ def main(): static_dir_path = os.path.join("data", args.dataset, "static") + ds = WeatherDataset() # Create parameter weights based on height # based on fig A.1 in graph cast paper w_dict = { @@ -56,7 +56,7 @@ def main(): "500": 0.03, } w_list = np.array( - [w_dict[par.split("_")[-2]] for par in constants.PARAM_NAMES] + [w_dict[par.split("_")[-2]] for par in ds.config_loader.param_names()] ) print("Saving parameter weights...") np.save( @@ -65,13 +65,6 @@ def main(): ) # Load dataset without any subsampling - ds = WeatherDataset( - args.dataset, - split="train", - subsample_step=1, - pred_length=63, - standardize=False, - ) # Without standardization loader = torch.utils.data.DataLoader( ds, args.batch_size, shuffle=False, num_workers=args.n_workers ) @@ -133,7 +126,7 @@ def main(): # Note: batch contains only 1h-steps stepped_batch = torch.cat( [ - batch[:, ss_i : used_subsample_len : args.step_length] + batch[:, ss_i: used_subsample_len: args.step_length] for ss_i in range(args.step_length) ], dim=0, diff --git a/neural_lam/constants.py b/neural_lam/constants.py deleted file mode 100644 index 527c31d8..00000000 --- a/neural_lam/constants.py +++ /dev/null @@ -1,120 +0,0 @@ -# Third-party -import cartopy -import numpy as np - -WANDB_PROJECT = "neural-lam" - -SECONDS_IN_YEAR = ( - 365 * 24 * 60 * 60 -) # Assuming no leap years in dataset (2024 is next) - -# Log prediction error for these lead times -VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19]) - -# Log these metrics to wandb as scalar values for -# specific variables and lead times -# List of metrics to watch, including any prefix (e.g. val_rmse) -METRICS_WATCH = [] -# Dict with variables and lead times to log watched metrics for -# Format is a dictionary that maps from a variable index to -# a list of lead time steps -VAR_LEADS_METRICS_WATCH = { - 6: [2, 19], # t_2 - 14: [2, 19], # wvint_0 - 15: [2, 19], # z_1000 -} - -# Variable names -PARAM_NAMES = [ - "pres_heightAboveGround_0_instant", - "pres_heightAboveSea_0_instant", - "nlwrs_heightAboveGround_0_accum", - "nswrs_heightAboveGround_0_accum", - "r_heightAboveGround_2_instant", - "r_hybrid_65_instant", - "t_heightAboveGround_2_instant", - "t_hybrid_65_instant", - "t_isobaricInhPa_500_instant", - "t_isobaricInhPa_850_instant", - "u_hybrid_65_instant", - "u_isobaricInhPa_850_instant", - "v_hybrid_65_instant", - "v_isobaricInhPa_850_instant", - "wvint_entireAtmosphere_0_instant", - "z_isobaricInhPa_1000_instant", - "z_isobaricInhPa_500_instant", -] - -PARAM_NAMES_SHORT = [ - "pres_0g", - "pres_0s", - "nlwrs_0", - "nswrs_0", - "r_2", - "r_65", - "t_2", - "t_65", - "t_500", - "t_850", - "u_65", - "u_850", - "v_65", - "v_850", - "wvint_0", - "z_1000", - "z_500", -] -PARAM_UNITS = [ - "Pa", - "Pa", - "W/m\\textsuperscript{2}", - "W/m\\textsuperscript{2}", - "-", # unitless - "-", - "K", - "K", - "K", - "K", - "m/s", - "m/s", - "m/s", - "m/s", - "kg/m\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", -] - -# Projection and grid -# Hard coded for now, but should eventually be part of dataset desc. files -GRID_SHAPE = (268, 238) # (y, x) - -LAMBERT_PROJ_PARAMS = { - "a": 6367470, - "b": 6367470, - "lat_0": 63.3, - "lat_1": 63.3, - "lat_2": 63.3, - "lon_0": 15.0, - "proj": "lcc", -} - -GRID_LIMITS = [ # In projection - -1059506.5523409774, # min x - 1310493.4476590226, # max x - -1331732.4471934352, # min y - 1338267.5528065648, # max y -] - -# Create projection -LAMBERT_PROJ = cartopy.crs.LambertConformal( - central_longitude=LAMBERT_PROJ_PARAMS["lon_0"], - central_latitude=LAMBERT_PROJ_PARAMS["lat_0"], - standard_parallels=( - LAMBERT_PROJ_PARAMS["lat_1"], - LAMBERT_PROJ_PARAMS["lat_2"], - ), -) - -# Data dimensions -GRID_FORCING_DIM = 5 * 3 + 1 # 5 feat. for 3 time-step window + 1 batch-static -GRID_STATE_DIM = 17 diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 7d0a8320..902a89e4 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -6,10 +6,12 @@ import numpy as np import pytorch_lightning as pl import torch + import wandb # First-party -from neural_lam import constants, metrics, utils, vis +from neural_lam import metrics, utils, vis +from neural_lam.weather_dataset import ConfigLoader class ARModel(pl.LightningModule): @@ -25,6 +27,7 @@ def __init__(self, args): super().__init__() self.save_hyperparameters() self.lr = args.lr + self.config_loader = ConfigLoader(args.data_config) # Load static features for grid/data static_data_dict = utils.load_static_data(args.dataset) @@ -37,11 +40,11 @@ def __init__(self, args): self.output_std = bool(args.output_std) if self.output_std: self.grid_output_dim = ( - 2 * constants.GRID_STATE_DIM + 2 * self.config_loader.num_data_vars("state") ) # Pred. dim. in grid cell else: self.grid_output_dim = ( - constants.GRID_STATE_DIM + self.config_loader.num_data_vars("state") ) # Pred. dim. in grid cell # Store constant per-variable std.-dev. weighting @@ -59,9 +62,9 @@ def __init__(self, args): grid_static_dim, ) = self.grid_static_features.shape # 63784 = 268x238 self.grid_dim = ( - 2 * constants.GRID_STATE_DIM + 2 * self.config_loader.num_data_vars("state") + grid_static_dim - + constants.GRID_FORCING_DIM + + self.config_loader.num_data_vars("forcing") ) # Instantiate loss function @@ -246,7 +249,7 @@ def validation_step(self, batch, batch_idx): # Log loss per time step forward and mean val_log_dict = { f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in constants.VAL_STEP_LOG_ERRORS + for step in self.args.val_steps_log } val_log_dict["val_mean_loss"] = mean_loss self.log_dict( @@ -294,7 +297,7 @@ def test_step(self, batch, batch_idx): # Log loss per time step forward and mean test_log_dict = { f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in constants.VAL_STEP_LOG_ERRORS + for step in self.args.val_steps_log } test_log_dict["test_mean_loss"] = mean_loss @@ -328,7 +331,7 @@ def test_step(self, batch, batch_idx): spatial_loss = self.loss( prediction, target, pred_std, average_grid=False ) # (B, pred_steps, num_grid_nodes) - log_spatial_losses = spatial_loss[:, constants.VAL_STEP_LOG_ERRORS - 1] + log_spatial_losses = spatial_loss[:, self.args.val_steps_log - 1] self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) @@ -399,14 +402,15 @@ def plot_examples(self, batch, n_examples, prediction=None): pred_t[:, var_i], target_t[:, var_i], self.interior_mask[:, 0], + self.config_loader, title=f"{var_name} ({var_unit}), " - f"t={t_i} ({self.step_length*t_i} h)", + f"t={t_i} ({self.step_length * t_i} h)", vrange=var_vrange, ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - constants.PARAM_NAMES_SHORT, - constants.PARAM_UNITS, + self.config_loader.param_names(), + self.config_loader.param_units(), var_vranges, ) ) @@ -417,7 +421,7 @@ def plot_examples(self, batch, n_examples, prediction=None): { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - constants.PARAM_NAMES_SHORT, var_figs + self.config_loader.param_names(), var_figs ) } ) @@ -453,7 +457,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ log_dict = {} metric_fig = vis.plot_error_map( - metric_tensor, step_length=self.step_length + metric_tensor, self.config_loader, step_length=self.step_length ) full_log_name = f"{prefix}_{metric_name}" log_dict[full_log_name] = wandb.Image(metric_fig) @@ -471,14 +475,14 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): ) # Check if metrics are watched, log exact values for specific vars - if full_log_name in constants.METRICS_WATCH: - for var_i, timesteps in constants.VAR_LEADS_METRICS_WATCH.items(): - var = constants.PARAM_NAMES_SHORT[var_i] + if full_log_name in self.args.metrics_watch: + for var_i, timesteps in self.args.var_leads_metrics_watch.items(): + var = self.config_loader.param_names()[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ step - 1, var_i - ] # 1-indexed in constants + ] # 1-indexed in data_config for step in timesteps } ) @@ -542,10 +546,11 @@ def on_test_epoch_end(self): vis.plot_spatial_error( loss_map, self.interior_mask[:, 0], - title=f"Test loss, t={t_i} ({self.step_length*t_i} h)", + self.config_loader, + title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", ) for t_i, loss_map in zip( - constants.VAL_STEP_LOG_ERRORS, mean_spatial_loss + self.args.val_steps_log, mean_spatial_loss ) ] @@ -554,14 +559,14 @@ def on_test_epoch_end(self): wandb.log({"test_loss": wandb.Image(fig)}) # also make without title and save as pdf - pdf_loss_map_figs = [ - vis.plot_spatial_error(loss_map, self.interior_mask[:, 0]) - for loss_map in mean_spatial_loss - ] + pdf_loss_map_figs = [vis.plot_spatial_error( + loss_map, self.interior_mask[:, 0], + self.config_loader) + for loss_map in mean_spatial_loss] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) for t_i, fig in zip( - constants.VAL_STEP_LOG_ERRORS, pdf_loss_map_figs + self.args.val_steps_log, pdf_loss_map_figs ): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 31715502..8b9e250b 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -7,8 +7,6 @@ from torch import nn from tueplots import bundles, figsizes -# First-party -from neural_lam import constants def load_dataset_stats(dataset_name, device="cpu"): @@ -263,11 +261,11 @@ def fractional_plot_bundle(fraction): return bundle -def init_wandb_metrics(wandb_logger): +def init_wandb_metrics(wandb_logger, val_steps): """ Set up wandb metrics to track """ experiment = wandb_logger.experiment experiment.define_metric("val_mean_loss", summary="min") - for step in constants.VAL_STEP_LOG_ERRORS: + for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") diff --git a/neural_lam/vis.py b/neural_lam/vis.py index cef34a84..81adb935 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -4,11 +4,11 @@ import numpy as np # First-party -from neural_lam import constants, utils +from neural_lam import utils @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_error_map(errors, title=None, step_length=3): +def plot_error_map(errors, data_config, title=None, step_length=3): """ Plot a heatmap of errors of different variables at different predictions horizons @@ -51,7 +51,7 @@ def plot_error_map(errors, title=None, step_length=3): y_ticklabels = [ f"{name} ({unit})" for name, unit in zip( - constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS + data_config.param_names(), data_config.param_units() ) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -63,7 +63,7 @@ def plot_error_map(errors, title=None, step_length=3): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_prediction(pred, target, obs_mask, title=None, vrange=None): +def plot_prediction(pred, target, obs_mask, data_config, title=None, vrange=None): """ Plot example prediction and grond truth. Each has shape (N_grid,) @@ -76,23 +76,22 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region fig, axes = plt.subplots( - 1, 2, figsize=(13, 7), subplot_kw={"projection": constants.LAMBERT_PROJ} + 1, 2, figsize=(13, 7), subplot_kw={"projection": data_config.projection()} ) # Plot pred and target for ax, data in zip(axes, (target, pred)): ax.coastlines() # Add coastline outlines - data_grid = data.reshape(*constants.GRID_SHAPE).cpu().numpy() + data_grid = data.reshape(*data_config.grid_shape).cpu().numpy() im = ax.imshow( data_grid, origin="lower", - extent=constants.GRID_LIMITS, alpha=pixel_alpha, vmin=vmin, vmax=vmax, @@ -112,7 +111,7 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_spatial_error(error, obs_mask, title=None, vrange=None): +def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): """ Plot errors over spatial map Error and obs_mask has shape (N_grid,) @@ -125,22 +124,21 @@ def plot_spatial_error(error, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region fig, ax = plt.subplots( - figsize=(5, 4.8), subplot_kw={"projection": constants.LAMBERT_PROJ} + figsize=(5, 4.8), subplot_kw={"projection": data_config.projection()} ) ax.coastlines() # Add coastline outlines - error_grid = error.reshape(*constants.GRID_SHAPE).cpu().numpy() + error_grid = error.reshape(*data_config.grid_shape).cpu().numpy() im = ax.imshow( error_grid, origin="lower", - extent=constants.GRID_LIMITS, alpha=pixel_alpha, vmin=vmin, vmax=vmax, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 968a985f..b6181602 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -7,6 +7,7 @@ import torch import xarray as xr import yaml +import cartopy.crs as ccrs class ConfigLoader: @@ -36,7 +37,7 @@ def __getattr__(self, name): if key in value: value = value[key] else: - None + return None if isinstance(value, dict): return ConfigLoader(None, values=value) return value @@ -50,6 +51,24 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.values + def param_names(self): + return self.values['state']['surface'] + self.values['state']['atmosphere'] + + def param_units(self): + return self.values['state']['surface_units'] + self.values['state']['atmosphere_units'] + + def num_data_vars(self, key): + surface_vars = len(self.values[key]['surface']) + atmosphere_vars = len(self.values[key]['atmosphere']) + levels = len(self.values[key]['levels']) + return surface_vars + atmosphere_vars * levels + + def projection(self): + proj_config = self.values["projections"]["class"] + proj_class = getattr(ccrs, proj_config["proj_class"]) + proj_params = proj_config["proj_params"] + return proj_class(**proj_params) + class WeatherDataset(torch.utils.data.Dataset): """ @@ -61,15 +80,7 @@ class WeatherDataset(torch.utils.data.Dataset): """ def process_dataset(self, dataset_name): - """ - Process a single dataset specified by the dataset name. - - Args: - dataset_name (str): Name of the dataset to process. - - Returns: - xarray.Dataset: Processed dataset. - """ + """Process a single dataset specified by the dataset name.""" dataset_path = self.config_loader.zarrs[dataset_name].path if dataset_path is None or not os.path.exists(dataset_path): @@ -129,7 +140,7 @@ def __init__( batch_size=4, ar_steps=3, control_only=False, - yaml_path="neural_lam/data_config.yaml", + data_config="neural_lam/data_config.yaml", ): super().__init__() @@ -143,7 +154,7 @@ def __init__( self.batch_size = batch_size self.ar_steps = ar_steps self.control_only = control_only - self.config_loader = ConfigLoader(yaml_path) + self.config_loader = ConfigLoader(data_config) self.state = self.process_dataset("state") assert self.state is not None, "State dataset not found" diff --git a/train_model.py b/train_model.py index 96d21a3f..767d575a 100644 --- a/train_model.py +++ b/train_model.py @@ -9,7 +9,7 @@ from lightning_fabric.utilities import seed # First-party -from neural_lam import constants, utils +from neural_lam import utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel @@ -44,6 +44,12 @@ def main(): default="graph_lam", help="Model architecture to train/evaluate (default: graph_lam)", ) + parser. add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data configuration file (default: neural_lam/data_config.yaml)", + ) parser.add_argument( "--subset_ds", type=int, @@ -183,6 +189,30 @@ def main(): help="Number of example predictions to plot during evaluation " "(default: 1)", ) + parser.add_argument( + "--wandb_project", + type=str, + default="neural-lam", + help="Wandb project to log to (default: neural-lam)", + ) + parser.add_argument( + "--val_steps_log", + type=list, + default=[1, 2, 3, 5, 10, 15, 19], + help="Steps to log validation loss for (default: [1, 2, 3, 5, 10, 15, 19])", + ) + parser.add_argument( + "--metrics_watch", + type=list, + default=[], + help="List of metrics to watch, including any prefix (e.g. val_rmse)", + ) + parser.add_argument( + "--var_leads_metrics_watch", + type=dict, + default={}, + help="Dict with variables and lead times to log watched metrics for", + ) args = parser.parse_args() # Asserts for arguments @@ -264,7 +294,7 @@ def main(): save_last=True, ) logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, name=run_name, config=args + project=args.wandb_project, name=run_name, config=args ) trainer = pl.Trainer( max_epochs=args.epochs, @@ -280,7 +310,8 @@ def main(): # Only init once, on rank 0 only if trainer.global_rank == 0: - utils.init_wandb_metrics(logger) # Do after wandb.init + utils.init_wandb_metrics( + logger, val_steps=args.val_steps_log) # Do after wandb.init if args.eval: if args.eval == "val": From 7bb139b2b54a64ca9e8ae5d73deea2b2579767c4 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 8 May 2024 11:22:37 +0200 Subject: [PATCH 05/26] fix linter --- create_parameter_weights.py | 10 ++---- neural_lam/data_config.yaml | 8 ++--- neural_lam/models/ar_model.py | 23 ++++++------- neural_lam/utils.py | 1 - neural_lam/vis.py | 9 +++-- neural_lam/weather_dataset.py | 64 +++++++++++++++++++++++------------ train_model.py | 36 +++++--------------- 7 files changed, 75 insertions(+), 76 deletions(-) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 6956d4ca..926d7741 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -105,13 +105,7 @@ def main(): # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") - ds_standard = WeatherDataset( - args.dataset, - split="train", - subsample_step=1, - pred_length=63, - standardize=True, - ) # Re-load with standardization + ds_standard = WeatherDataset() # Re-load with standardization loader_standard = torch.utils.data.DataLoader( ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers ) @@ -126,7 +120,7 @@ def main(): # Note: batch contains only 1h-steps stepped_batch = torch.cat( [ - batch[:, ss_i: used_subsample_len: args.step_length] + batch[:, ss_i : used_subsample_len : args.step_length] for ss_i in range(args.step_length) ], dim=0, diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index af2672af..8d936154 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -12,7 +12,7 @@ zarrs: # List of zarrs containing fields related to state level: z x: x y: y - forcing: + forcing: path: /scratch/sadamov/template.zarr dims: time: time @@ -55,7 +55,7 @@ state: # Variables forecasted by the model - m/s - m/s - Pa/s - levels: # Levels to use for atmosphere variables + levels: # Levels to use for atmosphere variables - 0 - 5 - 8 @@ -71,7 +71,7 @@ state: # Variables forecasted by the model - 59 static: # Static inputs surface: - - HSURF + - HSURF surface_units: - m atmosphere: @@ -122,7 +122,7 @@ splits: test: start: 2015-01-01T00 end: 2024-12-31T23 -projection: +projection: class: RotatedPole # Name of class in cartopy.crs kwargs: # Parsed and used directly as kwargs to projection-class above pole_longitude: 10.0 diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 902a89e4..93f2b569 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -6,7 +6,6 @@ import numpy as np import pytorch_lightning as pl import torch - import wandb # First-party @@ -39,12 +38,12 @@ def __init__(self, args): # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) if self.output_std: - self.grid_output_dim = ( - 2 * self.config_loader.num_data_vars("state") + self.grid_output_dim = 2 * self.config_loader.num_data_vars( + "state" ) # Pred. dim. in grid cell else: - self.grid_output_dim = ( - self.config_loader.num_data_vars("state") + self.grid_output_dim = self.config_loader.num_data_vars( + "state" ) # Pred. dim. in grid cell # Store constant per-variable std.-dev. weighting @@ -559,15 +558,15 @@ def on_test_epoch_end(self): wandb.log({"test_loss": wandb.Image(fig)}) # also make without title and save as pdf - pdf_loss_map_figs = [vis.plot_spatial_error( - loss_map, self.interior_mask[:, 0], - self.config_loader) - for loss_map in mean_spatial_loss] + pdf_loss_map_figs = [ + vis.plot_spatial_error( + loss_map, self.interior_mask[:, 0], self.config_loader + ) + for loss_map in mean_spatial_loss + ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip( - self.args.val_steps_log, pdf_loss_map_figs - ): + for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 8b9e250b..836b04ed 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -8,7 +8,6 @@ from tueplots import bundles, figsizes - def load_dataset_stats(dataset_name, device="cpu"): """ Load arrays with stored dataset statistics from pre-processing diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 81adb935..02b8dd35 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -63,7 +63,9 @@ def plot_error_map(errors, data_config, title=None, step_length=3): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_prediction(pred, target, obs_mask, data_config, title=None, vrange=None): +def plot_prediction( + pred, target, obs_mask, data_config, title=None, vrange=None +): """ Plot example prediction and grond truth. Each has shape (N_grid,) @@ -82,7 +84,10 @@ def plot_prediction(pred, target, obs_mask, data_config, title=None, vrange=None ) # Faded border region fig, axes = plt.subplots( - 1, 2, figsize=(13, 7), subplot_kw={"projection": data_config.projection()} + 1, + 2, + figsize=(13, 7), + subplot_kw={"projection": data_config.projection()}, ) # Plot pred and target diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b6181602..28c29db6 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,13 +1,12 @@ # Standard library import os -from functools import lru_cache # Third-party +import cartopy.crs as ccrs import pytorch_lightning as pl import torch import xarray as xr import yaml -import cartopy.crs as ccrs class ConfigLoader: @@ -26,10 +25,10 @@ def __init__(self, config_path, values=None): self.values = values def load_config(self): - with open(self.config_path, "r") as file: + """Load configuration file.""" + with open(self.config_path, encoding="utf-8", mode="r") as file: return yaml.safe_load(file) - @lru_cache(maxsize=None) def __getattr__(self, name): keys = name.split(".") value = self.values @@ -52,18 +51,27 @@ def __contains__(self, key): return key in self.values def param_names(self): - return self.values['state']['surface'] + self.values['state']['atmosphere'] + """Return parameter names.""" + return ( + self.values["state"]["surface"] + self.values["state"]["atmosphere"] + ) def param_units(self): - return self.values['state']['surface_units'] + self.values['state']['atmosphere_units'] + """Return parameter units.""" + return ( + self.values["state"]["surface_units"] + + self.values["state"]["atmosphere_units"] + ) def num_data_vars(self, key): - surface_vars = len(self.values[key]['surface']) - atmosphere_vars = len(self.values[key]['atmosphere']) - levels = len(self.values[key]['levels']) + """Return the number of data variables for a given key.""" + surface_vars = len(self.values[key]["surface"]) + atmosphere_vars = len(self.values[key]["atmosphere"]) + levels = len(self.values[key]["levels"]) return surface_vars + atmosphere_vars * levels - + def projection(self): + """Return the projection.""" proj_config = self.values["projections"]["class"] proj_class = getattr(ccrs, proj_config["proj_class"]) proj_params = proj_config["proj_params"] @@ -96,7 +104,9 @@ def process_dataset(self, dataset_name): dataset = dataset.rename_dims( { v: k - for k, v in self.config_loader.zarrs[dataset_name].dims.values.items() + for k, v in self.config_loader.zarrs[ + dataset_name + ].dims.values.items() if k not in dataset.dims } ) @@ -111,7 +121,9 @@ def process_dataset(self, dataset_name): if self.config_loader[dataset_name].atmosphere: vars_atmosphere = xr.merge( [ - dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}") + dataset[var] + .sel(level=level, drop=True) + .rename(f"{var}_{level}") for var in self.config_loader[dataset_name].atmosphere for level in self.config_loader[dataset_name].levels ] @@ -163,7 +175,9 @@ def __init__( self.boundary = self.process_dataset("boundary") if self.static is not None: - self.static = self.static.expand_dims({"time": self.state.time}, axis=0) + self.static = self.static.expand_dims( + {"time": self.state.time}, axis=0 + ) self.state = xr.concat([self.state, self.static], dim="variable") def __len__(self): @@ -175,15 +189,23 @@ def __getitem__(self, idx): dtype=torch.float32, ) - forcings = torch.tensor( - self.forcings.isel(time=slice(idx, idx + self.ar_steps)).values, - dtype=torch.float32, - ) if self.forcings is not None else torch.tensor([]) + forcings = ( + torch.tensor( + self.forcings.isel(time=slice(idx, idx + self.ar_steps)).values, + dtype=torch.float32, + ) + if self.forcings is not None + else torch.tensor([]) + ) - boundary = torch.tensor( - self.boundary.isel(time=slice(idx, idx + self.ar_steps)).values, - dtype=torch.float32, - ) if self.boundary is not None else torch.tensor([]) + boundary = ( + torch.tensor( + self.boundary.isel(time=slice(idx, idx + self.ar_steps)).values, + dtype=torch.float32, + ) + if self.boundary is not None + else torch.tensor([]) + ) init_states = sample[:2] target_states = sample[2:] diff --git a/train_model.py b/train_model.py index 767d575a..23a0330c 100644 --- a/train_model.py +++ b/train_model.py @@ -44,11 +44,11 @@ def main(): default="graph_lam", help="Model architecture to train/evaluate (default: graph_lam)", ) - parser. add_argument( + parser.add_argument( "--data_config", type=str, default="neural_lam/data_config.yaml", - help="Path to data configuration file (default: neural_lam/data_config.yaml)", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--subset_ds", @@ -199,7 +199,7 @@ def main(): "--val_steps_log", type=list, default=[1, 2, 3, 5, 10, 15, 19], - help="Steps to log validation loss for (default: [1, 2, 3, 5, 10, 15, 19])", + help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])", ) parser.add_argument( "--metrics_watch", @@ -232,28 +232,13 @@ def main(): # Load data train_loader = torch.utils.data.DataLoader( - WeatherDataset( - args.dataset, - pred_length=args.ar_steps, - split="train", - subsample_step=args.step_length, - subset=bool(args.subset_ds), - control_only=args.control_only, - ), + WeatherDataset(control_only=args.control_only), args.batch_size, shuffle=True, num_workers=args.n_workers, ) - max_pred_length = (65 // args.step_length) - 2 # 19 val_loader = torch.utils.data.DataLoader( - WeatherDataset( - args.dataset, - pred_length=max_pred_length, - split="val", - subsample_step=args.step_length, - subset=bool(args.subset_ds), - control_only=args.control_only, - ), + WeatherDataset(control_only=args.control_only), args.batch_size, shuffle=False, num_workers=args.n_workers, @@ -311,20 +296,15 @@ def main(): # Only init once, on rank 0 only if trainer.global_rank == 0: utils.init_wandb_metrics( - logger, val_steps=args.val_steps_log) # Do after wandb.init + logger, val_steps=args.val_steps_log + ) # Do after wandb.init if args.eval: if args.eval == "val": eval_loader = val_loader else: # Test eval_loader = torch.utils.data.DataLoader( - WeatherDataset( - args.dataset, - pred_length=max_pred_length, - split="test", - subsample_step=args.step_length, - subset=bool(args.subset_ds), - ), + WeatherDataset(), args.batch_size, shuffle=False, num_workers=args.n_workers, From af076feb613165c7596a79e47d987867015cfd4f Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 8 May 2024 14:08:46 +0200 Subject: [PATCH 06/26] Fixed calls to new WeatherDataModule Class --- train_model.py | 39 ++++++++------------------------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/train_model.py b/train_model.py index 23a0330c..a303132f 100644 --- a/train_model.py +++ b/train_model.py @@ -13,7 +13,7 @@ from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel -from neural_lam.weather_dataset import WeatherDataset +from neural_lam.weather_dataset import WeatherDataModule MODELS = { "graph_lam": GraphLAM, @@ -189,6 +189,8 @@ def main(): help="Number of example predictions to plot during evaluation " "(default: 1)", ) + + # Logging Options parser.add_argument( "--wandb_project", type=str, @@ -229,18 +231,9 @@ def main(): # Set seed seed.seed_everything(args.seed) - - # Load data - train_loader = torch.utils.data.DataLoader( - WeatherDataset(control_only=args.control_only), - args.batch_size, - shuffle=True, - num_workers=args.n_workers, - ) - val_loader = torch.utils.data.DataLoader( - WeatherDataset(control_only=args.control_only), - args.batch_size, - shuffle=False, + # Create datamodule + data_module = WeatherDataModule( + batch_size=args.batch_size, num_workers=args.n_workers, ) @@ -300,25 +293,9 @@ def main(): ) # Do after wandb.init if args.eval: - if args.eval == "val": - eval_loader = val_loader - else: # Test - eval_loader = torch.utils.data.DataLoader( - WeatherDataset(), - args.batch_size, - shuffle=False, - num_workers=args.n_workers, - ) - - print(f"Running evaluation on {args.eval}") - trainer.test(model=model, dataloaders=eval_loader) + trainer.test(model=model, datamodule=data_module, ckpt_path=args.load) else: - # Train model - trainer.fit( - model=model, - train_dataloaders=train_loader, - val_dataloaders=val_loader, - ) + trainer.fit(model=model, datamodule=data_module) if __name__ == "__main__": From 147caec913d8d490bdb5c155f1c61ef060d3a3d9 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 8 May 2024 14:12:43 +0200 Subject: [PATCH 07/26] fix linter --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 5a2111b2..0a921225 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ Cartopy>=0.22.0 pyproj>=3.4.1 tueplots>=0.0.8 plotly>=5.15.0 +xarray>=0.20.1 # for dev codespell>=2.0.0 black>=21.9b0 From 2b65416336b29d6f3347bd4b89284b0b867659df Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 8 May 2024 14:48:27 +0200 Subject: [PATCH 08/26] upload data config to wandb for history logs --- train_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_model.py b/train_model.py index a303132f..1839474b 100644 --- a/train_model.py +++ b/train_model.py @@ -8,6 +8,8 @@ import torch from lightning_fabric.utilities import seed +import wandb + # First-party from neural_lam import utils from neural_lam.models.graph_lam import GraphLAM @@ -291,7 +293,7 @@ def main(): utils.init_wandb_metrics( logger, val_steps=args.val_steps_log ) # Do after wandb.init - + wandb.save(args.data_config) if args.eval: trainer.test(model=model, datamodule=data_module, ckpt_path=args.load) else: From ed9ed696a6256b63bd0283c5057d0d12424dc91a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 8 May 2024 16:59:23 +0200 Subject: [PATCH 09/26] Improved handling of static data --- neural_lam/data_config.yaml | 19 ++- neural_lam/models/ar_model.py | 13 +-- neural_lam/utils.py | 211 ++++++++++++++++++++++------------ neural_lam/weather_dataset.py | 150 +----------------------- 4 files changed, 164 insertions(+), 229 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index 8d936154..6c4536f5 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -72,8 +72,12 @@ state: # Variables forecasted by the model static: # Static inputs surface: - HSURF + - lat + - lon surface_units: - m + - °N + - °E atmosphere: - FI atmosphere_units: @@ -127,4 +131,17 @@ projection: kwargs: # Parsed and used directly as kwargs to projection-class above pole_longitude: 10.0 pole_latitude: -43.0 -normalization_zarr: /scratch/sadamov/norm.zarr +normalization: + zarr: /scratch/sadamov/norm.zarr + vars: + data_mean: data_mean + data_std: data_std + forcing_mean: forcing_mean + forcing_std: forcing_std + boundary_mean: boundary_mean + boundary_std: boundary_std + diff_mean: diff_mean + diff_std: diff_std + grid_static_features: grid_static_features + param_weights: param_weights + diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 93f2b569..04679022 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -6,11 +6,11 @@ import numpy as np import pytorch_lightning as pl import torch + import wandb # First-party from neural_lam import metrics, utils, vis -from neural_lam.weather_dataset import ConfigLoader class ARModel(pl.LightningModule): @@ -26,14 +26,11 @@ def __init__(self, args): super().__init__() self.save_hyperparameters() self.lr = args.lr - self.config_loader = ConfigLoader(args.data_config) + self.config_loader = utils.ConfigLoader(args.data_config) # Load static features for grid/data - static_data_dict = utils.load_static_data(args.dataset) - for static_data_name, static_data_tensor in static_data_dict.items(): - self.register_buffer( - static_data_name, static_data_tensor, persistent=False - ) + static = self.config_loader.process_dataset("static") + self.register_buffer("grid_static_features", torch.tensor(static.values)) # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) @@ -59,7 +56,7 @@ def __init__(self, args): ( self.num_grid_nodes, grid_static_dim, - ) = self.grid_static_features.shape # 63784 = 268x238 + ) = self.grid_static_features.shape self.grid_dim = ( 2 * self.config_loader.num_data_vars("state") + grid_static_dim diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 836b04ed..f4c34141 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,87 +1,16 @@ # Standard library import os +import cartopy.crs as ccrs + # Third-party -import numpy as np import torch +import xarray as xr +import yaml from torch import nn from tueplots import bundles, figsizes -def load_dataset_stats(dataset_name, device="cpu"): - """ - Load arrays with stored dataset statistics from pre-processing - """ - static_dir_path = os.path.join("data", dataset_name, "static") - - def loads_file(fn): - return torch.load( - os.path.join(static_dir_path, fn), map_location=device - ) - - data_mean = loads_file("parameter_mean.pt") # (d_features,) - data_std = loads_file("parameter_std.pt") # (d_features,) - - flux_stats = loads_file("flux_stats.pt") # (2,) - flux_mean, flux_std = flux_stats - - return { - "data_mean": data_mean, - "data_std": data_std, - "flux_mean": flux_mean, - "flux_std": flux_std, - } - - -def load_static_data(dataset_name, device="cpu"): - """ - Load static files related to dataset - """ - static_dir_path = os.path.join("data", dataset_name, "static") - - def loads_file(fn): - return torch.load( - os.path.join(static_dir_path, fn), map_location=device - ) - - # Load border mask, 1. if node is part of border, else 0. - border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy")) - border_mask = ( - torch.tensor(border_mask_np, dtype=torch.float32, device=device) - .flatten(0, 1) - .unsqueeze(1) - ) # (N_grid, 1) - - grid_static_features = loads_file( - "grid_features.pt" - ) # (N_grid, d_grid_static) - - # Load step diff stats - step_diff_mean = loads_file("diff_mean.pt") # (d_f,) - step_diff_std = loads_file("diff_std.pt") # (d_f,) - - # Load parameter std for computing validation errors in original data scale - data_mean = loads_file("parameter_mean.pt") # (d_features,) - data_std = loads_file("parameter_std.pt") # (d_features,) - - # Load loss weighting vectors - param_weights = torch.tensor( - np.load(os.path.join(static_dir_path, "parameter_weights.npy")), - dtype=torch.float32, - device=device, - ) # (d_f,) - - return { - "border_mask": border_mask, - "grid_static_features": grid_static_features, - "step_diff_mean": step_diff_mean, - "step_diff_std": step_diff_std, - "data_mean": data_mean, - "data_std": data_std, - "param_weights": param_weights, - } - - class BufferList(nn.Module): """ A list of torch buffer tensors that sit together as a Module with no @@ -268,3 +197,135 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +class ConfigLoader: + """ + Class for loading configuration files. + + This class loads a YAML configuration file and provides a way to access + its values as attributes. + """ + + def __init__(self, config_path, values=None): + self.config_path = config_path + if values is None: + self.values = self.load_config() + else: + self.values = values + + def load_config(self): + """Load configuration file.""" + with open(self.config_path, encoding="utf-8", mode="r") as file: + return yaml.safe_load(file) + + def __getattr__(self, name): + keys = name.split(".") + value = self.values + for key in keys: + if key in value: + value = value[key] + else: + return None + if isinstance(value, dict): + return ConfigLoader(None, values=value) + return value + + def __getitem__(self, key): + value = self.values[key] + if isinstance(value, dict): + return ConfigLoader(None, values=value) + return value + + def __contains__(self, key): + return key in self.values + + def param_names(self): + """Return parameter names.""" + return self.values["state"]["surface"] + self.values["state"]["atmosphere"] + + def param_units(self): + """Return parameter units.""" + return ( + self.values["state"]["surface_units"] + + self.values["state"]["atmosphere_units"] + ) + + def num_data_vars(self, key): + """Return the number of data variables for a given key.""" + surface_vars = len(self.values[key]["surface"]) + atmosphere_vars = len(self.values[key]["atmosphere"]) + levels = len(self.values[key]["levels"]) + return surface_vars + atmosphere_vars * levels + + def projection(self): + """Return the projection.""" + proj_config = self.values["projections"]["class"] + proj_class = getattr(ccrs, proj_config["proj_class"]) + proj_params = proj_config["proj_params"] + return proj_class(**proj_params) + + def open_zarr(self, dataset_name, split): + """Open a dataset specified by the dataset name.""" + dataset_path = self.zarrs[dataset_name].path + if dataset_path is None or not os.path.exists(dataset_path): + print(f"Dataset '{dataset_name}' not found at path: {dataset_path}") + return None + dataset = xr.open_zarr(dataset_path, consolidated=True) + return dataset + + def process_dataset(self, dataset_name, split): + """Process a single dataset specified by the dataset name.""" + + dataset = self.open_zarr(dataset_name, split) + + start, end = ( + self.splits[split].start, + self.splits[split].end, + ) + dataset = dataset.sel(time=slice(start, end)) + dataset = dataset.rename_dims( + { + v: k + for k, v in self.zarrs[ + dataset_name + ].dims.values.items() + if k not in dataset.dims + } + ) + if "grid" not in dataset.dims: + dataset = dataset.stack(grid=("x", "y")) + + vars_surface = [] + if self[dataset_name].surface: + vars_surface = dataset[self[dataset_name].surface] + + vars_atmosphere = [] + if self[dataset_name].atmosphere: + vars_atmosphere = xr.merge( + [ + dataset[var] + .sel(level=level, drop=True) + .rename(f"{var}_{level}") + for var in self[dataset_name].atmosphere + for level in self[dataset_name].levels + ] + ) + + if vars_surface and vars_atmosphere: + dataset = xr.merge([vars_surface, vars_atmosphere]) + elif vars_surface: + dataset = vars_surface + elif vars_atmosphere: + dataset = vars_atmosphere + else: + print("No variables found in dataset {dataset_name}") + return None + + if "time" in dataset.dims: + dataset = dataset.squeeze( + drop=True).to_array().transpose( + "time", "grid", "variable") + else: + dataset = dataset.to_array().transpose("grid", "variable") + return dataset diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 28c29db6..87aa3f56 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,81 +1,7 @@ -# Standard library -import os - -# Third-party -import cartopy.crs as ccrs import pytorch_lightning as pl import torch -import xarray as xr -import yaml - - -class ConfigLoader: - """ - Class for loading configuration files. - - This class loads a YAML configuration file and provides a way to access - its values as attributes. - """ - - def __init__(self, config_path, values=None): - self.config_path = config_path - if values is None: - self.values = self.load_config() - else: - self.values = values - - def load_config(self): - """Load configuration file.""" - with open(self.config_path, encoding="utf-8", mode="r") as file: - return yaml.safe_load(file) - - def __getattr__(self, name): - keys = name.split(".") - value = self.values - for key in keys: - if key in value: - value = value[key] - else: - return None - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __getitem__(self, key): - value = self.values[key] - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __contains__(self, key): - return key in self.values - - def param_names(self): - """Return parameter names.""" - return ( - self.values["state"]["surface"] + self.values["state"]["atmosphere"] - ) - - def param_units(self): - """Return parameter units.""" - return ( - self.values["state"]["surface_units"] - + self.values["state"]["atmosphere_units"] - ) - - def num_data_vars(self, key): - """Return the number of data variables for a given key.""" - surface_vars = len(self.values[key]["surface"]) - atmosphere_vars = len(self.values[key]["atmosphere"]) - levels = len(self.values[key]["levels"]) - return surface_vars + atmosphere_vars * levels - def projection(self): - """Return the projection.""" - proj_config = self.values["projections"]["class"] - proj_class = getattr(ccrs, proj_config["proj_class"]) - proj_params = proj_config["proj_params"] - return proj_class(**proj_params) +from neural_lam import utils class WeatherDataset(torch.utils.data.Dataset): @@ -87,65 +13,6 @@ class WeatherDataset(torch.utils.data.Dataset): validation, and test sets. """ - def process_dataset(self, dataset_name): - """Process a single dataset specified by the dataset name.""" - - dataset_path = self.config_loader.zarrs[dataset_name].path - if dataset_path is None or not os.path.exists(dataset_path): - print(f"Dataset '{dataset_name}' not found at path: {dataset_path}") - return None - dataset = xr.open_zarr(dataset_path, consolidated=True) - - start, end = ( - self.config_loader.splits[self.split].start, - self.config_loader.splits[self.split].end, - ) - dataset = dataset.sel(time=slice(start, end)) - dataset = dataset.rename_dims( - { - v: k - for k, v in self.config_loader.zarrs[ - dataset_name - ].dims.values.items() - if k not in dataset.dims - } - ) - if "grid" not in dataset.dims: - dataset = dataset.stack(grid=("x", "y")) - - vars_surface = [] - if self.config_loader[dataset_name].surface: - vars_surface = dataset[self.config_loader[dataset_name].surface] - - vars_atmosphere = [] - if self.config_loader[dataset_name].atmosphere: - vars_atmosphere = xr.merge( - [ - dataset[var] - .sel(level=level, drop=True) - .rename(f"{var}_{level}") - for var in self.config_loader[dataset_name].atmosphere - for level in self.config_loader[dataset_name].levels - ] - ) - - if vars_surface and vars_atmosphere: - dataset = xr.merge([vars_surface, vars_atmosphere]) - elif vars_surface: - dataset = vars_surface - elif vars_atmosphere: - dataset = vars_atmosphere - else: - print("No variables found in dataset {dataset_name}") - return None - - dataset = dataset.squeeze(drop=True).to_array() - if "time" in dataset.dims: - dataset = dataset.transpose("time", "grid", "variable") - else: - dataset = dataset.transpose("grid", "variable") - return dataset - def __init__( self, split="train", @@ -166,19 +33,12 @@ def __init__( self.batch_size = batch_size self.ar_steps = ar_steps self.control_only = control_only - self.config_loader = ConfigLoader(data_config) + self.config_loader = utils.ConfigLoader(data_config) - self.state = self.process_dataset("state") + self.state = self.config_loader("state", self.split) assert self.state is not None, "State dataset not found" - self.static = self.process_dataset("static") - self.forcings = self.process_dataset("forcing") - self.boundary = self.process_dataset("boundary") - - if self.static is not None: - self.static = self.static.expand_dims( - {"time": self.state.time}, axis=0 - ) - self.state = xr.concat([self.state, self.static], dim="variable") + self.forcings = self.config_loader("forcing", self.split) + self.boundary = self.config_loader("boundary", self.split) def __len__(self): return len(self.state.time) - self.ar_steps From 0b69f4e04466d52e9ee6e36df7ed5accc93e915e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 8 May 2024 23:04:36 +0200 Subject: [PATCH 10/26] dask and zarr are required backends to xarray --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 0a921225..cb9bd425 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,8 @@ pyproj>=3.4.1 tueplots>=0.0.8 plotly>=5.15.0 xarray>=0.20.1 +zarr>=2.10.0 +dask>=2022.0.0 # for dev codespell>=2.0.0 black>=21.9b0 From b76d078a66ab5635e82abfcacdb0befd80336181 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 07:00:41 +0200 Subject: [PATCH 11/26] Implements windowed forcing and boundary --- create_grid_features.py | 59 ----------------------------------- neural_lam/data_config.yaml | 13 +++++++- neural_lam/models/ar_model.py | 17 +++++----- neural_lam/utils.py | 29 +++++++++-------- neural_lam/weather_dataset.py | 57 ++++++++++++++++++++++++++------- plot_graph.py | 18 +++++------ train_model.py | 3 +- 7 files changed, 88 insertions(+), 108 deletions(-) delete mode 100644 create_grid_features.py diff --git a/create_grid_features.py b/create_grid_features.py deleted file mode 100644 index c9038103..00000000 --- a/create_grid_features.py +++ /dev/null @@ -1,59 +0,0 @@ -# Standard library -import os -from argparse import ArgumentParser - -# Third-party -import numpy as np -import torch - - -def main(): - """ - Pre-compute all static features related to the grid nodes - """ - parser = ArgumentParser(description="Training arguments") - parser.add_argument( - "--dataset", - type=str, - default="meps_example", - help="Dataset to compute weights for (default: meps_example)", - ) - args = parser.parse_args() - - static_dir_path = os.path.join("data", args.dataset, "static") - - # -- Static grid node features -- - grid_xy = torch.tensor( - np.load(os.path.join(static_dir_path, "nwp_xy.npy")) - ) # (2, N_x, N_y) - grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2) - pos_max = torch.max(torch.abs(grid_xy)) - grid_xy = grid_xy / pos_max # Divide by maximum coordinate - - geopotential = torch.tensor( - np.load(os.path.join(static_dir_path, "surface_geopotential.npy")) - ) # (N_x, N_y) - geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1) - gp_min = torch.min(geopotential) - gp_max = torch.max(geopotential) - # Rescale geopotential to [0,1] - geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1) - - grid_border_mask = torch.tensor( - np.load(os.path.join(static_dir_path, "border_mask.npy")), - dtype=torch.int64, - ) # (N_x, N_y) - grid_border_mask = ( - grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1) - ) # (N_grid, 1) - - # Concatenate grid features - grid_features = torch.cat( - (grid_xy, geopotential, grid_border_mask), dim=1 - ) # (N_grid, 4) - - torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt")) - - -if __name__ == "__main__": - main() diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index 6c4536f5..cdfb57dc 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -102,14 +102,26 @@ forcing: # Forcing variables, dynamic inputs to the model surface_units: - W/m^2 atmosphere: + - T atmosphere_units: + - K levels: + - 0 + - 5 + - 8 + - 11 + - 13 + - 38 + - 44 + - 59 + window: 3 # Number of time steps to use for forcing (odd) boundary: # Boundary conditions surface: surface_units: atmosphere: atmosphere_units: levels: + window: 3 # Number of time steps to use for boundary (odd) lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells lat: lat lon: lon @@ -144,4 +156,3 @@ normalization: diff_std: diff_std grid_static_features: grid_static_features param_weights: param_weights - diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 04679022..8353327d 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -6,7 +6,6 @@ import numpy as np import pytorch_lightning as pl import torch - import wandb # First-party @@ -29,19 +28,19 @@ def __init__(self, args): self.config_loader = utils.ConfigLoader(args.data_config) # Load static features for grid/data - static = self.config_loader.process_dataset("static") - self.register_buffer("grid_static_features", torch.tensor(static.values)) + static = self.config_loader.process_dataset("static", self.split) + self.register_buffer( + "grid_static_features", torch.tensor(static.values) + ) # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) if self.output_std: - self.grid_output_dim = 2 * self.config_loader.num_data_vars( - "state" - ) # Pred. dim. in grid cell + # Pred. dim. in grid cell + self.grid_output_dim = 2 * self.config_loader.num_data_vars("state") else: - self.grid_output_dim = self.config_loader.num_data_vars( - "state" - ) # Pred. dim. in grid cell + # Pred. dim. in grid cell + self.grid_output_dim = self.config_loader.num_data_vars("state") # Store constant per-variable std.-dev. weighting # Note that this is the inverse of the multiplicative weighting diff --git a/neural_lam/utils.py b/neural_lam/utils.py index f4c34141..3992bc6c 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,9 +1,8 @@ # Standard library import os -import cartopy.crs as ccrs - # Third-party +import cartopy.crs as ccrs import torch import xarray as xr import yaml @@ -242,7 +241,9 @@ def __contains__(self, key): def param_names(self): """Return parameter names.""" - return self.values["state"]["surface"] + self.values["state"]["atmosphere"] + return ( + self.values["state"]["surface"] + self.values["state"]["atmosphere"] + ) def param_units(self): """Return parameter units.""" @@ -265,7 +266,7 @@ def projection(self): proj_params = proj_config["proj_params"] return proj_class(**proj_params) - def open_zarr(self, dataset_name, split): + def open_zarr(self, dataset_name): """Open a dataset specified by the dataset name.""" dataset_path = self.zarrs[dataset_name].path if dataset_path is None or not os.path.exists(dataset_path): @@ -274,10 +275,12 @@ def open_zarr(self, dataset_name, split): dataset = xr.open_zarr(dataset_path, consolidated=True) return dataset - def process_dataset(self, dataset_name, split): + def process_dataset(self, dataset_name, split="train"): """Process a single dataset specified by the dataset name.""" - dataset = self.open_zarr(dataset_name, split) + dataset = self.open_zarr(dataset_name) + if dataset is None: + return None start, end = ( self.splits[split].start, @@ -287,14 +290,10 @@ def process_dataset(self, dataset_name, split): dataset = dataset.rename_dims( { v: k - for k, v in self.zarrs[ - dataset_name - ].dims.values.items() + for k, v in self.zarrs[dataset_name].dims.values.items() if k not in dataset.dims } ) - if "grid" not in dataset.dims: - dataset = dataset.stack(grid=("x", "y")) vars_surface = [] if self[dataset_name].surface: @@ -322,10 +321,10 @@ def process_dataset(self, dataset_name, split): print("No variables found in dataset {dataset_name}") return None + dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() + if "time" in dataset.dims: - dataset = dataset.squeeze( - drop=True).to_array().transpose( - "time", "grid", "variable") + dataset = dataset.transpose("time", "grid", "variable") else: - dataset = dataset.to_array().transpose("grid", "variable") + dataset = dataset.transpose("grid", "variable") return dataset diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 87aa3f56..03c8114f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,6 +1,8 @@ +# Third-party import pytorch_lightning as pl import torch +# First-party from neural_lam import utils @@ -35,34 +37,64 @@ def __init__( self.control_only = control_only self.config_loader = utils.ConfigLoader(data_config) - self.state = self.config_loader("state", self.split) + self.state = self.config_loader.process_dataset("state", self.split) assert self.state is not None, "State dataset not found" - self.forcings = self.config_loader("forcing", self.split) - self.boundary = self.config_loader("boundary", self.split) + self.forcings = self.config_loader.process_dataset( + "forcing", self.split + ) + self.boundary = self.config_loader.process_dataset( + "boundary", self.split + ) + + self.state_times = self.state.time.values + self.forcing_window = self.config_loader.forcing.window + self.boundary_window = self.config_loader.boundary.window + self.idx_max = max( + (self.boundary_window - 1), (self.forcing_window - 1) + ) + + if self.forcings is not None: + self.forcings_windowed = ( + self.forcings.sel( + time=self.forcings.time.isin(self.state.time), + method="nearest", + ) + .rolling(time=self.forcing_window, center=True) + .construct("window") + ) + if self.boundary is not None: + self.boundary_windowed = ( + self.boundary.sel( + time=self.forcings.time.isin(self.state.time), + method="nearest", + ) + .rolling(time=self.boundary_window, center=True) + .construct("window") + ) def __len__(self): - return len(self.state.time) - self.ar_steps + # Skip first and last time step + return len(self.state.time) - self.ar_steps - self.idx_max def __getitem__(self, idx): + idx += self.idx_max / 2 # Skip first time step sample = torch.tensor( self.state.isel(time=slice(idx, idx + self.ar_steps)).values, dtype=torch.float32, ) forcings = ( - torch.tensor( - self.forcings.isel(time=slice(idx, idx + self.ar_steps)).values, - dtype=torch.float32, - ) + self.forcings_windowed.isel(time=slice(idx, idx + self.ar_steps)) + .stack(variable_window=("variable", "window")) + .values if self.forcings is not None else torch.tensor([]) ) boundary = ( - torch.tensor( - self.boundary.isel(time=slice(idx, idx + self.ar_steps)).values, - dtype=torch.float32, - ) + self.boundary_windowed.isel(time=slice(idx, idx + self.ar_steps)) + .stack(variable_window=("variable", "window")) + .values if self.boundary is not None else torch.tensor([]) ) @@ -153,4 +185,5 @@ def test_dataloader(self): print(batch[2].shape) print(batch[3].shape) print(batch[4]) + print(batch[2][0, 0, 0, :]) break diff --git a/plot_graph.py b/plot_graph.py index 48427d5c..c82b4e04 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -19,12 +19,6 @@ def main(): Plot graph structure in 3D using plotly """ parser = ArgumentParser(description="Plot graph") - parser.add_argument( - "--dataset", - type=str, - default="meps_example", - help="Datast to load grid coordinates from (default: meps_example)", - ) parser.add_argument( "--graph", type=str, @@ -42,6 +36,12 @@ def main(): default=0, help="If the axis should be displayed (default: 0 (No))", ) + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) args = parser.parse_args() @@ -62,10 +62,8 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] - grid_static_features = utils.load_static_data(args.dataset)[ - "grid_static_features" - ] - + config_loader = utils.ConfigLoader(args.data_config) + grid_static_features = config_loader.process_dataset("static") # Extract values needed, turn to numpy grid_pos = grid_static_features[:, :2].numpy() # Add in z-dimension diff --git a/train_model.py b/train_model.py index 1839474b..4f57ca24 100644 --- a/train_model.py +++ b/train_model.py @@ -6,9 +6,8 @@ # Third-party import pytorch_lightning as pl import torch -from lightning_fabric.utilities import seed - import wandb +from lightning_fabric.utilities import seed # First-party from neural_lam import utils From 5d27a4ce21c8894a69711f324bc14087f95fae8f Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 07:01:09 +0200 Subject: [PATCH 12/26] Some project related stuff (simple setup to pip install -e .) --- .gitignore | 1 + pyproject.toml | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 7bb826a2..590c7e12 100644 --- a/.gitignore +++ b/.gitignore @@ -72,3 +72,4 @@ tags # Coc configuration directory .vim +.vscode diff --git a/pyproject.toml b/pyproject.toml index b513a258..619f444f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,10 @@ +[project] +name = "neural_lam" +version = "0.1.0" + +[tool.setuptools] +packages = ["neural_lam"] + [tool.black] line-length = 80 @@ -42,12 +49,9 @@ ignore = [ "create_mesh.py", # Disable linting for now, as major rework is planned/expected ] # Temporary fix for import neural_lam statements until set up as proper package -init-hook='import sys; sys.path.append(".")' +init-hook = 'import sys; sys.path.append(".")' [tool.pylint.TYPECHECK] -generated-members = [ - "numpy.*", - "torch.*", -] +generated-members = ["numpy.*", "torch.*"] [tool.pylint.'MESSAGES CONTROL'] disable = [ "C0114", # 'missing-module-docstring', Do not require module docstrings @@ -56,10 +60,11 @@ disable = [ "R0913", # 'too-many-arguments', Allow many function arguments "R0914", # 'too-many-locals', Allow many local variables "W0223", # 'abstract-method', Subclasses do not have to override all abstract methods + "C0411", # 'wrong-import-order', Allow for isort to handle import order ] [tool.pylint.DESIGN] -max-statements=100 # Allow for some more involved functions +max-statements = 100 # Allow for some more involved functions [tool.pylint.IMPORTS] -allow-any-import-level="neural_lam" +allow-any-import-level = "neural_lam" [tool.pylint.SIMILARITIES] -min-similarity-lines=10 +min-similarity-lines = 10 From 4dadf2985591f7cacc536f906f8bad2eab98878a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 08:22:37 +0200 Subject: [PATCH 13/26] introducing realistic boundaries --- neural_lam/data_config.yaml | 75 ++++++++++++++++++++++++++++++++++- neural_lam/utils.py | 2 +- neural_lam/weather_dataset.py | 15 +++++-- 3 files changed, 87 insertions(+), 5 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index cdfb57dc..e6a0d506 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -20,7 +20,12 @@ zarrs: # List of zarrs containing fields related to state x: x y: y boundary: - path: + path: /scratch/sadamov/era5.zarr + dims: + time: time + level: level + x: longitude + y: latitude mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary. state: # Variables forecasted by the model surface: # Single-field variables @@ -111,16 +116,84 @@ forcing: # Forcing variables, dynamic inputs to the model - 8 - 11 - 13 + - 15 + - 19 + - 22 + - 26 + - 30 - 38 - 44 - 59 window: 3 # Number of time steps to use for forcing (odd) boundary: # Boundary conditions surface: + - 10m_u_component_of_wind + # - 10m_v_component_of_wind + # - 2m_dewpoint_temperature + # - 2m_temperature + # - mean_sea_level_pressure + # - mean_surface_latent_heat_flux + # - mean_surface_net_long_wave_radiation_flux + # - mean_surface_net_short_wave_radiation_flux + # - mean_surface_sensible_heat_flux + # - surface_pressure + # - total_cloud_cover + # - total_column_water_vapour + # - total_precipitation_12hr + # - total_precipitation_24hr + # - total_precipitation_6hr + # - geopotential_at_surface surface_units: + - m/s + # - m/s + # - K + # - K + # - Pa + # - W/m^2 + # - W/m^2 + # - W/m^2 + # - W/m^2 + # - Pa + # - "%" + # - kg/m^2 + # - kg/m^2 + # - kg/m^2 + # - kg/m^2 + # - m^2/s^2 atmosphere: + - divergence + # - geopotential + # - relative_humidity + # - specific_humidity + # - temperature + # - u_component_of_wind + # - v_component_of_wind + # - vertical_velocity + # - vorticity atmosphere_units: + - 1/s + # - m^2/s^2 + # - "%" + # - kg/kg + # - K + # - m/s + # - m/s + # - m/s + # - 1/s levels: + - 50 + - 100 + - 150 + - 200 + - 250 + - 300 + - 400 + - 500 + - 600 + - 700 + - 850 + - 925 + - 1000 window: 3 # Number of time steps to use for boundary (odd) lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells lat: lat diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 3992bc6c..b4855eff 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -318,7 +318,7 @@ def process_dataset(self, dataset_name, split="train"): elif vars_atmosphere: dataset = vars_atmosphere else: - print("No variables found in dataset {dataset_name}") + print(f"No variables found in dataset {dataset_name}") return None dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 03c8114f..d6662cfb 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -56,18 +56,27 @@ def __init__( if self.forcings is not None: self.forcings_windowed = ( self.forcings.sel( - time=self.forcings.time.isin(self.state.time), + time=self.state.time, method="nearest", ) + .pad( + time=(self.forcing_window // 2, self.forcing_window // 2), + mode="edge", + ) .rolling(time=self.forcing_window, center=True) .construct("window") ) + if self.boundary is not None: self.boundary_windowed = ( self.boundary.sel( - time=self.forcings.time.isin(self.state.time), + time=self.state.time, method="nearest", ) + .pad( + time=(self.boundary_window // 2, self.boundary_window // 2), + mode="edge", + ) .rolling(time=self.boundary_window, center=True) .construct("window") ) @@ -77,7 +86,7 @@ def __len__(self): return len(self.state.time) - self.ar_steps - self.idx_max def __getitem__(self, idx): - idx += self.idx_max / 2 # Skip first time step + idx += self.idx_max // 2 # Skip first time step sample = torch.tensor( self.state.isel(time=slice(idx, idx + self.ar_steps)).values, dtype=torch.float32, From 7524c4ddad149a16b4bdc8cfa8e4f41e967b0223 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 12:48:46 +0200 Subject: [PATCH 14/26] Adapted nwp_xy related code to new data loading procedure --- .gitignore | 1 + create_mesh.py | 24 ++++++++++++++---------- neural_lam/utils.py | 35 ++++++++++++++++++++++++++++++++++- neural_lam/vis.py | 8 ++++---- plot_graph.py | 8 +++++--- 5 files changed, 58 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 590c7e12..1ecd1dfe 100644 --- a/.gitignore +++ b/.gitignore @@ -73,3 +73,4 @@ tags # Coc configuration directory .vim .vscode +cosmo_hilam.html diff --git a/create_mesh.py b/create_mesh.py index cb524cd6..2b6af9fd 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -12,6 +12,11 @@ import torch_geometric as pyg from torch_geometric.utils.convert import from_networkx +# First-party +from neural_lam import utils + +# matplotlib.use('TkAgg') + def plot_graph(graph, title=None): fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H @@ -152,13 +157,6 @@ def prepend_node_index(graph, new_index): def main(): parser = ArgumentParser(description="Graph generation arguments") - parser.add_argument( - "--dataset", - type=str, - default="meps_example", - help="Dataset to load grid point coordinates from " - "(default: meps_example)", - ) parser.add_argument( "--graph", type=str, @@ -184,15 +182,21 @@ def main(): default=0, help="Generate hierarchical mesh graph (default: 0, no)", ) + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) + args = parser.parse_args() # Load grid positions - static_dir_path = os.path.join("data", args.dataset, "static") graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) - xy = np.load(os.path.join(static_dir_path, "nwp_xy.npy")) - + config_loader = utils.ConfigLoader(args.data_config) + xy = config_loader.get_nwp_xy() grid_xy = torch.tensor(xy) pos_max = torch.max(torch.abs(grid_xy)) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index b4855eff..172cef95 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -3,6 +3,7 @@ # Third-party import cartopy.crs as ccrs +import numpy as np import torch import xarray as xr import yaml @@ -275,7 +276,7 @@ def open_zarr(self, dataset_name): dataset = xr.open_zarr(dataset_path, consolidated=True) return dataset - def process_dataset(self, dataset_name, split="train"): + def process_dataset(self, dataset_name, split="train", stack=True): """Process a single dataset specified by the dataset name.""" dataset = self.open_zarr(dataset_name) @@ -321,6 +322,29 @@ def process_dataset(self, dataset_name, split="train"): print(f"No variables found in dataset {dataset_name}") return None + if not all( + lat_lon in self.zarrs[dataset_name].dims.values.values() + for lat_lon in self.zarrs[ + dataset_name + ].lat_lon_names.values.values() + ): + lat_name = self.zarrs[dataset_name].lat_lon_names.lat + lon_name = self.zarrs[dataset_name].lat_lon_names.lon + if dataset[lat_name].ndim == 2: + dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True) + if dataset[lon_name].ndim == 2: + dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True) + dataset = dataset.assign_coords( + x=dataset[lon_name], y=dataset[lat_name] + ) + + if stack: + dataset = self.stack_grid(dataset) + + return dataset + + def stack_grid(self, dataset): + """Stack grid dimensions.""" dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() if "time" in dataset.dims: @@ -328,3 +352,12 @@ def process_dataset(self, dataset_name, split="train"): else: dataset = dataset.transpose("grid", "variable") return dataset + + def get_nwp_xy(self): + """Get the x and y coordinates for the NWP grid.""" + x = self.process_dataset("static", stack=False).x.values + y = self.process_dataset("static", stack=False).y.values + xx, yy = np.meshgrid(y, x) + xy = np.stack((xx, yy), axis=0) + + return xy diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 02b8dd35..8c36a9a7 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -78,7 +78,7 @@ def plot_prediction( vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*data_config.grid_shape) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region @@ -93,7 +93,7 @@ def plot_prediction( # Plot pred and target for ax, data in zip(axes, (target, pred)): ax.coastlines() # Add coastline outlines - data_grid = data.reshape(*data_config.grid_shape).cpu().numpy() + data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( data_grid, origin="lower", @@ -129,7 +129,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*data_config.grid_shape) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region @@ -139,7 +139,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): ) ax.coastlines() # Add coastline outlines - error_grid = error.reshape(*data_config.grid_shape).cpu().numpy() + error_grid = error.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( error_grid, diff --git a/plot_graph.py b/plot_graph.py index c82b4e04..e246200d 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -63,9 +63,11 @@ def main(): mesh_static_features = graph_ldict["mesh_static_features"] config_loader = utils.ConfigLoader(args.data_config) - grid_static_features = config_loader.process_dataset("static") - # Extract values needed, turn to numpy - grid_pos = grid_static_features[:, :2].numpy() + xy = config_loader.get_nwp_xy() + grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2) # (N_grid, 2) + pos_max = np.max(np.abs(grid_xy)) + grid_pos = grid_xy / pos_max # Divide by maximum coordinate + # Add in z-dimension z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) grid_pos = np.concatenate( From 45fd375b3ff3c2760906b866c3c26e631162dbbf Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 12:49:53 +0200 Subject: [PATCH 15/26] only state requires units for plotting lat lon specifications make the code more flexible --- neural_lam/data_config.yaml | 56 +++++++++---------------------------- 1 file changed, 13 insertions(+), 43 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index e6a0d506..cce477ed 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -6,12 +6,18 @@ zarrs: # List of zarrs containing fields related to state level: z x: x # Either give "grid" (flattened) dimension or "x" and "y" y: y + lat_lon_names: + lon: lon + lat: lat static: path: /scratch/sadamov/template.zarr dims: level: z x: x y: y + lat_lon_names: + lon: lon + lat: lat forcing: path: /scratch/sadamov/template.zarr dims: @@ -19,6 +25,9 @@ zarrs: # List of zarrs containing fields related to state level: z x: x y: y + lat_lon_names: + lon: lon + lat: lat boundary: path: /scratch/sadamov/era5.zarr dims: @@ -26,6 +35,9 @@ zarrs: # List of zarrs containing fields related to state level: level x: longitude y: latitude + lat_lon_names: + lon: longitude + lat: latitude mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary. state: # Variables forecasted by the model surface: # Single-field variables @@ -77,16 +89,8 @@ state: # Variables forecasted by the model static: # Static inputs surface: - HSURF - - lat - - lon - surface_units: - - m - - °N - - °E atmosphere: - FI - atmosphere_units: - - m^2/s^2 levels: - 0 - 5 @@ -104,12 +108,8 @@ static: # Static inputs forcing: # Forcing variables, dynamic inputs to the model surface: - ASOB_S - surface_units: - - W/m^2 atmosphere: - T - atmosphere_units: - - K levels: - 0 - 5 @@ -143,23 +143,6 @@ boundary: # Boundary conditions # - total_precipitation_24hr # - total_precipitation_6hr # - geopotential_at_surface - surface_units: - - m/s - # - m/s - # - K - # - K - # - Pa - # - W/m^2 - # - W/m^2 - # - W/m^2 - # - W/m^2 - # - Pa - # - "%" - # - kg/m^2 - # - kg/m^2 - # - kg/m^2 - # - kg/m^2 - # - m^2/s^2 atmosphere: - divergence # - geopotential @@ -170,16 +153,6 @@ boundary: # Boundary conditions # - v_component_of_wind # - vertical_velocity # - vorticity - atmosphere_units: - - 1/s - # - m^2/s^2 - # - "%" - # - kg/kg - # - K - # - m/s - # - m/s - # - m/s - # - 1/s levels: - 50 - 100 @@ -195,10 +168,7 @@ boundary: # Boundary conditions - 925 - 1000 window: 3 # Number of time steps to use for boundary (odd) -lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells - lat: lat - lon: lon -grid_shape: +grid_shape_state: x: 582 y: 390 splits: From 812323ddce5a86cc1fef0f1305ad27d5dfce629f Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 19:23:04 +0200 Subject: [PATCH 16/26] small bugfixes and improvements --- .gitignore | 3 ++- neural_lam/data_config.yaml | 4 +--- neural_lam/weather_dataset.py | 33 +++++++++------------------------ train_model.py | 4 ++-- 4 files changed, 14 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 1ecd1dfe..08cc014e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ graphs *.sif sweeps test_*.sh +cosmo_hilam.html +normalization.zarr ### Python ### # Byte-compiled / optimized / DLL files @@ -73,4 +75,3 @@ tags # Coc configuration directory .vim .vscode -cosmo_hilam.html diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index cce477ed..faaabd32 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -29,7 +29,7 @@ zarrs: # List of zarrs containing fields related to state lon: lon lat: lat boundary: - path: /scratch/sadamov/era5.zarr + path: /scratch/sadamov/era5_template.zarr dims: time: time level: level @@ -197,5 +197,3 @@ normalization: boundary_std: boundary_std diff_mean: diff_mean diff_std: diff_std - grid_static_features: grid_static_features - param_weights: param_weights diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index d6662cfb..d51fb896 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -39,9 +39,7 @@ def __init__( self.state = self.config_loader.process_dataset("state", self.split) assert self.state is not None, "State dataset not found" - self.forcings = self.config_loader.process_dataset( - "forcing", self.split - ) + self.forcing = self.config_loader.process_dataset("forcing", self.split) self.boundary = self.config_loader.process_dataset( "boundary", self.split ) @@ -53,9 +51,9 @@ def __init__( (self.boundary_window - 1), (self.forcing_window - 1) ) - if self.forcings is not None: - self.forcings_windowed = ( - self.forcings.sel( + if self.forcing is not None: + self.forcing_windowed = ( + self.forcing.sel( time=self.state.time, method="nearest", ) @@ -92,11 +90,11 @@ def __getitem__(self, idx): dtype=torch.float32, ) - forcings = ( - self.forcings_windowed.isel(time=slice(idx, idx + self.ar_steps)) + forcing = ( + self.forcing_windowed.isel(time=slice(idx, idx + self.ar_steps)) .stack(variable_window=("variable", "window")) .values - if self.forcings is not None + if self.forcing is not None else torch.tensor([]) ) @@ -119,10 +117,10 @@ def __getitem__(self, idx): # init_states: (2, N_grid, d_features) # target_states: (ar_steps-2, N_grid, d_features) - # forcings: (ar_steps, N_grid, d_windowed_forcings) + # forcing: (ar_steps, N_grid, d_windowed_forcing) # boundary: (ar_steps, N_grid, d_windowed_boundary) # batch_times: (ar_steps,) - return init_states, target_states, forcings, boundary, batch_times + return init_states, target_states, forcing, boundary, batch_times class WeatherDataModule(pl.LightningDataModule): @@ -183,16 +181,3 @@ def test_dataloader(self): num_workers=self.num_workers, shuffle=False, ) - - -data_module = WeatherDataModule(batch_size=4, num_workers=0) -data_module.setup() -train_dataloader = data_module.train_dataloader() -for batch in train_dataloader: - print(batch[0].shape) - print(batch[1].shape) - print(batch[2].shape) - print(batch[3].shape) - print(batch[4]) - print(batch[2][0, 0, 0, :]) - break diff --git a/train_model.py b/train_model.py index 4f57ca24..e5dfd528 100644 --- a/train_model.py +++ b/train_model.py @@ -62,7 +62,7 @@ def main(): "--seed", type=int, default=42, help="random seed (default: 42)" ) parser.add_argument( - "--n_workers", + "--num_workers", type=int, default=4, help="Number of workers in data loader (default: 4)", @@ -235,7 +235,7 @@ def main(): # Create datamodule data_module = WeatherDataModule( batch_size=args.batch_size, - num_workers=args.n_workers, + num_workers=args.num_workers, ) # Instantiate model + trainer From 500f2fbeafd386844c813a7bc20a3bf65aed86f0 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 19:24:09 +0200 Subject: [PATCH 17/26] Calculate stats and store in zarr archive Zarr is registered to model buffer Normalization happens on device on_after_batch_transfer --- create_parameter_weights.py | 144 ++++++++++++++-------------------- neural_lam/models/ar_model.py | 33 ++++++++ neural_lam/utils.py | 18 ++++- 3 files changed, 108 insertions(+), 87 deletions(-) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 926d7741..1eda7a24 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -1,14 +1,13 @@ # Standard library -import os from argparse import ArgumentParser # Third-party -import numpy as np import torch +import xarray as xr from tqdm import tqdm # First-party -from neural_lam.weather_dataset import WeatherDataset +from neural_lam.weather_dataset import WeatherDataModule def main(): @@ -16,12 +15,6 @@ def main(): Pre-compute parameter weights to be used in loss function """ parser = ArgumentParser(description="Training arguments") - parser.add_argument( - "--dataset", - type=str, - default="meps_example", - help="Dataset to compute weights for (default: meps_example)", - ) parser.add_argument( "--batch_size", type=int, @@ -29,107 +22,77 @@ def main(): help="Batch size when iterating over the dataset", ) parser.add_argument( - "--step_length", - type=int, - default=3, - help="Step length in hours to consider single time step (default: 3)", - ) - parser.add_argument( - "--n_workers", + "--num_workers", type=int, default=4, help="Number of workers in data loader (default: 4)", ) + parser.add_argument( + "--zarr_path", + type=str, + default="normalization.zarr", + help="Directory where data is stored", + ) + args = parser.parse_args() - static_dir_path = os.path.join("data", args.dataset, "static") - - ds = WeatherDataset() - # Create parameter weights based on height - # based on fig A.1 in graph cast paper - w_dict = { - "2": 1.0, - "0": 0.1, - "65": 0.065, - "1000": 0.1, - "850": 0.05, - "500": 0.03, - } - w_list = np.array( - [w_dict[par.split("_")[-2]] for par in ds.config_loader.param_names()] - ) - print("Saving parameter weights...") - np.save( - os.path.join(static_dir_path, "parameter_weights.npy"), - w_list.astype("float32"), + data_module = WeatherDataModule( + batch_size=args.batch_size, num_workers=args.num_workers ) + data_module.setup() + loader = data_module.train_dataloader() # Load dataset without any subsampling - loader = torch.utils.data.DataLoader( - ds, args.batch_size, shuffle=False, num_workers=args.n_workers - ) - # Compute mean and std.-dev. of each parameter (+ flux forcing) + # Compute mean and std.-dev. of each parameter (+ forcing forcing) # across full dataset print("Computing mean and std.-dev. for parameters...") means = [] squares = [] - flux_means = [] - flux_squares = [] - for init_batch, target_batch, forcing_batch in tqdm(loader): + fb_means = {"forcing": [], "boundary": []} + fb_squares = {"forcing": [], "boundary": []} + + for init_batch, target_batch, forcing_batch, boundary_batch, _ in tqdm( + loader + ): batch = torch.cat( (init_batch, target_batch), dim=1 ) # (N_batch, N_t, N_grid, d_features) means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,) - squares.append( - torch.mean(batch**2, dim=(1, 2)) - ) # (N_batch, d_features,) + squares.append(torch.mean(batch**2, dim=(1, 2))) - # Flux at 1st windowed position is index 1 in forcing - flux_batch = forcing_batch[:, :, :, 1] - flux_means.append(torch.mean(flux_batch)) # (,) - flux_squares.append(torch.mean(flux_batch**2)) # (,) + for fb_type, fb_batch in zip( + ["forcing", "boundary"], [forcing_batch, boundary_batch] + ): + fb_batch = fb_batch[:, :, :, 1] + fb_means[fb_type].append(torch.mean(fb_batch)) # (,) + fb_squares[fb_type].append(torch.mean(fb_batch**2)) # (,) mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features) second_moment = torch.mean(torch.cat(squares, dim=0), dim=0) std = torch.sqrt(second_moment - mean**2) # (d_features) - flux_mean = torch.mean(torch.stack(flux_means)) # (,) - flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) - flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) - flux_stats = torch.stack((flux_mean, flux_std)) - - print("Saving mean, std.-dev, flux_stats...") - torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt")) - torch.save(std, os.path.join(static_dir_path, "parameter_std.pt")) - torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt")) + fb_stats = {} + for fb_type in ["forcing", "boundary"]: + fb_stats[f"{fb_type}_mean"] = torch.mean( + torch.stack(fb_means[fb_type]) + ) # (,) + fb_second_moment = torch.mean(torch.stack(fb_squares[fb_type])) # (,) + fb_stats[f"{fb_type}_std"] = torch.sqrt( + fb_second_moment - fb_stats[f"{fb_type}_mean"] ** 2 + ) # (,) # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") - ds_standard = WeatherDataset() # Re-load with standardization - loader_standard = torch.utils.data.DataLoader( - ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers - ) - used_subsample_len = (65 // args.step_length) * args.step_length - diff_means = [] diff_squares = [] - for init_batch, target_batch, _ in tqdm(loader_standard): - batch = torch.cat( - (init_batch, target_batch), dim=1 - ) # (N_batch, N_t', N_grid, d_features) - # Note: batch contains only 1h-steps - stepped_batch = torch.cat( - [ - batch[:, ss_i : used_subsample_len : args.step_length] - for ss_i in range(args.step_length) - ], - dim=0, - ) - # (N_batch', N_t, N_grid, d_features), - # N_batch' = args.step_length*N_batch - - batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] - # (N_batch', N_t-1, N_grid, d_features) + for init_batch, target_batch, _, _, _ in tqdm(loader): + # normalize the batch + init_batch = (init_batch - mean) / std + target_batch = (target_batch - mean) / std + + batch = torch.cat((init_batch, target_batch), dim=1) + batch_diffs = batch[:, 1:] - batch[:, :-1] + # (N_batch, N_t-1, N_grid, d_features) diff_means.append( torch.mean(batch_diffs, dim=(1, 2)) @@ -142,9 +105,20 @@ def main(): diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0) diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features) - print("Saving one-step difference mean and std.-dev...") - torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) - torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) + # Create xarray dataset + ds = xr.Dataset( + { + "mean": (["d_features"], mean), + "std": (["d_features"], std), + "diff_mean": (["d_features"], diff_mean), + "diff_std": (["d_features"], diff_std), + **fb_stats, + } + ) + + # Save dataset as Zarr + print("Saving dataset as Zarr...") + ds.to_zarr(args.zarr_path, mode="w") if __name__ == "__main__": diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 8353327d..8976990b 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -91,6 +91,17 @@ def __init__(self, args): # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + # Load normalization statistics + self.normalization_stats = self.config_loader.load_normalization_stats() + if self.normalization_stats is not None: + for ( + var_name, + var_data, + ) in self.normalization_stats.data_vars.items(): + self.register_buffer( + f"data_{var_name}", torch.tensor(var_data.values) + ) + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.lr, betas=(0.9, 0.95) @@ -195,6 +206,28 @@ def common_step(self, batch): return prediction, target_states, pred_std + def on_after_batch_transfer(self, batch, dataloader_idx): + """Normalize Batch data after transferring to the device.""" + if self.normalization_stats is not None: + init_states, target_states, forcing_features, boundary_features = ( + batch + ) + init_states = (init_states - self.data_mean) / self.data_std + target_states = (target_states - self.data_mean) / self.data_std + forcing_features = ( + forcing_features - self.forcing_mean + ) / self.forcing_std + boundary_features = ( + boundary_features - self.boundary_mean + ) / self.boundary_std + batch = ( + init_states, + target_states, + forcing_features, + boundary_features, + ) + return batch + def training_step(self, batch): """ Train on single batch diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 172cef95..c86418c8 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -276,6 +276,20 @@ def open_zarr(self, dataset_name): dataset = xr.open_zarr(dataset_path, consolidated=True) return dataset + def load_normalization_stats(self): + """Load normalization statistics from Zarr archive.""" + normalization_path = "normalization.zarr" + if not os.path.exists(normalization_path): + print( + f"Normalization statistics not found at " + f"path: {normalization_path}" + ) + return None + normalization_stats = xr.open_zarr( + normalization_path, consolidated=True + ) + return normalization_stats + def process_dataset(self, dataset_name, split="train", stack=True): """Process a single dataset specified by the dataset name.""" @@ -338,8 +352,8 @@ def process_dataset(self, dataset_name, split="train", stack=True): x=dataset[lon_name], y=dataset[lat_name] ) - if stack: - dataset = self.stack_grid(dataset) + if stack: + dataset = self.stack_grid(dataset) return dataset From 9293fe1b6e69f7960e4174e9d14a18e01cfa6521 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 21:55:42 +0200 Subject: [PATCH 18/26] latex support --- neural_lam/data_config.yaml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index faaabd32..55e59a72 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -50,12 +50,12 @@ state: # Variables forecasted by the model - V_10M surface_units: - "%" - - Pa - - Pa - - K - - kg/m^2 - - m/s - - m/s + - r"$\mathrm{Pa}$" + - r"$\mathrm{Pa}$" + - r"$\mathrm{K}$" + - r"$\mathrm{kg}/\mathrm{m}^2$" + - r"$\mathrm{m}/\mathrm{s}$" + - r"$\mathrm{m}/\mathrm{s}$" atmosphere: # Variables with vertical levels - PP - QV @@ -65,13 +65,13 @@ state: # Variables forecasted by the model - V - W atmosphere_units: - - Pa - - kg/kg + - r"$\mathrm{Pa}$" + - r"$\mathrm{kg}/\mathrm{kg}$" - "%" - - K - - m/s - - m/s - - Pa/s + - r"$\mathrm{K}$" + - r"$\mathrm{m}/\mathrm{s}$" + - r"$\mathrm{m}/\mathrm{s}$" + - r"$\mathrm{Pa}/\mathrm{s}$" levels: # Levels to use for atmosphere variables - 0 - 5 From e80aa5899c3ff9cad63f09f364d48bf4780e1dfe Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 21:58:18 +0200 Subject: [PATCH 19/26] ar_steps for training and eval --- neural_lam/weather_dataset.py | 29 +++++++++++++++---------- train_model.py | 41 +++++++++++++++-------------------- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index d51fb896..4b5da0a8 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -18,8 +18,8 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, split="train", - batch_size=4, ar_steps=3, + batch_size=4, control_only=False, data_config="neural_lam/data_config.yaml", ): @@ -47,9 +47,6 @@ def __init__( self.state_times = self.state.time.values self.forcing_window = self.config_loader.forcing.window self.boundary_window = self.config_loader.boundary.window - self.idx_max = max( - (self.boundary_window - 1), (self.forcing_window - 1) - ) if self.forcing is not None: self.forcing_windowed = ( @@ -81,17 +78,16 @@ def __init__( def __len__(self): # Skip first and last time step - return len(self.state.time) - self.ar_steps - self.idx_max + return len(self.state.time) - self.ar_steps def __getitem__(self, idx): - idx += self.idx_max // 2 # Skip first time step sample = torch.tensor( self.state.isel(time=slice(idx, idx + self.ar_steps)).values, dtype=torch.float32, ) forcing = ( - self.forcing_windowed.isel(time=slice(idx, idx + self.ar_steps)) + self.forcing_windowed.isel(time=slice(idx + 2, idx + self.ar_steps)) .stack(variable_window=("variable", "window")) .values if self.forcing is not None @@ -99,7 +95,9 @@ def __getitem__(self, idx): ) boundary = ( - self.boundary_windowed.isel(time=slice(idx, idx + self.ar_steps)) + self.boundary_windowed.isel( + time=slice(idx + 2, idx + self.ar_steps) + ) .stack(variable_window=("variable", "window")) .values if self.boundary is not None @@ -110,16 +108,16 @@ def __getitem__(self, idx): target_states = sample[2:] batch_times = ( - self.state.isel(time=slice(idx, idx + self.ar_steps)) + self.state.isel(time=slice(idx + 2, idx + self.ar_steps)) .time.values.astype(str) .tolist() ) # init_states: (2, N_grid, d_features) # target_states: (ar_steps-2, N_grid, d_features) - # forcing: (ar_steps, N_grid, d_windowed_forcing) - # boundary: (ar_steps, N_grid, d_windowed_boundary) - # batch_times: (ar_steps,) + # forcing: (ar_steps-2, N_grid, d_windowed_forcing) + # boundary: (ar_steps-2, N_grid, d_windowed_boundary) + # batch_times: (ar_steps-2,) return init_states, target_states, forcing, boundary, batch_times @@ -128,10 +126,14 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, + ar_steps_train=3, + ar_steps_eval=25, batch_size=4, num_workers=16, ): super().__init__() + self.ar_steps_train = ar_steps_train + self.ar_steps_eval = ar_steps_eval self.batch_size = batch_size self.num_workers = num_workers self.train_dataset = None @@ -142,16 +144,19 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( split="train", + ar_steps=self.ar_steps_train, batch_size=self.batch_size, ) self.val_dataset = WeatherDataset( split="val", + ar_steps=self.ar_steps_eval, batch_size=self.batch_size, ) if stage == "test" or stage is None: self.test_dataset = WeatherDataset( split="test", + ar_steps=self.ar_steps_eval, batch_size=self.batch_size, ) diff --git a/train_model.py b/train_model.py index e5dfd528..a8b02f58 100644 --- a/train_model.py +++ b/train_model.py @@ -31,14 +31,6 @@ def main(): description="Train or evaluate NeurWP models for LAM" ) - # General options - parser.add_argument( - "--dataset", - type=str, - default="meps_example", - help="Dataset, corresponding to name in data directory " - "(default: meps_example)", - ) parser.add_argument( "--model", type=str, @@ -51,13 +43,6 @@ def main(): default="neural_lam/data_config.yaml", help="Path to data config file (default: neural_lam/data_config.yaml)", ) - parser.add_argument( - "--subset_ds", - type=int, - default=0, - help="Use only a small subset of the dataset, for debugging" - "(default: 0=false)", - ) parser.add_argument( "--seed", type=int, default=42, help="random seed (default: 42)" ) @@ -139,11 +124,11 @@ def main(): # Training options parser.add_argument( - "--ar_steps", + "--ar_steps_train", type=int, - default=1, - help="Number of steps to unroll prediction for in loss (1-19) " - "(default: 1)", + default=3, + help="Number of steps to unroll prediction for in loss function " + "(default: 3)", ) parser.add_argument( "--control_only", @@ -161,9 +146,9 @@ def main(): parser.add_argument( "--step_length", type=int, - default=3, + default=1, help="Step length in hours to consider single time step 1-3 " - "(default: 3)", + "(default: 1)", ) parser.add_argument( "--lr", type=float, default=1e-3, help="learning rate (default: 0.001)" @@ -183,6 +168,13 @@ def main(): help="Eval model on given data split (val/test) " "(default: None (train model))", ) + parser.add_argument( + "--ar_steps_eval", + type=int, + default=25, + help="Number of steps to unroll prediction for in loss function " + "(default: 25)", + ) parser.add_argument( "--n_example_pred", type=int, @@ -234,6 +226,8 @@ def main(): seed.seed_everything(args.seed) # Create datamodule data_module = WeatherDataModule( + ar_steps_train=args.ar_steps_train, + ar_steps_eval=args.ar_steps_eval, batch_size=args.batch_size, num_workers=args.num_workers, ) @@ -258,9 +252,10 @@ def main(): else: model = model_class(args) - prefix = "subset-" if args.subset_ds else "" if args.eval: - prefix = prefix + f"eval-{args.eval}-" + prefix = f"eval-{args.eval}-" + else: + prefix = "train-" run_name = ( f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}" From a86fc0788c30a1f5f364f26ba4c68816b4af23f3 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 21:58:26 +0200 Subject: [PATCH 20/26] smaller ammendments --- neural_lam/models/ar_model.py | 50 ++++++++++++--------------- neural_lam/models/base_graph_model.py | 4 +-- neural_lam/utils.py | 21 +++++++---- 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 8976990b..0c0e5a55 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -24,13 +24,15 @@ class ARModel(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters() - self.lr = args.lr + self.args = args self.config_loader = utils.ConfigLoader(args.data_config) # Load static features for grid/data - static = self.config_loader.process_dataset("static", self.split) + static = self.config_loader.process_dataset("static") self.register_buffer( - "grid_static_features", torch.tensor(static.values) + "grid_static_features", + torch.tensor(static.values), + persistent=False, ) # Double grid output dim. to also output std.-dev. @@ -42,15 +44,6 @@ def __init__(self, args): # Pred. dim. in grid cell self.grid_output_dim = self.config_loader.num_data_vars("state") - # Store constant per-variable std.-dev. weighting - # Note that this is the inverse of the multiplicative weighting - # in wMSE/wMAE - self.register_buffer( - "per_var_std", - self.step_diff_std / torch.sqrt(self.param_weights), - persistent=False, - ) - # grid_dim from data + static ( self.num_grid_nodes, @@ -60,11 +53,14 @@ def __init__(self, args): 2 * self.config_loader.num_data_vars("state") + grid_static_dim + self.config_loader.num_data_vars("forcing") + * self.config_loader.forcing.window ) # Instantiate loss function self.loss = metrics.get_metric(args.loss) + border_mask = torch.ones(self.num_grid_nodes, 1) + self.register_buffer("border_mask", border_mask, persistent=False) # Pre-compute interior mask for use in loss function self.register_buffer( "interior_mask", 1.0 - self.border_mask, persistent=False @@ -99,12 +95,14 @@ def __init__(self, args): var_data, ) in self.normalization_stats.data_vars.items(): self.register_buffer( - f"data_{var_name}", torch.tensor(var_data.values) + f"{var_name}", + torch.tensor(var_data.values), + persistent=False, ) def configure_optimizers(self): opt = torch.optim.AdamW( - self.parameters(), lr=self.lr, betas=(0.9, 0.95) + self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) if self.opt_state: opt.load_state_dict(self.opt_state) @@ -179,7 +177,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states): pred_std_list, dim=1 ) # (B, pred_steps, num_grid_nodes, d_f) else: - pred_std = self.per_var_std # (d_f,) + pred_std = self.diff_std # (d_f,) return prediction, pred_std @@ -209,22 +207,20 @@ def common_step(self, batch): def on_after_batch_transfer(self, batch, dataloader_idx): """Normalize Batch data after transferring to the device.""" if self.normalization_stats is not None: - init_states, target_states, forcing_features, boundary_features = ( - batch - ) - init_states = (init_states - self.data_mean) / self.data_std - target_states = (target_states - self.data_mean) / self.data_std + init_states, target_states, forcing_features, _, _ = batch + init_states = (init_states - self.mean) / self.std + target_states = (target_states - self.mean) / self.std forcing_features = ( forcing_features - self.forcing_mean ) / self.forcing_std - boundary_features = ( - boundary_features - self.boundary_mean - ) / self.boundary_std + # boundary_features = ( + # boundary_features - self.boundary_mean + # ) / self.boundary_std batch = ( init_states, target_states, forcing_features, - boundary_features, + # boundary_features, ) return batch @@ -392,8 +388,8 @@ def plot_examples(self, batch, n_examples, prediction=None): target = batch[1] # Rescale to original data scale - prediction_rescaled = prediction * self.data_std + self.data_mean - target_rescaled = target * self.data_std + self.data_mean + prediction_rescaled = prediction * self.std + self.mean + target_rescaled = target * self.std + self.mean # Iterate over the examples for pred_slice, target_slice in zip( @@ -541,7 +537,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): metric_name = metric_name.replace("mse", "rmse") # Note: we here assume rescaling for all metrics is linear - metric_rescaled = metric_tensor_averaged * self.data_std + metric_rescaled = metric_tensor_averaged * self.std # (pred_steps, d_f) log_dict.update( self.create_metric_log_dict( diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 256d4adc..fb5df62d 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -166,9 +166,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing): pred_std = None # Rescale with one-step difference statistics - rescaled_delta_mean = ( - pred_delta_mean * self.step_diff_std + self.step_diff_mean - ) + rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean # Residual connection for full state return prev_state + rescaled_delta_mean, pred_std diff --git a/neural_lam/utils.py b/neural_lam/utils.py index c86418c8..71ef9512 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -242,16 +242,23 @@ def __contains__(self, key): def param_names(self): """Return parameter names.""" - return ( - self.values["state"]["surface"] + self.values["state"]["atmosphere"] - ) + surface_names = self.values["state"]["surface"] + atmosphere_names = [ + f"{var}_{level}" + for var in self.values["state"]["atmosphere"] + for level in self.values["state"]["levels"] + ] + return surface_names + atmosphere_names def param_units(self): """Return parameter units.""" - return ( - self.values["state"]["surface_units"] - + self.values["state"]["atmosphere_units"] - ) + surface_units = self.values["state"]["surface_units"] + atmosphere_units = [ + unit + for unit in self.values["state"]["atmosphere_units"] + for _ in self.values["state"]["levels"] + ] + return surface_units + atmosphere_units def num_data_vars(self, key): """Return the number of data variables for a given key.""" From 7ae9c872359b94283cb278f24f73bb7e050ae5bf Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 22:42:32 +0200 Subject: [PATCH 21/26] Dummy mask was inverted - fixed --- neural_lam/models/ar_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0c0e5a55..f49eb094 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -59,7 +59,7 @@ def __init__(self, args): # Instantiate loss function self.loss = metrics.get_metric(args.loss) - border_mask = torch.ones(self.num_grid_nodes, 1) + border_mask = torch.zeros(self.num_grid_nodes, 1) self.register_buffer("border_mask", border_mask, persistent=False) # Pre-compute interior mask for use in loss function self.register_buffer( From 93674a2b437cf46681ddced3859f8978e5e03200 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 9 May 2024 22:54:54 +0200 Subject: [PATCH 22/26] replace hardcoded normalization path --- neural_lam/data_config.yaml | 2 +- neural_lam/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index 55e59a72..140eb9b7 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -187,7 +187,7 @@ projection: pole_longitude: 10.0 pole_latitude: -43.0 normalization: - zarr: /scratch/sadamov/norm.zarr + zarr: normalization.zarr vars: data_mean: data_mean data_std: data_std diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 71ef9512..96e1549e 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -285,7 +285,7 @@ def open_zarr(self, dataset_name): def load_normalization_stats(self): """Load normalization statistics from Zarr archive.""" - normalization_path = "normalization.zarr" + normalization_path = self.normalization.zarr if not os.path.exists(normalization_path): print( f"Normalization statistics not found at " From 244284ce7759b3734c1d987b9b233ba0d55b4f96 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 14 May 2024 17:13:38 +0200 Subject: [PATCH 23/26] constants.py converted into yaml-file test-case based on meps-example --- .gitignore | 4 +- README.md | 5 +- create_grid_features.py | 63 +++++ create_mesh.py | 21 +- create_parameter_weights.py | 158 ++++++----- neural_lam/data_config.yaml | 259 +++++------------- neural_lam/models/ar_model.py | 83 ++---- neural_lam/models/base_graph_model.py | 4 +- neural_lam/utils.py | 218 ++++++--------- neural_lam/vis.py | 2 +- neural_lam/weather_dataset.py | 365 +++++++++++++++----------- plot_graph.py | 23 +- pyproject.toml | 21 +- requirements.txt | 3 - train_model.py | 117 ++++++--- 15 files changed, 669 insertions(+), 677 deletions(-) create mode 100644 create_grid_features.py diff --git a/.gitignore b/.gitignore index 08cc014e..c9d914c2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,8 +7,7 @@ graphs *.sif sweeps test_*.sh -cosmo_hilam.html -normalization.zarr +.vscode ### Python ### # Byte-compiled / optimized / DLL files @@ -74,4 +73,3 @@ tags # Coc configuration directory .vim -.vscode diff --git a/README.md b/README.md index 67d9d9b1..fc5675e8 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Still, some restrictions are inevitable: ## A note on the limited area setting Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). There are still some parts of the code that is quite specific for the MEPS area use case. -This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`). +This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/data_config.yaml`). If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. We would be happy to support such enhancements. See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. @@ -104,13 +104,12 @@ The graph-related files are stored in a directory called `graphs`. ### Create remaining static features To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`. -The main option to set for these is just which dataset to use. ## Weights & Biases Integration The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it. When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface. If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`. -The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam/constants.py`. +The W&B project name is set to `neural-lam`, but this can be changed in the flags of `train_model.py` (using argsparse). See the [W&B documentation](https://docs.wandb.ai/) for details. If you would like to login and use W&B, run: diff --git a/create_grid_features.py b/create_grid_features.py new file mode 100644 index 00000000..e5b9c49a --- /dev/null +++ b/create_grid_features.py @@ -0,0 +1,63 @@ +# Standard library +import os +from argparse import ArgumentParser + +# Third-party +import numpy as np +import torch + +# First-party +from neural_lam import utils + + +def main(): + """ + Pre-compute all static features related to the grid nodes + """ + parser = ArgumentParser(description="Training arguments") + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) + args = parser.parse_args() + config_loader = utils.ConfigLoader(args.data_config) + + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") + + # -- Static grid node features -- + grid_xy = torch.tensor( + np.load(os.path.join(static_dir_path, "nwp_xy.npy")) + ) # (2, N_x, N_y) + grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2) + pos_max = torch.max(torch.abs(grid_xy)) + grid_xy = grid_xy / pos_max # Divide by maximum coordinate + + geopotential = torch.tensor( + np.load(os.path.join(static_dir_path, "surface_geopotential.npy")) + ) # (N_x, N_y) + geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1) + gp_min = torch.min(geopotential) + gp_max = torch.max(geopotential) + # Rescale geopotential to [0,1] + geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1) + + grid_border_mask = torch.tensor( + np.load(os.path.join(static_dir_path, "border_mask.npy")), + dtype=torch.int64, + ) # (N_x, N_y) + grid_border_mask = ( + grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1) + ) # (N_grid, 1) + + # Concatenate grid features + grid_features = torch.cat( + (grid_xy, geopotential, grid_border_mask), dim=1 + ) # (N_grid, 4) + + torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt")) + + +if __name__ == "__main__": + main() diff --git a/create_mesh.py b/create_mesh.py index 2b6af9fd..477ddf55 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -15,8 +15,6 @@ # First-party from neural_lam import utils -# matplotlib.use('TkAgg') - def plot_graph(graph, title=None): fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H @@ -157,6 +155,12 @@ def prepend_node_index(graph, new_index): def main(): parser = ArgumentParser(description="Graph generation arguments") + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) parser.add_argument( "--graph", type=str, @@ -182,21 +186,16 @@ def main(): default=0, help="Generate hierarchical mesh graph (default: 0, no)", ) - parser.add_argument( - "--data_config", - type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file (default: neural_lam/data_config.yaml)", - ) - args = parser.parse_args() # Load grid positions + config_loader = utils.ConfigLoader(args.data_config) + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) - config_loader = utils.ConfigLoader(args.data_config) - xy = config_loader.get_nwp_xy() + xy = np.load(os.path.join(static_dir_path, "nwp_xy.npy")) + grid_xy = torch.tensor(xy) pos_max = torch.max(torch.abs(grid_xy)) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 1eda7a24..fd8c38cd 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -1,13 +1,15 @@ # Standard library +import os from argparse import ArgumentParser # Third-party +import numpy as np import torch -import xarray as xr from tqdm import tqdm # First-party -from neural_lam.weather_dataset import WeatherDataModule +from neural_lam import utils +from neural_lam.weather_dataset import WeatherDataset def main(): @@ -15,6 +17,12 @@ def main(): Pre-compute parameter weights to be used in loss function """ parser = ArgumentParser(description="Training arguments") + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) parser.add_argument( "--batch_size", type=int, @@ -22,77 +30,120 @@ def main(): help="Batch size when iterating over the dataset", ) parser.add_argument( - "--num_workers", + "--step_length", type=int, - default=4, - help="Number of workers in data loader (default: 4)", + default=3, + help="Step length in hours to consider single time step (default: 3)", ) parser.add_argument( - "--zarr_path", - type=str, - default="normalization.zarr", - help="Directory where data is stored", + "--n_workers", + type=int, + default=4, + help="Number of workers in data loader (default: 4)", ) - args = parser.parse_args() - data_module = WeatherDataModule( - batch_size=args.batch_size, num_workers=args.num_workers + config_loader = utils.ConfigLoader(args.data_config) + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") + + # Create parameter weights based on height + # based on fig A.1 in graph cast paper + w_dict = { + "2": 1.0, + "0": 0.1, + "65": 0.065, + "1000": 0.1, + "850": 0.05, + "500": 0.03, + } + w_list = np.array( + [w_dict[par.split("_")[-2]] for par in config_loader.dataset.var_names] + ) + print("Saving parameter weights...") + np.save( + os.path.join(static_dir_path, "parameter_weights.npy"), + w_list.astype("float32"), ) - data_module.setup() - loader = data_module.train_dataloader() # Load dataset without any subsampling - # Compute mean and std.-dev. of each parameter (+ forcing forcing) + ds = WeatherDataset( + config_loader.dataset.name, + split="train", + subsample_step=1, + pred_length=63, + standardize=False, + ) # Without standardization + loader = torch.utils.data.DataLoader( + ds, args.batch_size, shuffle=False, num_workers=args.n_workers + ) + # Compute mean and std.-dev. of each parameter (+ flux forcing) # across full dataset print("Computing mean and std.-dev. for parameters...") means = [] squares = [] - fb_means = {"forcing": [], "boundary": []} - fb_squares = {"forcing": [], "boundary": []} - - for init_batch, target_batch, forcing_batch, boundary_batch, _ in tqdm( - loader - ): + flux_means = [] + flux_squares = [] + for init_batch, target_batch, forcing_batch in tqdm(loader): batch = torch.cat( (init_batch, target_batch), dim=1 ) # (N_batch, N_t, N_grid, d_features) means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,) - squares.append(torch.mean(batch**2, dim=(1, 2))) + squares.append( + torch.mean(batch**2, dim=(1, 2)) + ) # (N_batch, d_features,) - for fb_type, fb_batch in zip( - ["forcing", "boundary"], [forcing_batch, boundary_batch] - ): - fb_batch = fb_batch[:, :, :, 1] - fb_means[fb_type].append(torch.mean(fb_batch)) # (,) - fb_squares[fb_type].append(torch.mean(fb_batch**2)) # (,) + # Flux at 1st windowed position is index 1 in forcing + flux_batch = forcing_batch[:, :, :, 1] + flux_means.append(torch.mean(flux_batch)) # (,) + flux_squares.append(torch.mean(flux_batch**2)) # (,) mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features) second_moment = torch.mean(torch.cat(squares, dim=0), dim=0) std = torch.sqrt(second_moment - mean**2) # (d_features) - fb_stats = {} - for fb_type in ["forcing", "boundary"]: - fb_stats[f"{fb_type}_mean"] = torch.mean( - torch.stack(fb_means[fb_type]) - ) # (,) - fb_second_moment = torch.mean(torch.stack(fb_squares[fb_type])) # (,) - fb_stats[f"{fb_type}_std"] = torch.sqrt( - fb_second_moment - fb_stats[f"{fb_type}_mean"] ** 2 - ) # (,) + flux_mean = torch.mean(torch.stack(flux_means)) # (,) + flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) + flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) + flux_stats = torch.stack((flux_mean, flux_std)) + + print("Saving mean, std.-dev, flux_stats...") + torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt")) + torch.save(std, os.path.join(static_dir_path, "parameter_std.pt")) + torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt")) # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") + ds_standard = WeatherDataset( + config_loader.dataset.name, + split="train", + subsample_step=1, + pred_length=63, + standardize=True, + ) # Re-load with standardization + loader_standard = torch.utils.data.DataLoader( + ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers + ) + used_subsample_len = (65 // args.step_length) * args.step_length + diff_means = [] diff_squares = [] - for init_batch, target_batch, _, _, _ in tqdm(loader): - # normalize the batch - init_batch = (init_batch - mean) / std - target_batch = (target_batch - mean) / std - - batch = torch.cat((init_batch, target_batch), dim=1) - batch_diffs = batch[:, 1:] - batch[:, :-1] - # (N_batch, N_t-1, N_grid, d_features) + for init_batch, target_batch, _ in tqdm(loader_standard): + batch = torch.cat( + (init_batch, target_batch), dim=1 + ) # (N_batch, N_t', N_grid, d_features) + # Note: batch contains only 1h-steps + stepped_batch = torch.cat( + [ + batch[:, ss_i : used_subsample_len : args.step_length] + for ss_i in range(args.step_length) + ], + dim=0, + ) + # (N_batch', N_t, N_grid, d_features), + # N_batch' = args.step_length*N_batch + + batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] + # (N_batch', N_t-1, N_grid, d_features) diff_means.append( torch.mean(batch_diffs, dim=(1, 2)) @@ -105,20 +156,9 @@ def main(): diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0) diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features) - # Create xarray dataset - ds = xr.Dataset( - { - "mean": (["d_features"], mean), - "std": (["d_features"], std), - "diff_mean": (["d_features"], diff_mean), - "diff_std": (["d_features"], diff_std), - **fb_stats, - } - ) - - # Save dataset as Zarr - print("Saving dataset as Zarr...") - ds.to_zarr(args.zarr_path, mode="w") + print("Saving one-step difference mean and std.-dev...") + torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) + torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) if __name__ == "__main__": diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index 140eb9b7..213825ff 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -1,199 +1,64 @@ -zarrs: # List of zarrs containing fields related to state - state: - path: /scratch/sadamov/template.zarr # Path to zarr - dims: # Name of dimensions in zarr, to be used for indexing - time: time - level: z - x: x # Either give "grid" (flattened) dimension or "x" and "y" - y: y - lat_lon_names: - lon: lon - lat: lat - static: - path: /scratch/sadamov/template.zarr - dims: - level: z - x: x - y: y - lat_lon_names: - lon: lon - lat: lat - forcing: - path: /scratch/sadamov/template.zarr - dims: - time: time - level: z - x: x - y: y - lat_lon_names: - lon: lon - lat: lat - boundary: - path: /scratch/sadamov/era5_template.zarr - dims: - time: time - level: level - x: longitude - y: latitude - lat_lon_names: - lon: longitude - lat: latitude - mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary. -state: # Variables forecasted by the model - surface: # Single-field variables - - CLCT - - PMSL - - PS - - T_2M - - TOT_PREC - - U_10M - - V_10M - surface_units: - - "%" - - r"$\mathrm{Pa}$" - - r"$\mathrm{Pa}$" - - r"$\mathrm{K}$" +dataset: + name: meps_example + vars: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + - t_500 + - t_850 + - u_65 + - u_850 + - v_65 + - v_850 + - wvint_0 + - z_1000 + - z_500 + units: + - Pa + - Pa + - r"$\mathrm{W}/\mathrm{m}^2$" + - r"$\mathrm{W}/\mathrm{m}^2$" + - "" + - "" + - K + - K + - K + - K + - m/s + - m/s + - m/s + - m/s - r"$\mathrm{kg}/\mathrm{m}^2$" - - r"$\mathrm{m}/\mathrm{s}$" - - r"$\mathrm{m}/\mathrm{s}$" - atmosphere: # Variables with vertical levels - - PP - - QV - - RELHUM - - T - - U - - V - - W - atmosphere_units: - - r"$\mathrm{Pa}$" - - r"$\mathrm{kg}/\mathrm{kg}$" - - "%" - - r"$\mathrm{K}$" - - r"$\mathrm{m}/\mathrm{s}$" - - r"$\mathrm{m}/\mathrm{s}$" - - r"$\mathrm{Pa}/\mathrm{s}$" - levels: # Levels to use for atmosphere variables - - 0 - - 5 - - 8 - - 11 - - 13 - - 15 - - 19 - - 22 - - 26 - - 30 - - 38 - - 44 - - 59 -static: # Static inputs - surface: - - HSURF - atmosphere: - - FI - levels: - - 0 - - 5 - - 8 - - 11 - - 13 - - 15 - - 19 - - 22 - - 26 - - 30 - - 38 - - 44 - - 59 -forcing: # Forcing variables, dynamic inputs to the model - surface: - - ASOB_S - atmosphere: - - T - levels: - - 0 - - 5 - - 8 - - 11 - - 13 - - 15 - - 19 - - 22 - - 26 - - 30 - - 38 - - 44 - - 59 - window: 3 # Number of time steps to use for forcing (odd) -boundary: # Boundary conditions - surface: - - 10m_u_component_of_wind - # - 10m_v_component_of_wind - # - 2m_dewpoint_temperature - # - 2m_temperature - # - mean_sea_level_pressure - # - mean_surface_latent_heat_flux - # - mean_surface_net_long_wave_radiation_flux - # - mean_surface_net_short_wave_radiation_flux - # - mean_surface_sensible_heat_flux - # - surface_pressure - # - total_cloud_cover - # - total_column_water_vapour - # - total_precipitation_12hr - # - total_precipitation_24hr - # - total_precipitation_6hr - # - geopotential_at_surface - atmosphere: - - divergence - # - geopotential - # - relative_humidity - # - specific_humidity - # - temperature - # - u_component_of_wind - # - v_component_of_wind - # - vertical_velocity - # - vorticity - levels: - - 50 - - 100 - - 150 - - 200 - - 250 - - 300 - - 400 - - 500 - - 600 - - 700 - - 850 - - 925 - - 1000 - window: 3 # Number of time steps to use for boundary (odd) -grid_shape_state: - x: 582 - y: 390 -splits: - train: - start: 2015-01-01T00 - end: 2024-12-31T23 - val: - start: 2015-01-01T00 - end: 2024-12-31T23 - test: - start: 2015-01-01T00 - end: 2024-12-31T23 + - r"$\mathrm{m}^2/\mathrm{s}^2$" + - r"$\mathrm{m}^2/\mathrm{s}^2$" + var_names: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + - t_isobaricInhPa_500_instant + - t_isobaricInhPa_850_instant + - u_hybrid_65_instant + - u_isobaricInhPa_850_instant + - v_hybrid_65_instant + - v_isobaricInhPa_850_instant + - wvint_entireAtmosphere_0_instant + - z_isobaricInhPa_1000_instant + - z_isobaricInhPa_500_instant + forcing_dim: 16 +grid_shape_state: [268, 238] projection: - class: RotatedPole # Name of class in cartopy.crs - kwargs: # Parsed and used directly as kwargs to projection-class above - pole_longitude: 10.0 - pole_latitude: -43.0 -normalization: - zarr: normalization.zarr - vars: - data_mean: data_mean - data_std: data_std - forcing_mean: forcing_mean - forcing_std: forcing_std - boundary_mean: boundary_mean - boundary_std: boundary_std - diff_mean: diff_mean - diff_std: diff_std + class: LambertConformal + kwargs: + central_longitude: 15.0 + central_latitude: 63.3 + standard_parallels: [63.3, 63.3] diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f49eb094..da2654f0 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -28,21 +28,30 @@ def __init__(self, args): self.config_loader = utils.ConfigLoader(args.data_config) # Load static features for grid/data - static = self.config_loader.process_dataset("static") - self.register_buffer( - "grid_static_features", - torch.tensor(static.values), - persistent=False, + static_data_dict = utils.load_static_data( + self.config_loader.dataset.name ) + for static_data_name, static_data_tensor in static_data_dict.items(): + self.register_buffer( + static_data_name, static_data_tensor, persistent=False + ) # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) if self.output_std: # Pred. dim. in grid cell - self.grid_output_dim = 2 * self.config_loader.num_data_vars("state") + self.grid_output_dim = 2 * self.config_loader.num_data_vars() else: # Pred. dim. in grid cell - self.grid_output_dim = self.config_loader.num_data_vars("state") + self.grid_output_dim = self.config_loader.num_data_vars() + # Store constant per-variable std.-dev. weighting + # Note that this is the inverse of the multiplicative weighting + # in wMSE/wMAE + self.register_buffer( + "per_var_std", + self.step_diff_std / torch.sqrt(self.param_weights), + persistent=False, + ) # grid_dim from data + static ( @@ -50,17 +59,14 @@ def __init__(self, args): grid_static_dim, ) = self.grid_static_features.shape self.grid_dim = ( - 2 * self.config_loader.num_data_vars("state") + 2 * self.config_loader.num_data_vars() + grid_static_dim - + self.config_loader.num_data_vars("forcing") - * self.config_loader.forcing.window + + self.config_loader.dataset.forcing_dim ) # Instantiate loss function self.loss = metrics.get_metric(args.loss) - border_mask = torch.zeros(self.num_grid_nodes, 1) - self.register_buffer("border_mask", border_mask, persistent=False) # Pre-compute interior mask for use in loss function self.register_buffer( "interior_mask", 1.0 - self.border_mask, persistent=False @@ -87,19 +93,6 @@ def __init__(self, args): # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] - # Load normalization statistics - self.normalization_stats = self.config_loader.load_normalization_stats() - if self.normalization_stats is not None: - for ( - var_name, - var_data, - ) in self.normalization_stats.data_vars.items(): - self.register_buffer( - f"{var_name}", - torch.tensor(var_data.values), - persistent=False, - ) - def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) @@ -177,7 +170,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states): pred_std_list, dim=1 ) # (B, pred_steps, num_grid_nodes, d_f) else: - pred_std = self.diff_std # (d_f,) + pred_std = self.per_var_std # (d_f,) return prediction, pred_std @@ -204,26 +197,6 @@ def common_step(self, batch): return prediction, target_states, pred_std - def on_after_batch_transfer(self, batch, dataloader_idx): - """Normalize Batch data after transferring to the device.""" - if self.normalization_stats is not None: - init_states, target_states, forcing_features, _, _ = batch - init_states = (init_states - self.mean) / self.std - target_states = (target_states - self.mean) / self.std - forcing_features = ( - forcing_features - self.forcing_mean - ) / self.forcing_std - # boundary_features = ( - # boundary_features - self.boundary_mean - # ) / self.boundary_std - batch = ( - init_states, - target_states, - forcing_features, - # boundary_features, - ) - return batch - def training_step(self, batch): """ Train on single batch @@ -355,7 +328,9 @@ def test_step(self, batch, batch_idx): spatial_loss = self.loss( prediction, target, pred_std, average_grid=False ) # (B, pred_steps, num_grid_nodes) - log_spatial_losses = spatial_loss[:, self.args.val_steps_log - 1] + log_spatial_losses = spatial_loss[ + :, [step - 1 for step in self.args.val_steps_log] + ] self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) @@ -388,8 +363,8 @@ def plot_examples(self, batch, n_examples, prediction=None): target = batch[1] # Rescale to original data scale - prediction_rescaled = prediction * self.std + self.mean - target_rescaled = target * self.std + self.mean + prediction_rescaled = prediction * self.data_std + self.data_mean + target_rescaled = target * self.data_std + self.data_mean # Iterate over the examples for pred_slice, target_slice in zip( @@ -433,8 +408,8 @@ def plot_examples(self, batch, n_examples, prediction=None): ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - self.config_loader.param_names(), - self.config_loader.param_units(), + self.config_loader.dataset.vars, + self.config_loader.dataset.units, var_vranges, ) ) @@ -445,7 +420,7 @@ def plot_examples(self, batch, n_examples, prediction=None): { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - self.config_loader.param_names(), var_figs + self.config_loader.dataset.vars, var_figs ) } ) @@ -501,7 +476,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): # Check if metrics are watched, log exact values for specific vars if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.param_names()[var_i] + var = self.config_loader.dataset.vars[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ @@ -537,7 +512,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): metric_name = metric_name.replace("mse", "rmse") # Note: we here assume rescaling for all metrics is linear - metric_rescaled = metric_tensor_averaged * self.std + metric_rescaled = metric_tensor_averaged * self.data_std # (pred_steps, d_f) log_dict.update( self.create_metric_log_dict( diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index fb5df62d..256d4adc 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -166,7 +166,9 @@ def predict_step(self, prev_state, prev_prev_state, forcing): pred_std = None # Rescale with one-step difference statistics - rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean + rescaled_delta_mean = ( + pred_delta_mean * self.step_diff_std + self.step_diff_mean + ) # Residual connection for full state return prev_state + rescaled_delta_mean, pred_std diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 96e1549e..528560e3 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -5,12 +5,85 @@ import cartopy.crs as ccrs import numpy as np import torch -import xarray as xr import yaml from torch import nn from tueplots import bundles, figsizes +def load_dataset_stats(dataset_name, device="cpu"): + """ + Load arrays with stored dataset statistics from pre-processing + """ + static_dir_path = os.path.join("data", dataset_name, "static") + + def loads_file(fn): + return torch.load( + os.path.join(static_dir_path, fn), map_location=device + ) + + data_mean = loads_file("parameter_mean.pt") # (d_features,) + data_std = loads_file("parameter_std.pt") # (d_features,) + + flux_stats = loads_file("flux_stats.pt") # (2,) + flux_mean, flux_std = flux_stats + + return { + "data_mean": data_mean, + "data_std": data_std, + "flux_mean": flux_mean, + "flux_std": flux_std, + } + + +def load_static_data(dataset_name, device="cpu"): + """ + Load static files related to dataset + """ + static_dir_path = os.path.join("data", dataset_name, "static") + + def loads_file(fn): + return torch.load( + os.path.join(static_dir_path, fn), map_location=device + ) + + # Load border mask, 1. if node is part of border, else 0. + border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy")) + border_mask = ( + torch.tensor(border_mask_np, dtype=torch.float32, device=device) + .flatten(0, 1) + .unsqueeze(1) + ) # (N_grid, 1) + + grid_static_features = loads_file( + "grid_features.pt" + ) # (N_grid, d_grid_static) + + # Load step diff stats + step_diff_mean = loads_file("diff_mean.pt") # (d_f,) + step_diff_std = loads_file("diff_std.pt") # (d_f,) + + # Load parameter std for computing validation errors in original data scale + data_mean = loads_file("parameter_mean.pt") # (d_features,) + data_std = loads_file("parameter_std.pt") # (d_features,) + + # Load loss weighting vectors + param_weights = torch.tensor( + np.load(os.path.join(static_dir_path, "parameter_weights.npy")), + dtype=torch.float32, + device=device, + ) # (d_f,) + + return { + "border_mask": border_mask, + "grid_static_features": grid_static_features, + "step_diff_mean": step_diff_mean, + "step_diff_std": step_diff_std, + "data_mean": data_mean, + "data_std": data_std, + "param_weights": param_weights, + } + + class BufferList(nn.Module): """ A list of torch buffer tensors that sit together as a Module with no @@ -240,145 +313,14 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.values - def param_names(self): - """Return parameter names.""" - surface_names = self.values["state"]["surface"] - atmosphere_names = [ - f"{var}_{level}" - for var in self.values["state"]["atmosphere"] - for level in self.values["state"]["levels"] - ] - return surface_names + atmosphere_names - - def param_units(self): - """Return parameter units.""" - surface_units = self.values["state"]["surface_units"] - atmosphere_units = [ - unit - for unit in self.values["state"]["atmosphere_units"] - for _ in self.values["state"]["levels"] - ] - return surface_units + atmosphere_units - - def num_data_vars(self, key): + def num_data_vars(self): """Return the number of data variables for a given key.""" - surface_vars = len(self.values[key]["surface"]) - atmosphere_vars = len(self.values[key]["atmosphere"]) - levels = len(self.values[key]["levels"]) - return surface_vars + atmosphere_vars * levels + return len(self.dataset.vars) def projection(self): """Return the projection.""" - proj_config = self.values["projections"]["class"] - proj_class = getattr(ccrs, proj_config["proj_class"]) - proj_params = proj_config["proj_params"] + proj_config = self.values["projection"] + proj_class_name = proj_config["class"] + proj_class = getattr(ccrs, proj_class_name) + proj_params = proj_config.get("kwargs", {}) return proj_class(**proj_params) - - def open_zarr(self, dataset_name): - """Open a dataset specified by the dataset name.""" - dataset_path = self.zarrs[dataset_name].path - if dataset_path is None or not os.path.exists(dataset_path): - print(f"Dataset '{dataset_name}' not found at path: {dataset_path}") - return None - dataset = xr.open_zarr(dataset_path, consolidated=True) - return dataset - - def load_normalization_stats(self): - """Load normalization statistics from Zarr archive.""" - normalization_path = self.normalization.zarr - if not os.path.exists(normalization_path): - print( - f"Normalization statistics not found at " - f"path: {normalization_path}" - ) - return None - normalization_stats = xr.open_zarr( - normalization_path, consolidated=True - ) - return normalization_stats - - def process_dataset(self, dataset_name, split="train", stack=True): - """Process a single dataset specified by the dataset name.""" - - dataset = self.open_zarr(dataset_name) - if dataset is None: - return None - - start, end = ( - self.splits[split].start, - self.splits[split].end, - ) - dataset = dataset.sel(time=slice(start, end)) - dataset = dataset.rename_dims( - { - v: k - for k, v in self.zarrs[dataset_name].dims.values.items() - if k not in dataset.dims - } - ) - - vars_surface = [] - if self[dataset_name].surface: - vars_surface = dataset[self[dataset_name].surface] - - vars_atmosphere = [] - if self[dataset_name].atmosphere: - vars_atmosphere = xr.merge( - [ - dataset[var] - .sel(level=level, drop=True) - .rename(f"{var}_{level}") - for var in self[dataset_name].atmosphere - for level in self[dataset_name].levels - ] - ) - - if vars_surface and vars_atmosphere: - dataset = xr.merge([vars_surface, vars_atmosphere]) - elif vars_surface: - dataset = vars_surface - elif vars_atmosphere: - dataset = vars_atmosphere - else: - print(f"No variables found in dataset {dataset_name}") - return None - - if not all( - lat_lon in self.zarrs[dataset_name].dims.values.values() - for lat_lon in self.zarrs[ - dataset_name - ].lat_lon_names.values.values() - ): - lat_name = self.zarrs[dataset_name].lat_lon_names.lat - lon_name = self.zarrs[dataset_name].lat_lon_names.lon - if dataset[lat_name].ndim == 2: - dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True) - if dataset[lon_name].ndim == 2: - dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True) - dataset = dataset.assign_coords( - x=dataset[lon_name], y=dataset[lat_name] - ) - - if stack: - dataset = self.stack_grid(dataset) - - return dataset - - def stack_grid(self, dataset): - """Stack grid dimensions.""" - dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() - - if "time" in dataset.dims: - dataset = dataset.transpose("time", "grid", "variable") - else: - dataset = dataset.transpose("grid", "variable") - return dataset - - def get_nwp_xy(self): - """Get the x and y coordinates for the NWP grid.""" - x = self.process_dataset("static", stack=False).x.values - y = self.process_dataset("static", stack=False).y.values - xx, yy = np.meshgrid(y, x) - xy = np.stack((xx, yy), axis=0) - - return xy diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 8c36a9a7..7a4d3730 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -51,7 +51,7 @@ def plot_error_map(errors, data_config, title=None, step_length=3): y_ticklabels = [ f"{name} ({unit})" for name, unit in zip( - data_config.param_names(), data_config.param_units() + data_config.dataset.vars, data_config.dataset.units ) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 4b5da0a8..0c01ae1d 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,5 +1,10 @@ +# Standard library +import datetime as dt +import glob +import os + # Third-party -import pytorch_lightning as pl +import numpy as np import torch # First-party @@ -8,181 +13,249 @@ class WeatherDataset(torch.utils.data.Dataset): """ - Dataset class for weather data. - - This class loads and processes weather data from zarr files based on the - provided configuration. It supports splitting the data into train, - validation, and test sets. + For our dataset: + N_t' = 65 + N_t = 65//subsample_step (= 21 for 3h steps) + dim_x = 268 + dim_y = 238 + N_grid = 268x238 = 63784 + d_features = 17 (d_features' = 18) + d_forcing = 5 """ def __init__( self, + dataset_name, + pred_length=19, split="train", - ar_steps=3, - batch_size=4, + subsample_step=3, + standardize=True, + subset=False, control_only=False, - data_config="neural_lam/data_config.yaml", ): super().__init__() - assert split in ( - "train", - "val", - "test", - ), "Unknown dataset split" - - self.split = split - self.batch_size = batch_size - self.ar_steps = ar_steps - self.control_only = control_only - self.config_loader = utils.ConfigLoader(data_config) - - self.state = self.config_loader.process_dataset("state", self.split) - assert self.state is not None, "State dataset not found" - self.forcing = self.config_loader.process_dataset("forcing", self.split) - self.boundary = self.config_loader.process_dataset( - "boundary", self.split + assert split in ("train", "val", "test"), "Unknown dataset split" + self.sample_dir_path = os.path.join( + "data", dataset_name, "samples", split ) - self.state_times = self.state.time.values - self.forcing_window = self.config_loader.forcing.window - self.boundary_window = self.config_loader.boundary.window - - if self.forcing is not None: - self.forcing_windowed = ( - self.forcing.sel( - time=self.state.time, - method="nearest", - ) - .pad( - time=(self.forcing_window // 2, self.forcing_window // 2), - mode="edge", - ) - .rolling(time=self.forcing_window, center=True) - .construct("window") - ) + member_file_regexp = ( + "nwp*mbr000.npy" if control_only else "nwp*mbr*.npy" + ) + sample_paths = glob.glob( + os.path.join(self.sample_dir_path, member_file_regexp) + ) + self.sample_names = [path.split("/")[-1][4:-4] for path in sample_paths] + # Now on form "yyymmddhh_mbrXXX" + + if subset: + self.sample_names = self.sample_names[:50] # Limit to 50 samples + + self.sample_length = pred_length + 2 # 2 init states + self.subsample_step = subsample_step + self.original_sample_length = ( + 65 // self.subsample_step + ) # 21 for 3h steps + assert ( + self.sample_length <= self.original_sample_length + ), "Requesting too long time series samples" - if self.boundary is not None: - self.boundary_windowed = ( - self.boundary.sel( - time=self.state.time, - method="nearest", - ) - .pad( - time=(self.boundary_window // 2, self.boundary_window // 2), - mode="edge", - ) - .rolling(time=self.boundary_window, center=True) - .construct("window") + # Set up for standardization + self.standardize = standardize + if standardize: + ds_stats = utils.load_dataset_stats(dataset_name, "cpu") + self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( + ds_stats["data_mean"], + ds_stats["data_std"], + ds_stats["flux_mean"], + ds_stats["flux_std"], ) + # If subsample index should be sampled (only duing training) + self.random_subsample = split == "train" + def __len__(self): - # Skip first and last time step - return len(self.state.time) - self.ar_steps + return len(self.sample_names) def __getitem__(self, idx): - sample = torch.tensor( - self.state.isel(time=slice(idx, idx + self.ar_steps)).values, - dtype=torch.float32, + # === Sample === + sample_name = self.sample_names[idx] + sample_path = os.path.join( + self.sample_dir_path, f"nwp_{sample_name}.npy" ) + try: + full_sample = torch.tensor( + np.load(sample_path), dtype=torch.float32 + ) # (N_t', dim_x, dim_y, d_features') + except ValueError: + print(f"Failed to load {sample_path}") + + # Only use every ss_step:th time step, sample which of ss_step + # possible such time series + if self.random_subsample: + subsample_index = torch.randint(0, self.subsample_step, ()).item() + else: + subsample_index = 0 + subsample_end_index = self.original_sample_length * self.subsample_step + sample = full_sample[ + subsample_index : subsample_end_index : self.subsample_step + ] + # (N_t, dim_x, dim_y, d_features') + + # Remove feature 15, "z_height_above_ground" + sample = torch.cat( + (sample[:, :, :, :15], sample[:, :, :, 16:]), dim=3 + ) # (N_t, dim_x, dim_y, d_features) - forcing = ( - self.forcing_windowed.isel(time=slice(idx + 2, idx + self.ar_steps)) - .stack(variable_window=("variable", "window")) - .values - if self.forcing is not None - else torch.tensor([]) + # Accumulate solar radiation instead of just subsampling + rad_features = full_sample[:, :, :, 2:4] # (N_t', dim_x, dim_y, 2) + # Accumulate for first time step + init_accum_rad = torch.sum( + rad_features[: (subsample_index + 1)], dim=0, keepdim=True + ) # (1, dim_x, dim_y, 2) + # Accumulate for rest of subsampled sequence + in_subsample_len = ( + subsample_end_index - self.subsample_step + subsample_index + 1 ) + rad_features_in_subsample = rad_features[ + (subsample_index + 1) : in_subsample_len + ] # (N_t*, dim_x, dim_y, 2), N_t* = (N_t-1)*ss_step + _, dim_x, dim_y, _ = sample.shape + rest_accum_rad = torch.sum( + rad_features_in_subsample.view( + self.original_sample_length - 1, + self.subsample_step, + dim_x, + dim_y, + 2, + ), + dim=1, + ) # (N_t-1, dim_x, dim_y, 2) + accum_rad = torch.cat( + (init_accum_rad, rest_accum_rad), dim=0 + ) # (N_t, dim_x, dim_y, 2) + # Replace in sample + sample[:, :, :, 2:4] = accum_rad - boundary = ( - self.boundary_windowed.isel( - time=slice(idx + 2, idx + self.ar_steps) - ) - .stack(variable_window=("variable", "window")) - .values - if self.boundary is not None - else torch.tensor([]) + # Flatten spatial dim + sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) + + # Uniformly sample time id to start sample from + init_id = torch.randint( + 0, 1 + self.original_sample_length - self.sample_length, () ) + sample = sample[init_id : (init_id + self.sample_length)] + # (sample_length, N_grid, d_features) + + if self.standardize: + # Standardize sample + sample = (sample - self.data_mean) / self.data_std - init_states = sample[:2] - target_states = sample[2:] + # Split up sample in init. states and target states + init_states = sample[:2] # (2, N_grid, d_features) + target_states = sample[2:] # (sample_length-2, N_grid, d_features) - batch_times = ( - self.state.isel(time=slice(idx + 2, idx + self.ar_steps)) - .time.values.astype(str) - .tolist() + # === Forcing features === + # Now batch-static features are just part of forcing, + # repeated over temporal dimension + # Load water coverage + sample_datetime = sample_name[:10] + water_path = os.path.join( + self.sample_dir_path, f"wtr_{sample_datetime}.npy" ) + water_cover_features = torch.tensor( + np.load(water_path), dtype=torch.float32 + ).unsqueeze( + -1 + ) # (dim_x, dim_y, 1) + # Flatten + water_cover_features = water_cover_features.flatten(0, 1) # (N_grid, 1) + # Expand over temporal dimension + water_cover_expanded = water_cover_features.unsqueeze(0).expand( + self.sample_length - 2, -1, -1 # -2 as added on after windowing + ) # (sample_len, N_grid, 1) - # init_states: (2, N_grid, d_features) - # target_states: (ar_steps-2, N_grid, d_features) - # forcing: (ar_steps-2, N_grid, d_windowed_forcing) - # boundary: (ar_steps-2, N_grid, d_windowed_boundary) - # batch_times: (ar_steps-2,) - return init_states, target_states, forcing, boundary, batch_times + # TOA flux + flux_path = os.path.join( + self.sample_dir_path, + f"nwp_toa_downwelling_shortwave_flux_{sample_datetime}.npy", + ) + flux = torch.tensor(np.load(flux_path), dtype=torch.float32).unsqueeze( + -1 + ) # (N_t', dim_x, dim_y, 1) + if self.standardize: + flux = (flux - self.flux_mean) / self.flux_std -class WeatherDataModule(pl.LightningDataModule): - """DataModule for weather data.""" + # Flatten and subsample flux forcing + flux = flux.flatten(1, 2) # (N_t, N_grid, 1) + flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1) + flux = flux[ + init_id : (init_id + self.sample_length) + ] # (sample_len, N_grid, 1) - def __init__( - self, - ar_steps_train=3, - ar_steps_eval=25, - batch_size=4, - num_workers=16, - ): - super().__init__() - self.ar_steps_train = ar_steps_train - self.ar_steps_eval = ar_steps_eval - self.batch_size = batch_size - self.num_workers = num_workers - self.train_dataset = None - self.val_dataset = None - self.test_dataset = None - - def setup(self, stage=None): - if stage == "fit" or stage is None: - self.train_dataset = WeatherDataset( - split="train", - ar_steps=self.ar_steps_train, - batch_size=self.batch_size, - ) - self.val_dataset = WeatherDataset( - split="val", - ar_steps=self.ar_steps_eval, - batch_size=self.batch_size, - ) + # Time of day and year + dt_obj = dt.datetime.strptime(sample_datetime, "%Y%m%d%H") + dt_obj = dt_obj + dt.timedelta( + hours=2 + subsample_index + ) # Offset for first index + # Extract for initial step + init_hour_in_day = dt_obj.hour + start_of_year = dt.datetime(dt_obj.year, 1, 1) + init_seconds_into_year = (dt_obj - start_of_year).total_seconds() - if stage == "test" or stage is None: - self.test_dataset = WeatherDataset( - split="test", - ar_steps=self.ar_steps_eval, - batch_size=self.batch_size, - ) + # Add increments for all steps + hour_inc = ( + torch.arange(self.sample_length) * self.subsample_step + ) # (sample_len,) + hour_of_day = ( + init_hour_in_day + hour_inc + ) # (sample_len,), Can be > 24 but ok + second_into_year = ( + init_seconds_into_year + hour_inc * 3600 + ) # (sample_len,) + # can roll over to next year, ok because periodicity - def train_dataloader(self): - """Load train dataset.""" - return torch.utils.data.DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + # Encode as sin/cos + seconds_in_year = 365 * 24 * 3600 + hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,) + year_angle = ( + (second_into_year / seconds_in_year) * 2 * torch.pi + ) # (sample_len,) + datetime_forcing = torch.stack( + ( + torch.sin(hour_angle), + torch.cos(hour_angle), + torch.sin(year_angle), + torch.cos(year_angle), + ), + dim=1, + ) # (N_t, 4) + datetime_forcing = (datetime_forcing + 1) / 2 # Rescale to [0,1] + datetime_forcing = datetime_forcing.unsqueeze(1).expand( + -1, flux.shape[1], -1 + ) # (sample_len, N_grid, 4) - def val_dataloader(self): - """Load validation dataset.""" - return torch.utils.data.DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + # Put forcing features together + forcing_features = torch.cat( + (flux, datetime_forcing), dim=-1 + ) # (sample_len, N_grid, d_forcing) - def test_dataloader(self): - """Load test dataset.""" - return torch.utils.data.DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + # Combine forcing over each window of 3 time steps + forcing_windowed = torch.cat( + ( + forcing_features[:-2], + forcing_features[1:-1], + forcing_features[2:], + ), + dim=2, + ) # (sample_len-2, N_grid, 3*d_forcing) + # Now index 0 of ^ corresponds to forcing at index 0-2 of sample + + # batch-static water cover is added after windowing, + # as it is static over time + forcing = torch.cat((water_cover_expanded, forcing_windowed), dim=2) + # (sample_len-2, N_grid, forcing_dim) + + return init_states, target_states, forcing diff --git a/plot_graph.py b/plot_graph.py index e246200d..0670963f 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -19,6 +19,12 @@ def main(): Plot graph structure in 3D using plotly """ parser = ArgumentParser(description="Plot graph") + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) parser.add_argument( "--graph", type=str, @@ -36,14 +42,9 @@ def main(): default=0, help="If the axis should be displayed (default: 0 (No))", ) - parser.add_argument( - "--data_config", - type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file (default: neural_lam/data_config.yaml)", - ) args = parser.parse_args() + config_loader = utils.ConfigLoader(args.data_config) # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) @@ -62,12 +63,12 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] - config_loader = utils.ConfigLoader(args.data_config) - xy = config_loader.get_nwp_xy() - grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2) # (N_grid, 2) - pos_max = np.max(np.abs(grid_xy)) - grid_pos = grid_xy / pos_max # Divide by maximum coordinate + grid_static_features = utils.load_static_data(config_loader.dataset.name)[ + "grid_static_features" + ] + # Extract values needed, turn to numpy + grid_pos = grid_static_features[:, :2].numpy() # Add in z-dimension z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) grid_pos = np.concatenate( diff --git a/pyproject.toml b/pyproject.toml index 619f444f..b513a258 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,3 @@ -[project] -name = "neural_lam" -version = "0.1.0" - -[tool.setuptools] -packages = ["neural_lam"] - [tool.black] line-length = 80 @@ -49,9 +42,12 @@ ignore = [ "create_mesh.py", # Disable linting for now, as major rework is planned/expected ] # Temporary fix for import neural_lam statements until set up as proper package -init-hook = 'import sys; sys.path.append(".")' +init-hook='import sys; sys.path.append(".")' [tool.pylint.TYPECHECK] -generated-members = ["numpy.*", "torch.*"] +generated-members = [ + "numpy.*", + "torch.*", +] [tool.pylint.'MESSAGES CONTROL'] disable = [ "C0114", # 'missing-module-docstring', Do not require module docstrings @@ -60,11 +56,10 @@ disable = [ "R0913", # 'too-many-arguments', Allow many function arguments "R0914", # 'too-many-locals', Allow many local variables "W0223", # 'abstract-method', Subclasses do not have to override all abstract methods - "C0411", # 'wrong-import-order', Allow for isort to handle import order ] [tool.pylint.DESIGN] -max-statements = 100 # Allow for some more involved functions +max-statements=100 # Allow for some more involved functions [tool.pylint.IMPORTS] -allow-any-import-level = "neural_lam" +allow-any-import-level="neural_lam" [tool.pylint.SIMILARITIES] -min-similarity-lines = 10 +min-similarity-lines=10 diff --git a/requirements.txt b/requirements.txt index cb9bd425..5a2111b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,9 +10,6 @@ Cartopy>=0.22.0 pyproj>=3.4.1 tueplots>=0.0.8 plotly>=5.15.0 -xarray>=0.20.1 -zarr>=2.10.0 -dask>=2022.0.0 # for dev codespell>=2.0.0 black>=21.9b0 diff --git a/train_model.py b/train_model.py index a8b02f58..da109fdf 100644 --- a/train_model.py +++ b/train_model.py @@ -6,7 +6,6 @@ # Third-party import pytorch_lightning as pl import torch -import wandb from lightning_fabric.utilities import seed # First-party @@ -14,7 +13,7 @@ from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel -from neural_lam.weather_dataset import WeatherDataModule +from neural_lam.weather_dataset import WeatherDataset MODELS = { "graph_lam": GraphLAM, @@ -30,7 +29,12 @@ def main(): parser = ArgumentParser( description="Train or evaluate NeurWP models for LAM" ) - + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", + ) parser.add_argument( "--model", type=str, @@ -38,16 +42,17 @@ def main(): help="Model architecture to train/evaluate (default: graph_lam)", ) parser.add_argument( - "--data_config", - type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file (default: neural_lam/data_config.yaml)", + "--subset_ds", + type=int, + default=0, + help="Use only a small subset of the dataset, for debugging" + "(default: 0=false)", ) parser.add_argument( "--seed", type=int, default=42, help="random seed (default: 42)" ) parser.add_argument( - "--num_workers", + "--n_workers", type=int, default=4, help="Number of workers in data loader (default: 4)", @@ -124,11 +129,11 @@ def main(): # Training options parser.add_argument( - "--ar_steps_train", + "--ar_steps", type=int, - default=3, - help="Number of steps to unroll prediction for in loss function " - "(default: 3)", + default=1, + help="Number of steps to unroll prediction for in loss (1-19) " + "(default: 1)", ) parser.add_argument( "--control_only", @@ -146,9 +151,9 @@ def main(): parser.add_argument( "--step_length", type=int, - default=1, + default=3, help="Step length in hours to consider single time step 1-3 " - "(default: 1)", + "(default: 3)", ) parser.add_argument( "--lr", type=float, default=1e-3, help="learning rate (default: 0.001)" @@ -168,13 +173,6 @@ def main(): help="Eval model on given data split (val/test) " "(default: None (train model))", ) - parser.add_argument( - "--ar_steps_eval", - type=int, - default=25, - help="Number of steps to unroll prediction for in loss function " - "(default: 25)", - ) parser.add_argument( "--n_example_pred", type=int, @@ -183,12 +181,12 @@ def main(): "(default: 1)", ) - # Logging Options + # Logger Settings parser.add_argument( "--wandb_project", type=str, - default="neural-lam", - help="Wandb project to log to (default: neural-lam)", + default="neural_lam", + help="Wandb project name (default: neural_lam)", ) parser.add_argument( "--val_steps_log", @@ -210,6 +208,8 @@ def main(): ) args = parser.parse_args() + config_loader = utils.ConfigLoader(args.data_config) + # Asserts for arguments assert args.model in MODELS, f"Unknown model: {args.model}" assert args.step_length <= 3, "Too high step length" @@ -224,12 +224,34 @@ def main(): # Set seed seed.seed_everything(args.seed) - # Create datamodule - data_module = WeatherDataModule( - ar_steps_train=args.ar_steps_train, - ar_steps_eval=args.ar_steps_eval, - batch_size=args.batch_size, - num_workers=args.num_workers, + + # Load data + train_loader = torch.utils.data.DataLoader( + WeatherDataset( + config_loader.dataset.name, + pred_length=args.ar_steps, + split="train", + subsample_step=args.step_length, + subset=bool(args.subset_ds), + control_only=args.control_only, + ), + args.batch_size, + shuffle=True, + num_workers=args.n_workers, + ) + max_pred_length = (65 // args.step_length) - 2 # 19 + val_loader = torch.utils.data.DataLoader( + WeatherDataset( + config_loader.dataset.name, + pred_length=max_pred_length, + split="val", + subsample_step=args.step_length, + subset=bool(args.subset_ds), + control_only=args.control_only, + ), + args.batch_size, + shuffle=False, + num_workers=args.n_workers, ) # Instantiate model + trainer @@ -252,10 +274,9 @@ def main(): else: model = model_class(args) + prefix = "subset-" if args.subset_ds else "" if args.eval: - prefix = f"eval-{args.eval}-" - else: - prefix = "train-" + prefix = prefix + f"eval-{args.eval}-" run_name = ( f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}" @@ -285,13 +306,35 @@ def main(): # Only init once, on rank 0 only if trainer.global_rank == 0: utils.init_wandb_metrics( - logger, val_steps=args.val_steps_log + logger, args.val_steps_log ) # Do after wandb.init - wandb.save(args.data_config) + if args.eval: - trainer.test(model=model, datamodule=data_module, ckpt_path=args.load) + if args.eval == "val": + eval_loader = val_loader + else: # Test + eval_loader = torch.utils.data.DataLoader( + WeatherDataset( + config_loader.dataset.name, + pred_length=max_pred_length, + split="test", + subsample_step=args.step_length, + subset=bool(args.subset_ds), + ), + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + ) + + print(f"Running evaluation on {args.eval}") + trainer.test(model=model, dataloaders=eval_loader) else: - trainer.fit(model=model, datamodule=data_module) + # Train model + trainer.fit( + model=model, + train_dataloaders=train_loader, + val_dataloaders=val_loader, + ) if __name__ == "__main__": From 0ba441bc8e34b95eab6550a0e279136548e5db53 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 21 May 2024 12:24:53 +0200 Subject: [PATCH 24/26] Implementation PR-review feedback --- README.md | 2 +- create_grid_features.py | 4 +-- create_mesh.py | 4 +-- create_parameter_weights.py | 9 ++++-- neural_lam/config.py | 59 +++++++++++++++++++++++++++++++++++ neural_lam/data_config.yaml | 8 ++--- neural_lam/models/ar_model.py | 24 +++++++------- neural_lam/utils.py | 56 --------------------------------- neural_lam/vis.py | 7 +++-- plot_graph.py | 4 +-- train_model.py | 8 ++--- 11 files changed, 96 insertions(+), 89 deletions(-) create mode 100644 neural_lam/config.py diff --git a/README.md b/README.md index fc5675e8..ba0bb3fe 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Still, some restrictions are inevitable: ## A note on the limited area setting Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). There are still some parts of the code that is quite specific for the MEPS area use case. -This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/data_config.yaml`). +This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ). If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. We would be happy to support such enhancements. See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. diff --git a/create_grid_features.py b/create_grid_features.py index e5b9c49a..c3714368 100644 --- a/create_grid_features.py +++ b/create_grid_features.py @@ -7,7 +7,7 @@ import torch # First-party -from neural_lam import utils +from neural_lam import config def main(): @@ -22,7 +22,7 @@ def main(): help="Path to data config file (default: neural_lam/data_config.yaml)", ) args = parser.parse_args() - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) static_dir_path = os.path.join("data", config_loader.dataset.name, "static") diff --git a/create_mesh.py b/create_mesh.py index 477ddf55..f04b4d4b 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -13,7 +13,7 @@ from torch_geometric.utils.convert import from_networkx # First-party -from neural_lam import utils +from neural_lam import config def plot_graph(graph, title=None): @@ -189,7 +189,7 @@ def main(): args = parser.parse_args() # Load grid positions - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) static_dir_path = os.path.join("data", config_loader.dataset.name, "static") graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index fd8c38cd..cae1ae3e 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -8,7 +8,7 @@ from tqdm import tqdm # First-party -from neural_lam import utils +from neural_lam import config from neural_lam.weather_dataset import WeatherDataset @@ -43,7 +43,7 @@ def main(): ) args = parser.parse_args() - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) static_dir_path = os.path.join("data", config_loader.dataset.name, "static") # Create parameter weights based on height @@ -57,7 +57,10 @@ def main(): "500": 0.03, } w_list = np.array( - [w_dict[par.split("_")[-2]] for par in config_loader.dataset.var_names] + [ + w_dict[par.split("_")[-2]] + for par in config_loader.dataset.var_longnames + ] ) print("Saving parameter weights...") np.save( diff --git a/neural_lam/config.py b/neural_lam/config.py new file mode 100644 index 00000000..e758e09c --- /dev/null +++ b/neural_lam/config.py @@ -0,0 +1,59 @@ +import functools +from pathlib import Path + +import cartopy.crs as ccrs +import yaml + + +class Config: + """ + Class for loading configuration files. + + This class loads a configuration file and provides a way to access its + values as attributes. + """ + + def __init__(self, values): + self.values = values + + @classmethod + def from_file(cls, filepath): + if filepath.endswith(".yaml"): + with open(filepath, encoding="utf-8", mode="r") as file: + return cls(values=yaml.safe_load(file)) + else: + raise NotImplementedError(Path(filepath).suffix) + + def __getattr__(self, name): + keys = name.split(".") + value = self.values + for key in keys: + if key in value: + value = value[key] + else: + return None + if isinstance(value, dict): + return Config(values=value) + return value + + def __getitem__(self, key): + value = self.values[key] + if isinstance(value, dict): + return Config(values=value) + return value + + def __contains__(self, key): + return key in self.values + + def num_data_vars(self): + """Return the number of data variables for a given key.""" + return len(self.dataset.var_names) + + @functools.cached_property + def coords_projection(self): + """Return the projection.""" + proj_config = self.values["projection"] + proj_class_name = proj_config["class"] + proj_class = getattr(ccrs, proj_class_name) + proj_params = proj_config.get("kwargs", {}) + return proj_class(**proj_params) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index 213825ff..f16a4a30 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -1,6 +1,6 @@ dataset: name: meps_example - vars: + var_names: - pres_0g - pres_0s - nlwrs_0 @@ -18,7 +18,7 @@ dataset: - wvint_0 - z_1000 - z_500 - units: + var_units: - Pa - Pa - r"$\mathrm{W}/\mathrm{m}^2$" @@ -36,7 +36,7 @@ dataset: - r"$\mathrm{kg}/\mathrm{m}^2$" - r"$\mathrm{m}^2/\mathrm{s}^2$" - r"$\mathrm{m}^2/\mathrm{s}^2$" - var_names: + var_longnames: - pres_heightAboveGround_0_instant - pres_heightAboveSea_0_instant - nlwrs_heightAboveGround_0_accum @@ -54,7 +54,7 @@ dataset: - wvint_entireAtmosphere_0_instant - z_isobaricInhPa_1000_instant - z_isobaricInhPa_500_instant - forcing_dim: 16 + num_forcing_features: 16 grid_shape_state: [268, 238] projection: class: LambertConformal diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index da2654f0..9cda9fc2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -9,7 +9,7 @@ import wandb # First-party -from neural_lam import metrics, utils, vis +from neural_lam import config, metrics, utils, vis class ARModel(pl.LightningModule): @@ -25,7 +25,7 @@ def __init__(self, args): super().__init__() self.save_hyperparameters() self.args = args - self.config_loader = utils.ConfigLoader(args.data_config) + self.config_loader = config.Config.from_file(args.data_config) # Load static features for grid/data static_data_dict = utils.load_static_data( @@ -61,7 +61,7 @@ def __init__(self, args): self.grid_dim = ( 2 * self.config_loader.num_data_vars() + grid_static_dim - + self.config_loader.dataset.forcing_dim + + self.config_loader.dataset.num_forcing_features ) # Instantiate loss function @@ -246,7 +246,7 @@ def validation_step(self, batch, batch_idx): # Log loss per time step forward and mean val_log_dict = { f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_log + for step in self.args.val_steps_to_log } val_log_dict["val_mean_loss"] = mean_loss self.log_dict( @@ -294,7 +294,7 @@ def test_step(self, batch, batch_idx): # Log loss per time step forward and mean test_log_dict = { f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_log + for step in self.args.val_steps_to_log } test_log_dict["test_mean_loss"] = mean_loss @@ -329,7 +329,7 @@ def test_step(self, batch, batch_idx): prediction, target, pred_std, average_grid=False ) # (B, pred_steps, num_grid_nodes) log_spatial_losses = spatial_loss[ - :, [step - 1 for step in self.args.val_steps_log] + :, [step - 1 for step in self.args.val_steps_to_log] ] self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) @@ -408,8 +408,8 @@ def plot_examples(self, batch, n_examples, prediction=None): ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - self.config_loader.dataset.vars, - self.config_loader.dataset.units, + self.config_loader.dataset.var_names, + self.config_loader.dataset.var_units, var_vranges, ) ) @@ -420,7 +420,7 @@ def plot_examples(self, batch, n_examples, prediction=None): { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - self.config_loader.dataset.vars, var_figs + self.config_loader.dataset.var_names, var_figs ) } ) @@ -476,7 +476,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): # Check if metrics are watched, log exact values for specific vars if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.dataset.vars[var_i] + var = self.config_loader.dataset.var_nums[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ @@ -549,7 +549,7 @@ def on_test_epoch_end(self): title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", ) for t_i, loss_map in zip( - self.args.val_steps_log, mean_spatial_loss + self.args.val_steps_to_log, mean_spatial_loss ) ] @@ -566,7 +566,7 @@ def on_test_epoch_end(self): ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs): + for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 528560e3..836b04ed 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -2,10 +2,8 @@ import os # Third-party -import cartopy.crs as ccrs import numpy as np import torch -import yaml from torch import nn from tueplots import bundles, figsizes @@ -270,57 +268,3 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") - - -class ConfigLoader: - """ - Class for loading configuration files. - - This class loads a YAML configuration file and provides a way to access - its values as attributes. - """ - - def __init__(self, config_path, values=None): - self.config_path = config_path - if values is None: - self.values = self.load_config() - else: - self.values = values - - def load_config(self): - """Load configuration file.""" - with open(self.config_path, encoding="utf-8", mode="r") as file: - return yaml.safe_load(file) - - def __getattr__(self, name): - keys = name.split(".") - value = self.values - for key in keys: - if key in value: - value = value[key] - else: - return None - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __getitem__(self, key): - value = self.values[key] - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __contains__(self, key): - return key in self.values - - def num_data_vars(self): - """Return the number of data variables for a given key.""" - return len(self.dataset.vars) - - def projection(self): - """Return the projection.""" - proj_config = self.values["projection"] - proj_class_name = proj_config["class"] - proj_class = getattr(ccrs, proj_class_name) - proj_params = proj_config.get("kwargs", {}) - return proj_class(**proj_params) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 7a4d3730..2b6abf15 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -51,7 +51,7 @@ def plot_error_map(errors, data_config, title=None, step_length=3): y_ticklabels = [ f"{name} ({unit})" for name, unit in zip( - data_config.dataset.vars, data_config.dataset.units + data_config.dataset.var_names, data_config.dataset.var_units ) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.projection()}, + subplot_kw={"projection": data_config.coords_projection()}, ) # Plot pred and target @@ -135,7 +135,8 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): ) # Faded border region fig, ax = plt.subplots( - figsize=(5, 4.8), subplot_kw={"projection": data_config.projection()} + figsize=(5, 4.8), + subplot_kw={"projection": data_config.coords_projection()}, ) ax.coastlines() # Add coastline outlines diff --git a/plot_graph.py b/plot_graph.py index 0670963f..40b2b41d 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -7,7 +7,7 @@ import torch_geometric as pyg # First-party -from neural_lam import utils +from neural_lam import config, utils MESH_HEIGHT = 0.1 MESH_LEVEL_DIST = 0.2 @@ -44,7 +44,7 @@ def main(): ) args = parser.parse_args() - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) diff --git a/train_model.py b/train_model.py index da109fdf..390da6d4 100644 --- a/train_model.py +++ b/train_model.py @@ -9,7 +9,7 @@ from lightning_fabric.utilities import seed # First-party -from neural_lam import utils +from neural_lam import config, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel @@ -189,7 +189,7 @@ def main(): help="Wandb project name (default: neural_lam)", ) parser.add_argument( - "--val_steps_log", + "--val_steps_to_log", type=list, default=[1, 2, 3, 5, 10, 15, 19], help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])", @@ -208,7 +208,7 @@ def main(): ) args = parser.parse_args() - config_loader = utils.ConfigLoader(args.data_config) + config_loader = config.Config.from_file(args.data_config) # Asserts for arguments assert args.model in MODELS, f"Unknown model: {args.model}" @@ -306,7 +306,7 @@ def main(): # Only init once, on rank 0 only if trainer.global_rank == 0: utils.init_wandb_metrics( - logger, args.val_steps_log + logger, args.val_steps_to_log ) # Do after wandb.init if args.eval: From 37bdf8f0cb62c46f642bb96eff779224be3fa45d Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 21 May 2024 12:29:28 +0200 Subject: [PATCH 25/26] fix linter --- neural_lam/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/neural_lam/config.py b/neural_lam/config.py index e758e09c..5891ea74 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -1,6 +1,8 @@ +# Standard library import functools from pathlib import Path +# Third-party import cartopy.crs as ccrs import yaml @@ -18,6 +20,7 @@ def __init__(self, values): @classmethod def from_file(cls, filepath): + """Load a configuration file.""" if filepath.endswith(".yaml"): with open(filepath, encoding="utf-8", mode="r") as file: return cls(values=yaml.safe_load(file)) From 5d10591ea6cc6b48ab6d817a96385bc9c86a923e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 22 May 2024 10:19:22 +0200 Subject: [PATCH 26/26] Updated changelog for future references --- CHANGELOG.md | 16 +++++++++++++--- neural_lam/weather_dataset.py | 1 + 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19ecdd41..823ac8b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added +- Replaced `constants.py` with `data_config.yaml` for data configuration management + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + - new metrics (`nll` and `crps_gauss`) and `metrics` submodule, stddiv output option [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) @joeloskarsson @@ -24,6 +27,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Updated scripts and modules to use `data_config.yaml` instead of `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- Added new flags in `train_model.py` for configuration previously in `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + - moved batch-static features ("water cover") into forcing component return by `WeatherDataset` [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) @joeloskarsson @@ -44,8 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) @joeloskarsson - ## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication -(https://arxiv.org/abs/2309.17370) +() diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 0c01ae1d..a782806b 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -218,6 +218,7 @@ def __getitem__(self, idx): # can roll over to next year, ok because periodicity # Encode as sin/cos + # ! Make this more flexible in a separate create_forcings.py script seconds_in_year = 365 * 24 * 3600 hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,) year_angle = (