diff --git a/CHANGELOG.md b/CHANGELOG.md index d4a90b14b..574eaaf86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Refactored StormCast training example +- Enhancements and bug fixes to DoMINO model and training example +- Enhancement to parameterize DoMINO model with inlet velocity ### Deprecated diff --git a/examples/cfd/external_aerodynamics/domino/README.md b/examples/cfd/external_aerodynamics/domino/README.md index 3b5cf9ad8..6ee408887 100644 --- a/examples/cfd/external_aerodynamics/domino/README.md +++ b/examples/cfd/external_aerodynamics/domino/README.md @@ -57,12 +57,13 @@ To train and test the DoMINO model on AWS dataset, follow these steps: 1. The DoMINO model allows for training both volume and surface fields using a single model but currently the recommendation is to train the volume and surface models separately. This can be controlled through the config file. -2. MSE loss for the volume model and RMSE for surface model gives the best results. +2. MSE loss for both volume and surface model gives the best results. 3. The surface and volume variable names can change but currently the code only supports the variables in that specific order. For example, Pressure, wall-shear and turb-visc for surface and velocity, pressure and turb-visc for volume. 4. Bounding box is configurable and will depend on the usecase. The presets are suitable for the AWS DriveAer-ML dataset. +5. Integral loss factor is currently set to 0.0 as it adversely impacts the training. The DoMINO model architecture is used to support the Real Time Wind Tunnel OV Blueprint demo presented at Supercomputing' 24. Some of the results are shown below. diff --git a/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml b/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml index bc2a07816..5478a8cce 100644 --- a/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml +++ b/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml @@ -67,6 +67,7 @@ model: use_only_normals: true # Use only surface normals and not surface area integral_loss_scaling_factor: 0 # Scale integral loss by this factor normalization: min_max_scaling # or mean_std_scaling + encode_parameters: false # encode inlet velocity and air density in the model geometry_rep: # Hyperparameters for geometry representation network base_filters: 16 geo_conv: @@ -89,6 +90,9 @@ model: neighbors_in_radius: 64 radius: 0.05 # 0.2 in expt 7 base_layer: 512 + parameter_model: + base_layer: 512 + scaling_params: [30.0, 1.226] # [inlet_velocity, air_density] train: # Training configurable parameters epochs: 500 diff --git a/examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py b/examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py index 3f756097b..544761f42 100644 --- a/examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py +++ b/examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py @@ -32,6 +32,9 @@ from modulus.utils.domino.utils import * from torch.utils.data import Dataset +AIR_DENSITY = 1.205 +STREAM_VELOCITY = 30.00 + class DriveSimPaths: @staticmethod @@ -187,6 +190,8 @@ def __getitem__(self, idx): "volume_mesh_centers": np.float32(volume_coordinates), "surface_fields": np.float32(surface_fields), "filename": cfd_filename, + "stream_velocity": STREAM_VELOCITY, + "air_density": AIR_DENSITY, } diff --git a/examples/cfd/external_aerodynamics/domino/src/test.py b/examples/cfd/external_aerodynamics/domino/src/test.py index 2a2885b0e..51d1a84e2 100644 --- a/examples/cfd/external_aerodynamics/domino/src/test.py +++ b/examples/cfd/external_aerodynamics/domino/src/test.py @@ -86,9 +86,9 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): data_dict = dict_to_device(data_dict, device) # Non-dimensionalization factors - air_density = data_dict["air_density"].cpu().numpy() - stream_velocity = data_dict["stream_velocity"].cpu().numpy() - length_scale = data_dict["length_scale"].cpu().numpy() + air_density = data_dict["air_density"] + stream_velocity = data_dict["stream_velocity"] + length_scale = data_dict["length_scale"] # STL nodes geo_centers = data_dict["geometry_coordinates"] @@ -187,6 +187,8 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): volume_mesh_centers_batch, geo_encoding_local, pos_encoding, + stream_velocity, + air_density, num_sample_points=20, eval_mode="volume", ) @@ -195,17 +197,20 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): :, start_idx:end_idx ] = tpredictions_batch.cpu().numpy() - # print( - # f"Volume predictions calculated, Time taken={float(time.time()-start_time)}" - # ) prediction_vol = unnormalize(prediction_vol, vol_factors[0], vol_factors[1]) - prediction_vol[:, :, :3] = prediction_vol[:, :, :3] * stream_velocity[0] + prediction_vol[:, :, :3] = ( + prediction_vol[:, :, :3] * stream_velocity[0, 0].cpu().numpy() + ) prediction_vol[:, :, 3] = ( - prediction_vol[:, :, 3] * stream_velocity[0] ** 2.0 * air_density[0] + prediction_vol[:, :, 3] + * stream_velocity[0, 0].cpu().numpy() ** 2.0 + * air_density[0, 0].cpu().numpy() ) prediction_vol[:, :, 4] = ( - prediction_vol[:, :, 4] * stream_velocity[0] * length_scale[0] + prediction_vol[:, :, 4] + * stream_velocity[0, 0].cpu().numpy() + * length_scale[0].cpu().numpy() ) else: prediction_vol = None @@ -276,6 +281,8 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): surface_neighbors_normals_batch, surface_areas_batch, surface_neighbors_areas_batch, + stream_velocity, + air_density, ) ) else: @@ -283,6 +290,8 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): surface_mesh_centers_batch, geo_encoding_local, pos_encoding, + stream_velocity, + air_density, num_sample_points=1, eval_mode="surface", ) @@ -293,12 +302,10 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): prediction_surf = ( unnormalize(prediction_surf, surf_factors[0], surf_factors[1]) - * stream_velocity[0] ** 2.0 - * air_density[0] + * stream_velocity[0, 0].cpu().numpy() ** 2.0 + * air_density[0, 0].cpu().numpy() ) - # print( - # f"Surface predictions calculated, Time taken={float(time.time()-start_time)}" - # ) + else: prediction_surf = None @@ -638,8 +645,12 @@ def main(cfg: DictConfig): "volume_min_max": vol_grid_max_min, "surface_min_max": surf_grid_max_min, "length_scale": np.array(length_scale, dtype=np.float32), - "stream_velocity": np.array(STREAM_VELOCITY, dtype=np.float32), - "air_density": np.array(AIR_DENSITY, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), } elif model_type == "surface": data_dict = { @@ -656,8 +667,12 @@ def main(cfg: DictConfig): "surface_fields": np.float32(surface_fields), "surface_min_max": np.float32(surf_grid_max_min), "length_scale": np.array(length_scale, dtype=np.float32), - "stream_velocity": np.array(STREAM_VELOCITY, dtype=np.float32), - "air_density": np.array(AIR_DENSITY, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), } elif model_type == "volume": data_dict = { @@ -674,8 +689,12 @@ def main(cfg: DictConfig): "volume_min_max": vol_grid_max_min, "surface_min_max": surf_grid_max_min, "length_scale": np.array(length_scale, dtype=np.float32), - "stream_velocity": np.array(STREAM_VELOCITY, dtype=np.float32), - "air_density": np.array(AIR_DENSITY, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), } data_dict = { diff --git a/examples/cfd/external_aerodynamics/domino/src/train.py b/examples/cfd/external_aerodynamics/domino/src/train.py index 4c788acf3..a4ef19ed3 100644 --- a/examples/cfd/external_aerodynamics/domino/src/train.py +++ b/examples/cfd/external_aerodynamics/domino/src/train.py @@ -166,7 +166,7 @@ def relative_loss_fn_surface(output, target, normals, padded_value=-10): def relative_loss_fn_area(output, target, normals, area, padded_value=-10): scale_factor = 1.0 # Get this from the dataset - area = area * 10**5 + area = area * 10**4 ws_pred = torch.sqrt( output[:, :, 1:2] ** 2.0 + output[:, :, 2:3] ** 2.0 + output[:, :, 3:4] ** 2.0 ) @@ -232,7 +232,7 @@ def relative_loss_fn_area(output, target, normals, area, padded_value=-10): def mse_loss_fn_area(output, target, normals, area, padded_value=-10): scale_factor = 1.0 # Get this from the dataset - area = area * 10**5 + area = area * 10**4 ws_pred = torch.sqrt( output[:, :, 1:2] ** 2.0 + output[:, :, 2:3] ** 2.0 + output[:, :, 3:4] ** 2.0 ) diff --git a/modulus/datapipes/cae/domino_datapipe.py b/modulus/datapipes/cae/domino_datapipe.py index e2764b459..bc4abe712 100644 --- a/modulus/datapipes/cae/domino_datapipe.py +++ b/modulus/datapipes/cae/domino_datapipe.py @@ -51,9 +51,6 @@ ) from modulus.utils.sdf import signed_distance_field -AIR_DENSITY = 1.205 -STREAM_VELOCITY = 30.00 - class DoMINODataPipe(Dataset): """ @@ -169,6 +166,14 @@ def __getitem__(self, idx): mesh_indices_flattened = data_dict["stl_faces"] stl_sizes = data_dict["stl_areas"] + # Check if stream velocity in keys + if "stream_velocity" in data_dict.keys(): + STREAM_VELOCITY = data_dict["stream_velocity"] + AIR_DENSITY = data_dict["air_density"] + else: + AIR_DENSITY = 1.205 + STREAM_VELOCITY = 30.00 + # length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) @@ -486,8 +491,12 @@ def __getitem__(self, idx): "volume_min_max": vol_grid_max_min, "surface_min_max": surf_grid_max_min, "length_scale": length_scale, - "stream_velocity": STREAM_VELOCITY, - "air_density": AIR_DENSITY, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), } elif self.model_type == "surface": return { @@ -504,8 +513,12 @@ def __getitem__(self, idx): "surface_fields": surface_fields, "surface_min_max": surf_grid_max_min, "length_scale": length_scale, - "stream_velocity": STREAM_VELOCITY, - "air_density": AIR_DENSITY, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), } elif self.model_type == "volume": return { @@ -522,8 +535,12 @@ def __getitem__(self, idx): "volume_min_max": vol_grid_max_min, "surface_min_max": surf_grid_max_min, "length_scale": length_scale, - "stream_velocity": STREAM_VELOCITY, - "air_density": AIR_DENSITY, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), } diff --git a/modulus/models/domino/model.py b/modulus/models/domino/model.py index 61867d106..1574e6c21 100644 --- a/modulus/models/domino/model.py +++ b/modulus/models/domino/model.py @@ -347,7 +347,7 @@ def __init__(self, input_features, model_parameters=None): self.bn2 = nn.BatchNorm1d(int(base_layer)) self.bn3 = nn.BatchNorm1d(int(base_layer)) - self.activation = F.gelu + self.activation = F.relu def forward(self, x, padded_value=-10): facets = x @@ -358,6 +358,32 @@ def forward(self, x, padded_value=-10): return facets +class ParameterModel(nn.Module): + """Layer to encode parameters such as inlet velocity and air density""" + + def __init__(self, input_features, model_parameters=None): + super(ParameterModel, self).__init__() + self.input_features = input_features + + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.bn1 = nn.BatchNorm1d(base_layer) + self.bn2 = nn.BatchNorm1d(int(base_layer)) + self.bn3 = nn.BatchNorm1d(int(base_layer)) + + self.activation = F.relu + + def forward(self, x, padded_value=-10): + params = x + params = self.activation(self.fc1(params)) + params = self.activation(self.fc2(params)) + params = self.fc3(params) + + return params + + class AggregationModel(nn.Module): """Layer to aggregate local geometry encoding with basis functions""" @@ -370,15 +396,15 @@ def __init__( self.new_change = new_change base_layer = model_parameters.base_layer self.fc1 = nn.Linear(self.input_features, base_layer) - self.fc2 = nn.Linear(base_layer, int(base_layer / 2)) - self.fc3 = nn.Linear(int(base_layer / 2), int(base_layer / 4)) - self.fc4 = nn.Linear(int(base_layer / 4), int(base_layer / 8)) - self.fc5 = nn.Linear(int(base_layer / 8), self.output_features) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.fc4 = nn.Linear(int(base_layer), int(base_layer)) + self.fc5 = nn.Linear(int(base_layer), self.output_features) self.bn1 = nn.BatchNorm1d(base_layer) - self.bn2 = nn.BatchNorm1d(int(base_layer / 2)) - self.bn3 = nn.BatchNorm1d(int(base_layer / 4)) - self.bn4 = nn.BatchNorm1d(int(base_layer / 8)) - self.activation = F.gelu + self.bn2 = nn.BatchNorm1d(int(base_layer)) + self.bn3 = nn.BatchNorm1d(int(base_layer)) + self.bn4 = nn.BatchNorm1d(int(base_layer)) + self.activation = F.relu def forward(self, x): out = self.activation(self.fc1(x)) @@ -461,6 +487,8 @@ class DoMINO(nn.Module): >>> volume_coordinates = torch.randn(bsize, 100, 3).to(device) >>> vol_grid_max_min = torch.randn(bsize, 2, 3).to(device) >>> surf_grid_max_min = torch.randn(bsize, 2, 3).to(device) + >>> stream_velocity = torch.randn(bsize, 1).to(device) + >>> air_density = torch.randn(bsize, 1).to(device) >>> input_dict = { ... "pos_volume_closest": pos_normals_closest_vol, ... "pos_volume_center_of_mass": pos_normals_com_vol, @@ -480,6 +508,8 @@ class DoMINO(nn.Module): ... "volume_mesh_centers": volume_coordinates, ... "volume_min_max": vol_grid_max_min, ... "surface_min_max": surf_grid_max_min, + ... "stream_velocity": stream_velocity, + ... "air_density": air_density, ... } >>> output = model(input_dict) Module ... @@ -508,6 +538,8 @@ def __init__( self.surface_neighbors = model_parameters.surface_neighbors self.use_surface_normals = model_parameters.use_surface_normals self.use_only_normals = model_parameters.use_only_normals + self.encode_parameters = model_parameters.encode_parameters + self.param_scaling_factors = model_parameters.parameter_model.scaling_params if self.use_surface_normals: if self.use_only_normals: @@ -517,6 +549,15 @@ def __init__( else: input_features_surface = input_features + if self.encode_parameters: + # Defining the parameter model + base_layer_p = model_parameters.parameter_model.base_layer + self.parameter_model = ParameterModel( + input_features=2, model_parameters=model_parameters.parameter_model + ) + else: + base_layer_p = 0 + self.geo_rep = GeometryRep( input_features=input_features, model_parameters=model_parameters, @@ -583,7 +624,7 @@ def __init__( base_layer_geo = model_parameters.geometry_local.base_layer self.fc_1 = nn.Linear(self.neighbors_in_radius * 3, base_layer_geo) self.fc_2 = nn.Linear(base_layer_geo, base_layer_geo) - self.activation = F.gelu + self.activation = F.relu # Aggregation model if self.output_features_surf is not None: @@ -594,7 +635,8 @@ def __init__( AggregationModel( input_features=position_encoder_base_neurons + base_layer_nn - + base_layer_geo, + + base_layer_geo + + base_layer_p, output_features=1, model_parameters=model_parameters.aggregation_model, ) @@ -608,7 +650,8 @@ def __init__( AggregationModel( input_features=position_encoder_base_neurons + base_layer_nn - + base_layer_geo, + + base_layer_geo + + base_layer_p, output_features=1, model_parameters=model_parameters.aggregation_model, ) @@ -724,6 +767,8 @@ def calculate_solution_with_neighbors( surface_neighbors_normals, surface_areas, surface_neighbors_areas, + inlet_velocity, + air_density, ): """Function to approximate solution given the neighborhood information""" num_variables = self.num_variables_surf @@ -731,6 +776,26 @@ def calculate_solution_with_neighbors( agg_model = self.agg_model_surf num_sample_points = surface_mesh_neighbors.shape[2] + 1 + if self.encode_parameters: + inlet_velocity = torch.unsqueeze(inlet_velocity, 1) + inlet_velocity = inlet_velocity.expand( + inlet_velocity.shape[0], + surface_mesh_centers.shape[1], + inlet_velocity.shape[2], + ) + inlet_velocity = inlet_velocity / self.param_scaling_factors[0] + + air_density = torch.unsqueeze(air_density, 1) + air_density = air_density.expand( + air_density.shape[0], + surface_mesh_centers.shape[1], + air_density.shape[2], + ) + air_density = air_density / self.param_scaling_factors[1] + + params = torch.cat((inlet_velocity, air_density), axis=-1) + param_encoding = self.parameter_model(params) + if self.use_surface_normals: if self.use_only_normals: surface_mesh_centers = torch.cat( @@ -773,6 +838,8 @@ def calculate_solution_with_neighbors( ) basis_f = nn_basis[f](volume_m_c) output = torch.cat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = torch.cat((output, param_encoding), axis=-1) if p == 0: output_center = agg_model[f](output) else: @@ -798,6 +865,8 @@ def calculate_solution( volume_mesh_centers, encoding_g, encoding_node, + inlet_velocity, + air_density, eval_mode, num_sample_points=20, noise_intensity=50, @@ -811,6 +880,25 @@ def calculate_solution( num_variables = self.num_variables_surf nn_basis = self.nn_basis_surf agg_model = self.agg_model_surf + + if self.encode_parameters: + inlet_velocity = torch.unsqueeze(inlet_velocity, 1) + inlet_velocity = inlet_velocity.expand( + inlet_velocity.shape[0], + volume_mesh_centers.shape[1], + inlet_velocity.shape[2], + ) + inlet_velocity = inlet_velocity / self.param_scaling_factors[0] + + air_density = torch.unsqueeze(air_density, 1) + air_density = air_density.expand( + air_density.shape[0], volume_mesh_centers.shape[1], air_density.shape[2] + ) + air_density = air_density / self.param_scaling_factors[1] + + params = torch.cat((inlet_velocity, air_density), axis=-1) + param_encoding = self.parameter_model(params) + for f in range(num_variables): for p in range(num_sample_points): if p == 0: @@ -827,6 +915,8 @@ def calculate_solution( volume_m_c = volume_mesh_centers + noise basis_f = nn_basis[f](volume_m_c) output = torch.cat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = torch.cat((output, param_encoding), axis=-1) if p == 0: output_center = agg_model[f](output) else: @@ -863,6 +953,10 @@ def forward( surf_max = data_dict["surface_min_max"][:, 1] surf_min = data_dict["surface_min_max"][:, 0] + # Parameters + stream_velocity = data_dict["stream_velocity"] + air_density = data_dict["air_density"] + if self.output_features_vol is not None: # Represent geometry on computational grid # Computational domain grid @@ -931,6 +1025,8 @@ def forward( volume_mesh_centers, encoding_g_vol, encoding_node_vol, + stream_velocity, + air_density, eval_mode="volume", ) else: @@ -950,7 +1046,7 @@ def forward( surface_neighbors_areas = torch.unsqueeze(surface_neighbors_areas, -1) # Calculate local geometry encoding for surface encoding_g_surf = self.geo_encoding_local_surface( - encoding_g, surface_mesh_centers, s_grid + 0.5 * encoding_g_surf, surface_mesh_centers, s_grid ) # Approximate solution on surface cell center @@ -959,6 +1055,8 @@ def forward( surface_mesh_centers, encoding_g_surf, encoding_node_surf, + stream_velocity, + air_density, eval_mode="surface", num_sample_points=1, noise_intensity=500, @@ -973,6 +1071,8 @@ def forward( surface_neighbors_normals, surface_areas, surface_neighbors_areas, + stream_velocity, + air_density, ) else: output_surf = None diff --git a/test/models/data/domino_output.pth b/test/models/data/domino_output.pth index 07236993b..8826bdc63 100644 Binary files a/test/models/data/domino_output.pth and b/test/models/data/domino_output.pth differ diff --git a/test/models/test_domino.py b/test/models/test_domino.py index cf3786fda..ce3744bb0 100644 --- a/test/models/test_domino.py +++ b/test/models/test_domino.py @@ -106,6 +106,11 @@ class aggregation_model: class position_encoder: base_neurons: int = 128 + @dataclass + class parameter_model: + base_layer: int = 128 + scaling_params: Sequence = (30.0, 1.226) + model_type: str = "combined" interp_res: Sequence = (128, 64, 48) use_sdf_in_basis_func: bool = True @@ -114,6 +119,7 @@ class position_encoder: num_surface_neighbors: int = 21 use_surface_normals: bool = True use_only_normals: bool = True + encode_parameters: bool = False geometry_rep = geometry_rep nn_basis_functions = nn_basis_functions aggregation_model = aggregation_model @@ -149,6 +155,8 @@ class position_encoder: volume_coordinates = torch.randn(bsize, 100, 3).to(device) vol_grid_max_min = torch.randn(bsize, 2, 3).to(device) surf_grid_max_min = torch.randn(bsize, 2, 3).to(device) + stream_velocity = torch.randn(bsize, 1).to(device) + air_density = torch.randn(bsize, 1).to(device) input_dict = { "pos_volume_closest": pos_normals_closest_vol, "pos_volume_center_of_mass": pos_normals_com_vol, @@ -168,6 +176,8 @@ class position_encoder: "volume_mesh_centers": volume_coordinates, "volume_min_max": vol_grid_max_min, "surface_min_max": surf_grid_max_min, + "stream_velocity": stream_velocity, + "air_density": air_density, } # assert common.validate_forward_accuracy(