diff --git a/create_mesh.py b/create_mesh.py index da881594..04d7468b 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -125,7 +125,11 @@ def mk_2d_graph(xy, nx, ny): # add diagonal edges g.add_edges_from( - [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] + [ + ((x, y), (x + 1, y + 1)) + for x in range(nx - 1) + for y in range(ny - 1) + ] + [ ((x + 1, y), (x, y + 1)) for x in range(nx - 1) @@ -343,7 +347,9 @@ def main(): .reshape(int(n / nx) ** 2, 2) ) ij = [tuple(x) for x in ij] - G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij))) + G[lev] = networkx.relabel_nodes( + G[lev], dict(zip(G[lev].nodes, ij)) + ) G_tot = networkx.compose(G_tot, G[lev]) # Relabel mesh nodes to start with 0 diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index fc78e638..d29f84ec 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -38,7 +38,9 @@ def __init__(self, args): self.output_std = bool(args.output_std) if self.output_std: # Pred. dim. in grid cell - self.grid_output_dim = 2 * self.config_loader.num_data_vars("state") + self.grid_output_dim = 2 * self.config_loader.num_data_vars( + "state" + ) else: # Pred. dim. in grid cell self.grid_output_dim = self.config_loader.num_data_vars("state") @@ -87,7 +89,9 @@ def __init__(self, args): self.spatial_loss_maps = [] # Load normalization statistics - self.normalization_stats = self.config_loader.load_normalization_stats() + self.normalization_stats = ( + self.config_loader.load_normalization_stats() + ) if self.normalization_stats is not None: for ( var_name, @@ -236,7 +240,11 @@ def training_step(self, batch): log_dict = {"train_loss": batch_loss} self.log_dict( - log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True + log_dict, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True, ) return batch_loss @@ -362,7 +370,8 @@ def test_step(self, batch, batch_idx): ): # Need to plot more example predictions n_additional_examples = min( - prediction.shape[0], self.n_example_pred - self.plotted_examples + prediction.shape[0], + self.n_example_pred - self.plotted_examples, ) self.plot_examples( @@ -584,10 +593,14 @@ def on_test_epoch_end(self): ) for loss_map in mean_spatial_loss ] - pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") + pdf_loss_maps_dir = os.path.join( + wandb.run.dir, "spatial_loss_maps" + ) os.makedirs(pdf_loss_maps_dir, exist_ok=True) for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs): - fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) + fig.savefig( + os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf") + ) # save mean spatial loss as .pt file also torch.save( mean_spatial_loss.cpu(), diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index fb5df62d..723a3f3c 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -118,8 +118,8 @@ def predict_step(self, prev_state, prev_prev_state, forcing): dim=-1, ) - # Embed all features - grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) + # Embed all features # (B, num_grid_nodes, d_h) + grid_emb = self.grid_embedder(grid_features) g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) mesh_emb = self.embedd_mesh_nodes() @@ -149,9 +149,8 @@ def predict_step(self, prev_state, prev_prev_state, forcing): ) # (B, num_grid_nodes, d_h) # Map to output dimension, only for grid - net_output = self.output_map( - grid_rep - ) # (B, num_grid_nodes, d_grid_out) + # (B, num_grid_nodes, d_grid_out) + net_output = self.output_map(grid_rep) if self.output_std: pred_delta_mean, pred_std_raw = net_output.chunk( diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index f767fba0..e4dc74ac 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -32,7 +32,9 @@ def __init__(self, args): # Define sub-models # Feature embedders for mesh - self.mesh_embedder = utils.make_mlp([mesh_dim] + self.mlp_blueprint_end) + self.mesh_embedder = utils.make_mlp( + [mesh_dim] + self.mlp_blueprint_end + ) self.m2m_embedder = utils.make_mlp([m2m_dim] + self.mlp_blueprint_end) # GNNs diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py index 4d7eb94c..335ea8c7 100644 --- a/neural_lam/models/hi_lam.py +++ b/neural_lam/models/hi_lam.py @@ -101,9 +101,8 @@ def mesh_down_step( reversed(same_gnns[:-1]), ): # Extract representations - send_node_rep = mesh_rep_levels[ - level_l + 1 - ] # (B, N_mesh[l+1], d_h) + # (B, N_mesh[l+1], d_h) + send_node_rep = mesh_rep_levels[level_l + 1] rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) down_edge_rep = mesh_down_rep[level_l] same_edge_rep = mesh_same_rep[level_l] @@ -139,9 +138,8 @@ def mesh_up_step( zip(up_gnns, same_gnns[1:]), start=1 ): # Extract representations - send_node_rep = mesh_rep_levels[ - level_l - 1 - ] # (B, N_mesh[l-1], d_h) + # (B, N_mesh[l-1], d_h) + send_node_rep = mesh_rep_levels[level_l - 1] rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) up_edge_rep = mesh_up_rep[level_l - 1] same_edge_rep = mesh_same_rep[level_l] @@ -183,7 +181,11 @@ def hi_processor_step( self.mesh_up_same_gnns, ): # Down - mesh_rep_levels, mesh_same_rep, mesh_down_rep = self.mesh_down_step( + ( + mesh_rep_levels, + mesh_same_rep, + mesh_down_rep, + ) = self.mesh_down_step( mesh_rep_levels, mesh_same_rep, mesh_down_rep, @@ -200,5 +202,6 @@ def hi_processor_step( up_same_gnns, ) - # Note: We return all, even though only down edges really are used later + # Note: We return all, even though only down edges really are used + # later return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py index 740824e1..b6f619d1 100644 --- a/neural_lam/models/hi_lam_parallel.py +++ b/neural_lam/models/hi_lam_parallel.py @@ -27,7 +27,9 @@ def __init__(self, args): + list(self.mesh_down_edge_index) ) total_edge_index = torch.cat(total_edge_index_list, dim=1) - self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list] + self.edge_split_sections = [ + ei.shape[1] for ei in total_edge_index_list + ] if args.processor_layers == 0: self.processor = lambda x, edge_attr: (x, edge_attr) @@ -86,11 +88,12 @@ def hi_processor_step( mesh_same_rep = mesh_edge_rep_sections[: self.num_levels] mesh_up_rep = mesh_edge_rep_sections[ - self.num_levels : self.num_levels + (self.num_levels - 1) + self.num_levels : self.num_levels + (self.num_levels - 1) # noqa ] mesh_down_rep = mesh_edge_rep_sections[ - self.num_levels + (self.num_levels - 1) : + self.num_levels + (self.num_levels - 1) : # noqa ] # Last are down edges - # Note: We return all, even though only down edges really are used later + # Note: We return all, even though only down edges really are used + # later return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 18584d2e..f7ecafb3 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -40,7 +40,9 @@ def load_graph(graph_name, device="cpu"): graph_dir_path = os.path.join("graphs", graph_name) def loads_file(fn): - return torch.load(os.path.join(graph_dir_path, fn), map_location=device) + return torch.load( + os.path.join(graph_dir_path, fn), map_location=device + ) # Load edges (edge_index) m2m_edge_index = BufferList( @@ -53,7 +55,8 @@ def loads_file(fn): hierarchical = n_levels > 1 # Nor just single level mesh graph # Load static edge features - m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f) + # List of (M_m2m[l], d_edge_f) + m2m_features = loads_file("m2m_features.pt") g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f) m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 6ce630c7..6762a450 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -39,7 +39,9 @@ def __init__( self.state = self.config_loader.process_dataset("state", self.split) assert self.state is not None, "State dataset not found" - self.forcing = self.config_loader.process_dataset("forcing", self.split) + self.forcing = self.config_loader.process_dataset( + "forcing", self.split + ) self.boundary = self.config_loader.process_dataset( "boundary", self.split ) @@ -69,7 +71,10 @@ def __init__( method="nearest", ) .pad( - time=(self.boundary_window // 2, self.boundary_window // 2), + time=( + self.boundary_window // 2, + self.boundary_window // 2, + ), mode="edge", ) .rolling(time=self.boundary_window, center=True) @@ -87,7 +92,9 @@ def __getitem__(self, idx): ) forcing = ( - self.forcing_windowed.isel(time=slice(idx + 2, idx + self.ar_steps)) + self.forcing_windowed.isel( + time=slice(idx + 2, idx + self.ar_steps) + ) .stack(variable_window=("variable", "window")) .values if self.forcing is not None diff --git a/pyproject.toml b/pyproject.toml index 619f444f..192afbc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,10 +6,10 @@ version = "0.1.0" packages = ["neural_lam"] [tool.black] -line-length = 80 +line-length = 79 [tool.isort] -default_section = "THIRDPARTY" +default_section = "THIRDPARTY" #codespell:ignore profile = "black" # Headings import_heading_stdlib = "Standard library"