Skip to content

Commit

Permalink
Inlet velocity and density parameterization in domino (#760)
Browse files Browse the repository at this point in the history
* adding inlet velocity and air density to domino

* bug fixes

* fixing bug test script

* updating test

* fixing bugs and model updates

---------

Co-authored-by: Mohammad Amin Nabian <[email protected]>
  • Loading branch information
RishikeshRanade and mnabian authored Jan 21, 2025
1 parent b7b6265 commit 90c8f68
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 45 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Refactored StormCast training example
- Enhancements and bug fixes to DoMINO model and training example
- Enhancement to parameterize DoMINO model with inlet velocity

### Deprecated

Expand Down
3 changes: 2 additions & 1 deletion examples/cfd/external_aerodynamics/domino/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ To train and test the DoMINO model on AWS dataset, follow these steps:
1. The DoMINO model allows for training both volume and surface fields using a single model
but currently the recommendation is to train the volume and surface models separately. This
can be controlled through the config file.
2. MSE loss for the volume model and RMSE for surface model gives the best results.
2. MSE loss for both volume and surface model gives the best results.
3. The surface and volume variable names can change but currently the code only
supports the variables in that specific order. For example, Pressure, wall-shear
and turb-visc for surface and velocity, pressure and turb-visc for volume.
4. Bounding box is configurable and will depend on the usecase. The presets are
suitable for the AWS DriveAer-ML dataset.
5. Integral loss factor is currently set to 0.0 as it adversely impacts the training.

The DoMINO model architecture is used to support the Real Time Wind Tunnel OV Blueprint
demo presented at Supercomputing' 24. Some of the results are shown below.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ model:
use_only_normals: true # Use only surface normals and not surface area
integral_loss_scaling_factor: 0 # Scale integral loss by this factor
normalization: min_max_scaling # or mean_std_scaling
encode_parameters: false # encode inlet velocity and air density in the model
geometry_rep: # Hyperparameters for geometry representation network
base_filters: 16
geo_conv:
Expand All @@ -89,6 +90,9 @@ model:
neighbors_in_radius: 64
radius: 0.05 # 0.2 in expt 7
base_layer: 512
parameter_model:
base_layer: 512
scaling_params: [30.0, 1.226] # [inlet_velocity, air_density]

train: # Training configurable parameters
epochs: 500
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from modulus.utils.domino.utils import *
from torch.utils.data import Dataset

AIR_DENSITY = 1.205
STREAM_VELOCITY = 30.00


class DriveSimPaths:
@staticmethod
Expand Down Expand Up @@ -187,6 +190,8 @@ def __getitem__(self, idx):
"volume_mesh_centers": np.float32(volume_coordinates),
"surface_fields": np.float32(surface_fields),
"filename": cfd_filename,
"stream_velocity": STREAM_VELOCITY,
"air_density": AIR_DENSITY,
}


Expand Down
59 changes: 39 additions & 20 deletions examples/cfd/external_aerodynamics/domino/src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors):
data_dict = dict_to_device(data_dict, device)

# Non-dimensionalization factors
air_density = data_dict["air_density"].cpu().numpy()
stream_velocity = data_dict["stream_velocity"].cpu().numpy()
length_scale = data_dict["length_scale"].cpu().numpy()
air_density = data_dict["air_density"]
stream_velocity = data_dict["stream_velocity"]
length_scale = data_dict["length_scale"]

# STL nodes
geo_centers = data_dict["geometry_coordinates"]
Expand Down Expand Up @@ -187,6 +187,8 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors):
volume_mesh_centers_batch,
geo_encoding_local,
pos_encoding,
stream_velocity,
air_density,
num_sample_points=20,
eval_mode="volume",
)
Expand All @@ -195,17 +197,20 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors):
:, start_idx:end_idx
] = tpredictions_batch.cpu().numpy()

# print(
# f"Volume predictions calculated, Time taken={float(time.time()-start_time)}"
# )
prediction_vol = unnormalize(prediction_vol, vol_factors[0], vol_factors[1])

prediction_vol[:, :, :3] = prediction_vol[:, :, :3] * stream_velocity[0]
prediction_vol[:, :, :3] = (
prediction_vol[:, :, :3] * stream_velocity[0, 0].cpu().numpy()
)
prediction_vol[:, :, 3] = (
prediction_vol[:, :, 3] * stream_velocity[0] ** 2.0 * air_density[0]
prediction_vol[:, :, 3]
* stream_velocity[0, 0].cpu().numpy() ** 2.0
* air_density[0, 0].cpu().numpy()
)
prediction_vol[:, :, 4] = (
prediction_vol[:, :, 4] * stream_velocity[0] * length_scale[0]
prediction_vol[:, :, 4]
* stream_velocity[0, 0].cpu().numpy()
* length_scale[0].cpu().numpy()
)
else:
prediction_vol = None
Expand Down Expand Up @@ -276,13 +281,17 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors):
surface_neighbors_normals_batch,
surface_areas_batch,
surface_neighbors_areas_batch,
stream_velocity,
air_density,
)
)
else:
tpredictions_batch = model.module.calculate_solution(
surface_mesh_centers_batch,
geo_encoding_local,
pos_encoding,
stream_velocity,
air_density,
num_sample_points=1,
eval_mode="surface",
)
Expand All @@ -293,12 +302,10 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors):

prediction_surf = (
unnormalize(prediction_surf, surf_factors[0], surf_factors[1])
* stream_velocity[0] ** 2.0
* air_density[0]
* stream_velocity[0, 0].cpu().numpy() ** 2.0
* air_density[0, 0].cpu().numpy()
)
# print(
# f"Surface predictions calculated, Time taken={float(time.time()-start_time)}"
# )

else:
prediction_surf = None

Expand Down Expand Up @@ -638,8 +645,12 @@ def main(cfg: DictConfig):
"volume_min_max": vol_grid_max_min,
"surface_min_max": surf_grid_max_min,
"length_scale": np.array(length_scale, dtype=np.float32),
"stream_velocity": np.array(STREAM_VELOCITY, dtype=np.float32),
"air_density": np.array(AIR_DENSITY, dtype=np.float32),
"stream_velocity": np.expand_dims(
np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1
),
"air_density": np.expand_dims(
np.array(AIR_DENSITY, dtype=np.float32), axis=-1
),
}
elif model_type == "surface":
data_dict = {
Expand All @@ -656,8 +667,12 @@ def main(cfg: DictConfig):
"surface_fields": np.float32(surface_fields),
"surface_min_max": np.float32(surf_grid_max_min),
"length_scale": np.array(length_scale, dtype=np.float32),
"stream_velocity": np.array(STREAM_VELOCITY, dtype=np.float32),
"air_density": np.array(AIR_DENSITY, dtype=np.float32),
"stream_velocity": np.expand_dims(
np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1
),
"air_density": np.expand_dims(
np.array(AIR_DENSITY, dtype=np.float32), axis=-1
),
}
elif model_type == "volume":
data_dict = {
Expand All @@ -674,8 +689,12 @@ def main(cfg: DictConfig):
"volume_min_max": vol_grid_max_min,
"surface_min_max": surf_grid_max_min,
"length_scale": np.array(length_scale, dtype=np.float32),
"stream_velocity": np.array(STREAM_VELOCITY, dtype=np.float32),
"air_density": np.array(AIR_DENSITY, dtype=np.float32),
"stream_velocity": np.expand_dims(
np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1
),
"air_density": np.expand_dims(
np.array(AIR_DENSITY, dtype=np.float32), axis=-1
),
}

data_dict = {
Expand Down
4 changes: 2 additions & 2 deletions examples/cfd/external_aerodynamics/domino/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def relative_loss_fn_surface(output, target, normals, padded_value=-10):

def relative_loss_fn_area(output, target, normals, area, padded_value=-10):
scale_factor = 1.0 # Get this from the dataset
area = area * 10**5
area = area * 10**4
ws_pred = torch.sqrt(
output[:, :, 1:2] ** 2.0 + output[:, :, 2:3] ** 2.0 + output[:, :, 3:4] ** 2.0
)
Expand Down Expand Up @@ -232,7 +232,7 @@ def relative_loss_fn_area(output, target, normals, area, padded_value=-10):

def mse_loss_fn_area(output, target, normals, area, padded_value=-10):
scale_factor = 1.0 # Get this from the dataset
area = area * 10**5
area = area * 10**4
ws_pred = torch.sqrt(
output[:, :, 1:2] ** 2.0 + output[:, :, 2:3] ** 2.0 + output[:, :, 3:4] ** 2.0
)
Expand Down
35 changes: 26 additions & 9 deletions modulus/datapipes/cae/domino_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@
)
from modulus.utils.sdf import signed_distance_field

AIR_DENSITY = 1.205
STREAM_VELOCITY = 30.00


class DoMINODataPipe(Dataset):
"""
Expand Down Expand Up @@ -169,6 +166,14 @@ def __getitem__(self, idx):
mesh_indices_flattened = data_dict["stl_faces"]
stl_sizes = data_dict["stl_areas"]

# Check if stream velocity in keys
if "stream_velocity" in data_dict.keys():
STREAM_VELOCITY = data_dict["stream_velocity"]
AIR_DENSITY = data_dict["air_density"]
else:
AIR_DENSITY = 1.205
STREAM_VELOCITY = 30.00

#
length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0))

Expand Down Expand Up @@ -486,8 +491,12 @@ def __getitem__(self, idx):
"volume_min_max": vol_grid_max_min,
"surface_min_max": surf_grid_max_min,
"length_scale": length_scale,
"stream_velocity": STREAM_VELOCITY,
"air_density": AIR_DENSITY,
"stream_velocity": np.expand_dims(
np.array(STREAM_VELOCITY, dtype=np.float32), -1
),
"air_density": np.expand_dims(
np.array(AIR_DENSITY, dtype=np.float32), -1
),
}
elif self.model_type == "surface":
return {
Expand All @@ -504,8 +513,12 @@ def __getitem__(self, idx):
"surface_fields": surface_fields,
"surface_min_max": surf_grid_max_min,
"length_scale": length_scale,
"stream_velocity": STREAM_VELOCITY,
"air_density": AIR_DENSITY,
"stream_velocity": np.expand_dims(
np.array(STREAM_VELOCITY, dtype=np.float32), -1
),
"air_density": np.expand_dims(
np.array(AIR_DENSITY, dtype=np.float32), -1
),
}
elif self.model_type == "volume":
return {
Expand All @@ -522,8 +535,12 @@ def __getitem__(self, idx):
"volume_min_max": vol_grid_max_min,
"surface_min_max": surf_grid_max_min,
"length_scale": length_scale,
"stream_velocity": STREAM_VELOCITY,
"air_density": AIR_DENSITY,
"stream_velocity": np.expand_dims(
np.array(STREAM_VELOCITY, dtype=np.float32), -1
),
"air_density": np.expand_dims(
np.array(AIR_DENSITY, dtype=np.float32), -1
),
}


Expand Down
Loading

0 comments on commit 90c8f68

Please sign in to comment.