Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DrivAerML dataset support to FIGConvNet example. #753

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- DoMINO model architecture, datapipe and training recipe
- DrivAerML dataset support in FIGConvNet example.

### Changed

Expand Down
47 changes: 47 additions & 0 deletions examples/cfd/external_aerodynamics/figconvnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,52 @@ We demonstrate a 140k× speed-up compared to GPU-accelerated
computational fluid dynamics (CFD) simulators and over 2× improvement in pressure prediction
over prior deep learning arts.

## Supported datasets

The current version of the code supports the following datasets:

### DrivAerNet

Both DrivAerNet and DrivAerNet++ datasets [[4](#references)] are supported.
Please follow the instructions on the [dataset GitHub](https://github.com/Mohamedelrefaie/DrivAerNet)
page to download the dataset.

The corresponding experiment configuration file can be found at: `./configs/experiment/drivaernet/figconv_unet.yaml`.
For more details, refer to the [Training section](#training).

### DrivAerML

DrivAerML dataset [[6](#references)] is supported but requires
conversion of the dataset to a more efficient binary format.
This format is supported by models like XAeroNet and FIGConvNet
and represents efficient storage of the original meshes as
partitioned graphs.
For more details on how to convert the original DrivAerML dataset
to partitioned dataset, refer to
[XAeroNet example README](https://github.com/NVIDIA/modulus/tree/main/examples/cfd/external_aerodynamics/xaeronet#training-the-xaeronet-s-model),
steps 1 to 5.

The binary dataset should have the following structure:

```text
├─ partitions
│ ├─ graph_partitions_1.bin
│ ├─ graph_partitions_1.bin
│ ├─ ...
├─ test_partitions
│ ├─ graph_partitions_100.bin
│ ├─ graph_partitions_101.bin
│ ├─ ...
├─ validation_partitions
│ ├─ graph_partitions_200.bin
│ ├─ graph_partitions_201.bin
│ ├─ ...
└─ global_stats.json
```

The corresponding experiment configuration file can be found at:
`./configs/experiment/drivaerml/figconv_unet.yaml`.

## Installation

FIGConvUNet dependencies can be installed with `pip install`, for example:
Expand Down Expand Up @@ -107,3 +153,4 @@ mpirun -np 2 python train.py \
3. [Ahmed body wiki](https://www.cfd-online.com/Wiki/Ahmed_body)
4. [DrivAerNet: A Parametric Car Dataset for Data-Driven Aerodynamic Design and Graph-Based Drag Prediction](https://arxiv.org/abs/2403.08055)
5. [Deep Learning for Real-Time Aerodynamic Evaluations of Arbitrary Vehicle Shapes](https://arxiv.org/abs/2108.05798)
6. [DrivAerML: High-Fidelity Computational Fluid Dynamics Dataset for Road-Car External Aerodynamics](https://arxiv.org/abs/2408.11969)
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ train:
shuffle: true # can also specify the shuffle buffer size, e.g. shuffle_buffer_size: 100
num_workers: 0
pin_memory: true
lr_scheduler_mode: epoch # epoch or iteration.

eval:
loss: null
Expand Down Expand Up @@ -83,6 +84,7 @@ loggers:
run_name: default
entity: modulus # nvr-ai-algo
group_name:
mode: online

log_pointcloud: false # save pointclouds

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

_target_: src.data.DrivAerMLDataModule
_convert_: all

data_path: ???
num_points: 100_000
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# @package _global_

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

defaults:
- /data: drivaerml
- /model: figconv_unet_drivaerml
- /loss: mseloss
- /optimizer: adam
- /lr_scheduler: steplr

train:
batch_size: 8
num_epochs: 200

model:
aabb_max: [ 2.0, 1.8, 2.6]
aabb_min: [-2.0, -1.8, -1.5]
hidden_channels: [16, 16, 16]
kernel_size: 5
# mlp_channels: [2048, 2048]
neighbor_search_type: radius
num_down_blocks: 1
num_levels: 2
pooling_layers: [2]
pooling_type: max
reductions: [mean]
resolution_memory_format_pairs:
- ${res_mem_pair:b_xc_y_z, [ 5, 150, 100]}
- ${res_mem_pair:b_yc_x_z, [250, 3, 100]}
- ${res_mem_pair:b_zc_x_y, [250, 150, 2]}
use_rel_pos_encode: true

lr_scheduler:
step_size: 50

loggers:
wandb:
entity: modulus
project_name: car-cfd
group_name: fignet-drivaerml
run_name: FIGConvNet-level2-16,16,16-res250-150-100-pool-max-2-aabb-20x18x26-ks5-np32768-b8x2

seed: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 0.001
epochs: ${..train.num_epochs}
steps_per_epoch: 59
pct_start: 0.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

defaults:
- figconv_unet_drivaer

_target_: src.networks.FIGConvUNetDrivAerML

out_channels: 4
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# limitations under the License.

from .base_datamodule import BaseDataModule
from .drivaerml_datamodule import DrivAerMLDataModule
from .drivaernet_datamodule import DrivAerNetDataModule
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable, Iterable
import json
from pathlib import Path
from typing import Any

import dgl
import numpy as np

import torch
from torch import Tensor
from torch.utils.data import Dataset

from src.data.base_datamodule import BaseDataModule


class DrivAerMLPartitionedDataset(Dataset):
"""DrivAerML partitioned dataset.

The dataset enables reading meshes from binary files
generated by the Modulus XAeroNet data processing utility.
It reads data from all partitions in the source file and
samples a predefined number of points,
as specified by the `num_points` parameter.

The dataset expects the data to have the following structure:

```
├─ partitions
│ ├─ graph_partitions_1.bin
│ ├─ graph_partitions_1.bin
│ ├─ ...
├─ test_partitions
│ ├─ graph_partitions_100.bin
│ ├─ graph_partitions_101.bin
│ ├─ ...
├─ validation_partitions
│ ├─ graph_partitions_200.bin
│ ├─ graph_partitions_201.bin
│ ├─ ...
└─ global_stats.json
```

where `partitions` directory contains training samples.

For further details and examples on how to create a partitioned dataset,
refer to: `modulus/examples/cfd/external_aerodynamics/xaeronet/surface`.

Parameters:
----------
data_path (Path): path the directory that contains binary partitioned files.
num_points (int): number of points to sample from the mesh.
"""

def __init__(
self,
data_path: Path,
num_points: int = 0,
) -> None:
self.p_files = sorted(data_path.glob("*.bin"))
self.num_points = num_points

def __len__(self) -> int:
return len(self.p_files)

def __getitem__(self, index: int) -> dict[str, Any]:
if not 0 <= index < len(self):
raise IndexError(f"Invalid {index = } expected in [0, {len(self)})")

gs, _ = dgl.load_graphs(str(self.p_files[index]))

coords = torch.cat([g.ndata["coordinates"] for g in gs], dim=0)
# Sample indices from the combined graph.
n_total = coords.shape[0]
if n_total >= self.num_points:
indices = np.random.choice(n_total, self.num_points)
else:
indices = np.concatenate(
(
np.arange(n_total),
np.random.choice(n_total, self.num_points - n_total),
)
)
coords = coords[indices]
pressure = torch.cat([g.ndata["pressure"] for g in gs], dim=0)[indices]
shear_stress = torch.cat([g.ndata["shear_stress"] for g in gs], dim=0)[indices]

return {
"coordinates": coords,
"pressure": pressure,
"shear_stress": shear_stress,
"design": self.p_files[index].stem.removeprefix("graph_partitions_"),
}


class DrivAerMLDataModule(BaseDataModule):
"""DrivAerML data module"""

def __init__(
self,
data_path: str | Path,
num_points: int = 0,
stats_filename: str = "global_stats.json",
**kwargs,
):
data_path = Path(data_path)
self._train_dataset = DrivAerMLPartitionedDataset(
data_path / "partitions", num_points
)
self._val_dataset = DrivAerMLPartitionedDataset(
data_path / "validation_partitions", num_points
)
self._test_dataset = DrivAerMLPartitionedDataset(
data_path / "test_partitions", num_points
)

with open(data_path / stats_filename, "r", encoding="utf-8") as f:
stats = json.load(f)

self.mean = {k: torch.tensor(v) for k, v in stats["mean"].items()}
self.std = {k: torch.tensor(v) for k, v in stats["std_dev"].items()}

@property
def train_dataset(self):
return self._train_dataset

@property
def val_dataset(self):
return self._val_dataset

@property
def test_dataset(self):
return self._test_dataset

def encode(self, x: Tensor, name: str):
return (x - self.mean[name].to(x.device)) / self.std[name].to(x.device)

def decode(self, x: Tensor, name: str):
return x * self.std[name].to(x.device) + self.mean[name].to(x.device)
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .figconvunet_drivaer import FIGConvUNetDrivAerNet
from .figconvunet_drivaer import FIGConvUNetDrivAerML, FIGConvUNetDrivAerNet
Loading