From b012f589a40f680a4c26add5b77e9755b2fee6c3 Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Wed, 6 Nov 2024 09:43:32 -0800 Subject: [PATCH 1/6] Add surface DrivAerML dataset support to FIGConvNet. --- examples/cfd/figconvnet/configs/base.yaml | 2 + .../figconvnet/configs/data/drivaerml.yaml | 21 + .../experiment/drivaerml/figconv_unet.yaml | 58 ++ .../configs/lr_scheduler/onecyclelr.yaml | 21 + .../configs/model/figconv_unet_drivaerml.yaml | 22 + .../notebooks/figconvnet_drivaerml.ipynb | 747 ++++++++++++++++++ examples/cfd/figconvnet/src/data/__init__.py | 1 + .../src/data/drivaerml_datamodule.py | 119 +++ .../cfd/figconvnet/src/networks/__init__.py | 2 +- .../src/networks/figconvunet_drivaer.py | 136 ++++ .../cfd/figconvnet/src/utils/eval_funcs.py | 7 + examples/cfd/figconvnet/src/utils/loggers.py | 3 + examples/cfd/figconvnet/train.py | 16 +- 13 files changed, 1145 insertions(+), 10 deletions(-) create mode 100644 examples/cfd/figconvnet/configs/data/drivaerml.yaml create mode 100644 examples/cfd/figconvnet/configs/experiment/drivaerml/figconv_unet.yaml create mode 100644 examples/cfd/figconvnet/configs/lr_scheduler/onecyclelr.yaml create mode 100644 examples/cfd/figconvnet/configs/model/figconv_unet_drivaerml.yaml create mode 100644 examples/cfd/figconvnet/notebooks/figconvnet_drivaerml.ipynb create mode 100644 examples/cfd/figconvnet/src/data/drivaerml_datamodule.py diff --git a/examples/cfd/figconvnet/configs/base.yaml b/examples/cfd/figconvnet/configs/base.yaml index 4e2af078a..8efde34c0 100644 --- a/examples/cfd/figconvnet/configs/base.yaml +++ b/examples/cfd/figconvnet/configs/base.yaml @@ -36,6 +36,7 @@ train: shuffle: true # can also specify the shuffle buffer size, e.g. shuffle_buffer_size: 100 num_workers: 0 pin_memory: true + lr_scheduler_mode: epoch # epoch or iteration. eval: loss: null @@ -83,6 +84,7 @@ loggers: run_name: default entity: modulus # nvr-ai-algo group_name: + mode: online log_pointcloud: false # save pointclouds diff --git a/examples/cfd/figconvnet/configs/data/drivaerml.yaml b/examples/cfd/figconvnet/configs/data/drivaerml.yaml new file mode 100644 index 000000000..26799b106 --- /dev/null +++ b/examples/cfd/figconvnet/configs/data/drivaerml.yaml @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: src.data.DrivAerMLDataModule +_convert_: all + +data_path: ??? +num_points: 100_000 diff --git a/examples/cfd/figconvnet/configs/experiment/drivaerml/figconv_unet.yaml b/examples/cfd/figconvnet/configs/experiment/drivaerml/figconv_unet.yaml new file mode 100644 index 000000000..e3c25820d --- /dev/null +++ b/examples/cfd/figconvnet/configs/experiment/drivaerml/figconv_unet.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - /data: drivaerml + - /model: figconv_unet_drivaerml + - /loss: mseloss + - /optimizer: adam + - /lr_scheduler: steplr + +train: + batch_size: 8 + num_epochs: 200 + +model: + aabb_max: [ 2.0, 1.8, 2.6] + aabb_min: [-2.0, -1.8, -1.5] + hidden_channels: [16, 16, 16] + kernel_size: 5 + # mlp_channels: [2048, 2048] + neighbor_search_type: radius + num_down_blocks: 1 + num_levels: 2 + pooling_layers: [2] + pooling_type: max + reductions: [mean] + resolution_memory_format_pairs: + - ${res_mem_pair:b_xc_y_z, [ 5, 150, 100]} + - ${res_mem_pair:b_yc_x_z, [250, 3, 100]} + - ${res_mem_pair:b_zc_x_y, [250, 150, 2]} + use_rel_pos_encode: true + +lr_scheduler: + step_size: 50 + +loggers: + wandb: + entity: modulus + project_name: car-cfd + group_name: fignet-drivaerml + run_name: FIGConvNet-level2-16,16,16-res250-150-100-pool-max-2-aabb-275x15x1-ks5-np32768-b8x2 + +seed: 0 diff --git a/examples/cfd/figconvnet/configs/lr_scheduler/onecyclelr.yaml b/examples/cfd/figconvnet/configs/lr_scheduler/onecyclelr.yaml new file mode 100644 index 000000000..16295746e --- /dev/null +++ b/examples/cfd/figconvnet/configs/lr_scheduler/onecyclelr.yaml @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: torch.optim.lr_scheduler.OneCycleLR +max_lr: 0.001 +epochs: ${..train.num_epochs} +steps_per_epoch: 59 +pct_start: 0.2 diff --git a/examples/cfd/figconvnet/configs/model/figconv_unet_drivaerml.yaml b/examples/cfd/figconvnet/configs/model/figconv_unet_drivaerml.yaml new file mode 100644 index 000000000..a8e20433b --- /dev/null +++ b/examples/cfd/figconvnet/configs/model/figconv_unet_drivaerml.yaml @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - figconv_unet_drivaer + +_target_: src.networks.FIGConvUNetDrivAerML + +out_channels: 4 diff --git a/examples/cfd/figconvnet/notebooks/figconvnet_drivaerml.ipynb b/examples/cfd/figconvnet/notebooks/figconvnet_drivaerml.ipynb new file mode 100644 index 000000000..1f802b109 --- /dev/null +++ b/examples/cfd/figconvnet/notebooks/figconvnet_drivaerml.ipynb @@ -0,0 +1,747 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warp 1.0.2 initialized:\n", + " CUDA Toolkit 11.5, Driver 12.2\n", + " Devices:\n", + " \"cpu\" : \"x86_64\"\n", + " \"cuda:0\" : \"NVIDIA GeForce RTX 3090\" (24 GiB, sm_86, mempool enabled)\n", + " \"cuda:1\" : \"NVIDIA TITAN RTX\" (24 GiB, sm_75, mempool enabled)\n", + " CUDA peer access:\n", + " Not supported\n", + " Kernel cache:\n", + " /home/du/.cache/warp/1.0.2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/src/modulus/src/modulus/modulus/distributed/manager.py:346: UserWarning: Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job\n", + " warn(\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "import sys\n", + "\n", + "import numpy as np\n", + "import pyvista as pv\n", + "import torch\n", + "import vtk\n", + "import warp as wp\n", + "\n", + "if sys.path[0] != \"..\":\n", + " sys.path.insert(0, \"..\")\n", + "\n", + "device = torch.device(\"cuda:0\")\n", + "torch.cuda.device(device)\n", + "wp.init()\n", + "wp.set_device(str(device))\n", + "\n", + "from modulus.distributed import DistributedManager\n", + "\n", + "DistributedManager.initialize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "if os.environ.get(\"SLURM_JOB_NAME\", None) is None:\n", + " dataset_orig_path = Path(\"/data/src/modulus/data/drivaer_aws/\")\n", + " dataset_part_path = Path(\"/data/src/modulus/data/drivaer_aws/partitions\")\n", + " output_path = dataset_orig_path / f\"inference/\"\n", + " model_path = Path(\"/data/src/modulus/models/fignet/drivaerml/lrsoc/model_00999.pth\")\n", + " pc_path = Path(\"/data/src/modulus/data/drivaer_aws/original_pointclouds\")\n", + "else:\n", + " dataset_orig_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/drivaer_data_full\")\n", + " dataset_part_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/partitions/100_200_400/\")\n", + " output_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/inference/\")\n", + " model_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/9/model_00999.pth\")\n", + " pc_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/aero-benchmarking/original_pointclouds/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import src.data\n", + "\n", + "\n", + "num_points = 500_000\n", + "datamodule = src.data.DrivAerMLDataModule(\n", + " data_path=dataset_part_path,\n", + " num_points=num_points\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/functional.py:512: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3559.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + ] + }, + { + "data": { + "text/plain": [ + "FIGConvUNetDrivAerML(\n", + " (point_feature_to_grids): ModuleList(\n", + " (0): Sequential(\n", + " (0): PointFeatureToGrid(\n", + " (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)\n", + " )\n", + " (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_xc_y_z)\n", + " )\n", + " (1): Sequential(\n", + " (0): PointFeatureToGrid(\n", + " (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)\n", + " )\n", + " (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_yc_x_z)\n", + " )\n", + " (2): Sequential(\n", + " (0): PointFeatureToGrid(\n", + " (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)\n", + " )\n", + " (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_zc_x_y)\n", + " )\n", + " )\n", + " (down_blocks): ModuleList(\n", + " (0-1): 2 x Sequential(\n", + " (0): GridFeatureConv2DBlocksAndIntraCommunication(\n", + " (convs): ModuleList(\n", + " (0): GridFeatureConv2dBlock(\n", + " (conv1): GridFeatureConv2d(\n", + " (conv): Conv2d(80, 80, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))\n", + " )\n", + " (conv2): GridFeatureConv2d(\n", + " (conv): Conv2d(80, 80, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", + " )\n", + " (norm1): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((80,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (norm2): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((80,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (shortcut): GridFeatureConv2d(\n", + " (conv): Conv2d(80, 80, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (pad_to_match): GridFeaturePadToMatch()\n", + " (nonlinear): GridFeatureTransform(\n", + " (feature_transform): GELU(approximate='none')\n", + " )\n", + " )\n", + " (1): GridFeatureConv2dBlock(\n", + " (conv1): GridFeatureConv2d(\n", + " (conv): Conv2d(48, 48, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))\n", + " )\n", + " (conv2): GridFeatureConv2d(\n", + " (conv): Conv2d(48, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", + " )\n", + " (norm1): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((48,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (norm2): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((48,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (shortcut): GridFeatureConv2d(\n", + " (conv): Conv2d(48, 48, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (pad_to_match): GridFeaturePadToMatch()\n", + " (nonlinear): GridFeatureTransform(\n", + " (feature_transform): GELU(approximate='none')\n", + " )\n", + " )\n", + " (2): GridFeatureConv2dBlock(\n", + " (conv1): GridFeatureConv2d(\n", + " (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))\n", + " )\n", + " (conv2): GridFeatureConv2d(\n", + " (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", + " )\n", + " (norm1): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((32,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (norm2): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((32,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (shortcut): GridFeatureConv2d(\n", + " (conv): Conv2d(32, 32, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (pad_to_match): GridFeaturePadToMatch()\n", + " (nonlinear): GridFeatureTransform(\n", + " (feature_transform): GELU(approximate='none')\n", + " )\n", + " )\n", + " )\n", + " (intra_communications): GridFeatureGroupIntraCommunications(\n", + " (intra_communications): ModuleList(\n", + " (0): GridFeaturesGroupIntraCommunication()\n", + " )\n", + " (grid_cat): GridFeatureGroupCat(\n", + " (grid_cat): GridFeatureCat()\n", + " )\n", + " )\n", + " (proj): Identity()\n", + " (nonlinear): GridFeatureGroupTransform(\n", + " (transform): GELU(approximate='none')\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (up_blocks): ModuleList(\n", + " (0-1): 2 x Sequential(\n", + " (0): GridFeatureConv2DBlocksAndIntraCommunication(\n", + " (convs): ModuleList(\n", + " (0): GridFeatureConv2dBlock(\n", + " (conv1): GridFeatureConv2d(\n", + " (conv): ConvTranspose2d(80, 80, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (conv2): GridFeatureConv2d(\n", + " (conv): Conv2d(80, 80, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", + " )\n", + " (norm1): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((80,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (norm2): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((80,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (shortcut): GridFeatureConv2d(\n", + " (conv): ConvTranspose2d(80, 80, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (pad_to_match): GridFeaturePadToMatch()\n", + " (nonlinear): GridFeatureTransform(\n", + " (feature_transform): GELU(approximate='none')\n", + " )\n", + " )\n", + " (1): GridFeatureConv2dBlock(\n", + " (conv1): GridFeatureConv2d(\n", + " (conv): ConvTranspose2d(48, 48, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (conv2): GridFeatureConv2d(\n", + " (conv): Conv2d(48, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", + " )\n", + " (norm1): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((48,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (norm2): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((48,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (shortcut): GridFeatureConv2d(\n", + " (conv): ConvTranspose2d(48, 48, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (pad_to_match): GridFeaturePadToMatch()\n", + " (nonlinear): GridFeatureTransform(\n", + " (feature_transform): GELU(approximate='none')\n", + " )\n", + " )\n", + " (2): GridFeatureConv2dBlock(\n", + " (conv1): GridFeatureConv2d(\n", + " (conv): ConvTranspose2d(32, 32, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (conv2): GridFeatureConv2d(\n", + " (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", + " )\n", + " (norm1): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((32,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (norm2): GridFeatureTransform(\n", + " (feature_transform): LayerNorm2d((32,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (shortcut): GridFeatureConv2d(\n", + " (conv): ConvTranspose2d(32, 32, kernel_size=(2, 2), stride=(2, 2))\n", + " )\n", + " (pad_to_match): GridFeaturePadToMatch()\n", + " (nonlinear): GridFeatureTransform(\n", + " (feature_transform): GELU(approximate='none')\n", + " )\n", + " )\n", + " )\n", + " (intra_communications): GridFeatureGroupIntraCommunications(\n", + " (intra_communications): ModuleList(\n", + " (0): GridFeaturesGroupIntraCommunication()\n", + " )\n", + " (grid_cat): GridFeatureGroupCat(\n", + " (grid_cat): GridFeatureCat()\n", + " )\n", + " )\n", + " (proj): Identity()\n", + " (nonlinear): GridFeatureGroupTransform(\n", + " (transform): GELU(approximate='none')\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (convert_to_orig): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_x_y_z_c)\n", + " (grid_pools): ModuleList(\n", + " (0): GridFeatureGroupPool(\n", + " (pools): ModuleList(\n", + " (0): GridFeaturePool(\n", + " (conv): Conv2d(80, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (pool): AdaptiveMaxPool1d(output_size=1)\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): GridFeaturePool(\n", + " (conv): Conv2d(48, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (pool): AdaptiveMaxPool1d(output_size=1)\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): GridFeaturePool(\n", + " (conv): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n", + " (pool): AdaptiveMaxPool1d(output_size=1)\n", + " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): ResidualLinearBlock(\n", + " (blocks): Sequential(\n", + " (0): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n", + " (2): GELU(approximate='none')\n", + " (3): Linear(in_features=1536, out_features=512, bias=True)\n", + " (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (shortcut): Linear(in_features=1536, out_features=512, bias=True)\n", + " (activation): GELU(approximate='none')\n", + " )\n", + " (1): ResidualLinearBlock(\n", + " (blocks): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (2): GELU(approximate='none')\n", + " (3): Linear(in_features=512, out_features=512, bias=True)\n", + " (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (shortcut): Identity()\n", + " (activation): GELU(approximate='none')\n", + " )\n", + " (2): LinearBlock(\n", + " (block): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=False)\n", + " (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (2): GELU(approximate='none')\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (mlp_projection): Linear(in_features=512, out_features=1, bias=True)\n", + " (to_point): GridFeatureGroupToPoint(\n", + " (conv_list): ModuleList(\n", + " (0-2): 3 x GridFeatureToPoint(\n", + " (conv): GridFeatureToPointGraphConv(\n", + " (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (projection): PointFeatureTransform(\n", + " (feature_transform): Sequential(\n", + " (0): Linear(in_features=32, out_features=32, bias=True)\n", + " (1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", + " (2): GELU(approximate='none')\n", + " (3): Linear(in_features=32, out_features=4, bias=True)\n", + " )\n", + " )\n", + " (pad_to_match): GridFeatureGroupPadToMatch(\n", + " (match): GridFeaturePadToMatch()\n", + " )\n", + " (vertex_to_point_features): VerticesToPointFeatures(\n", + " (pos_embed): SinusoidalEncoding()\n", + " (mlp): MLP(\n", + " (layers): ModuleList(\n", + " (0): LinearBlock(\n", + " (block): Sequential(\n", + " (0): Linear(in_features=96, out_features=16, bias=False)\n", + " (1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)\n", + " (2): GELU(approximate='none')\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import src.networks\n", + "from modulus.models.figconvnet.geometries import GridFeaturesMemoryFormat\n", + "\n", + "\n", + "model = src.networks.FIGConvUNetDrivAerML(\n", + " aabb_max=[2.0, 1.8, 2.6],\n", + " aabb_min=[-2.0, -1.8, -1.5],\n", + " hidden_channels=[16, 16, 16],\n", + " in_channels=1,\n", + " kernel_size=5,\n", + " mlp_channels=[512, 512], #[2048, 2048],\n", + " neighbor_search_type=\"radius\",\n", + " num_down_blocks=1,\n", + " num_levels=2,\n", + " out_channels=4,\n", + " pooling_layers=[2],\n", + " pooling_type=\"max\",\n", + " reductions=[\"mean\"],\n", + " resolution_memory_format_pairs=[\n", + " (GridFeaturesMemoryFormat.b_xc_y_z, [ 5, 150, 100]),\n", + " (GridFeaturesMemoryFormat.b_yc_x_z, [250, 3, 100]),\n", + " (GridFeaturesMemoryFormat.b_zc_x_y, [250, 150, 2]),\n", + " ],\n", + " use_rel_pos_encode=True,\n", + ")\n", + "# Load checkpoint.\n", + "chk = torch.load(model_path)\n", + "model.load_state_dict(chk[\"model\"])\n", + "model = model.to(device)\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.neighbors import NearestNeighbors\n", + "\n", + "from modulus.datapipes.cae.readers import read_vtp\n", + "\n", + "\n", + "def convert_to_triangular_mesh(\n", + " polydata, write=False, output_filename=\"surface_mesh_triangular.vtu\"\n", + "):\n", + " \"\"\"Converts a vtkPolyData object to a triangular mesh.\"\"\"\n", + " tet_filter = vtk.vtkDataSetTriangleFilter()\n", + " tet_filter.SetInputData(polydata)\n", + " tet_filter.Update()\n", + "\n", + " tet_mesh = pv.wrap(tet_filter.GetOutput())\n", + "\n", + " if write:\n", + " tet_mesh.save(output_filename)\n", + "\n", + " return tet_mesh\n", + "\n", + "\n", + "def fetch_mesh_vertices(mesh):\n", + " \"\"\"Fetches the vertices of a mesh.\"\"\"\n", + " points = mesh.GetPoints()\n", + " num_points = points.GetNumberOfPoints()\n", + " vertices = [points.GetPoint(i) for i in range(num_points)]\n", + " return vertices" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from src.utils.eval_funcs import rrmse\n", + "\n", + "torch.set_grad_enabled(False)\n", + "\n", + "torch.cuda.empty_cache()\n", + "\n", + "@torch.no_grad\n", + "def run_inference(is_pointcloud: bool):\n", + " for sample in datamodule.test_dataloader():\n", + " vertices_denorm = model.data_dict_to_input(sample)\n", + " vertices = datamodule.encode(vertices_denorm, \"coordinates\")\n", + " normalized_pred, _ = model(vertices)\n", + " normalized_p_pred = normalized_pred[..., :1]\n", + " denorm_p_pred = datamodule.decode(normalized_p_pred, \"pressure\")\n", + " normalized_wss_pred = normalized_pred[..., 1:]\n", + " denorm_wss_pred = datamodule.decode(normalized_wss_pred, \"shear_stress\")\n", + "\n", + " # Read the original surface mesh.\n", + " idx = sample[\"design\"][0]\n", + " vtp_file = dataset_orig_path / f\"run_{idx}/boundary_{idx}.vtp\"\n", + " print(f\"Reading {vtp_file}\")\n", + " mesh = pv.read(vtp_file)\n", + " mesh = mesh.cell_data_to_point_data()\n", + "\n", + " # Interpolate predictions on GT mesh.\n", + " k = 4\n", + " nbrs_surface = NearestNeighbors(\n", + " n_neighbors=k, algorithm=\"ball_tree\"\n", + " ).fit(vertices_denorm[0].cpu().numpy())\n", + "\n", + " distances, indices = nbrs_surface.kneighbors(mesh.points)\n", + " if k == 1:\n", + " indices = indices.flatten()\n", + " pressure_pred_mesh = denorm_p_pred[0][indices]\n", + " shear_stress_pred_mesh = denorm_wss_pred[0][indices]\n", + " else:\n", + " # distances = distances.astype(np.float32)\n", + " # Weighted kNN interpolation\n", + " # Avoid division by zero by adding a small epsilon\n", + " epsilon = 1e-8\n", + " weights = 1 / (distances + epsilon)\n", + " weights_sum = np.sum(weights, axis=1, keepdims=True)\n", + " normalized_weights = weights / weights_sum\n", + " # Fetch the predictions of the k nearest neighbors\n", + " pressure_neighbors = denorm_p_pred[0][indices] # Shape: (n_samples, k, 1)\n", + " shear_stress_neighbors = denorm_wss_pred[0][indices] # Shape: (n_samples, k, 3)\n", + "\n", + " # Compute the weighted average\n", + " pressure_pred_mesh = np.sum(normalized_weights[:, :, np.newaxis] * pressure_neighbors.cpu().numpy(), axis=1)\n", + " shear_stress_pred_mesh = np.sum(normalized_weights[:, :, np.newaxis] * shear_stress_neighbors.cpu().numpy(), axis=1)\n", + "\n", + " # Convert back to torch tensors\n", + " pressure_pred_mesh = torch.from_numpy(pressure_pred_mesh).to(device)\n", + " shear_stress_pred_mesh = torch.from_numpy(shear_stress_pred_mesh).to(device)\n", + "\n", + " mesh.point_data[\"pMeanTrimPred\"] = pressure_pred_mesh.cpu().float().numpy()\n", + " mesh.point_data[\"wallShearStressMeanTrimPred\"] = shear_stress_pred_mesh.cpu().float().numpy()\n", + " mesh.save(output_path / f\"500K_k4_pc/inference_mesh_{idx}.vtp\")\n", + " print(\"Done.\")\n", + " print(\n", + " rrmse(torch.tensor(mesh.point_data[\"pMeanTrim\"]), torch.tensor(mesh.point_data[\"pMeanTrimPred\"])),\n", + " rrmse(torch.tensor(mesh.point_data[\"wallShearStressMeanTrim\"][:, 0]), torch.tensor(mesh.point_data[\"wallShearStressMeanTrimPred\"][:, 0])),\n", + " rrmse(torch.tensor(mesh.point_data[\"wallShearStressMeanTrim\"][:, 1]), torch.tensor(mesh.point_data[\"wallShearStressMeanTrimPred\"][:, 1])),\n", + " rrmse(torch.tensor(mesh.point_data[\"wallShearStressMeanTrim\"][:, 2]), torch.tensor(mesh.point_data[\"wallShearStressMeanTrimPred\"][:, 2])),\n", + " )\n", + " torch.cuda.empty_cache()\n", + " # break\n", + "\n", + "# run_inference()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_100_final.vtp\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[0m\u001b[33m2024-11-20 10:17:40.100 ( 117.435s) [ 7F914D57C280] vtkMath.cxx:778 WARN| vtkMath::Jacobi: Error extracting eigenfunctions\u001b[0m\n", + "ERROR:root:No data to measure...!\n", + "\u001b[0m\u001b[31m2024-11-20 10:17:40.737 ( 118.071s) [ 7F914D57C280] vtkMassProperties.cxx:60 ERR| vtkMassProperties (0x559237b10530): No data to measure...!\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done.\n", + "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_200_final.vtp\n", + "Done.\n", + "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_300_final.vtp\n", + "Done.\n", + "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_400_final.vtp\n", + "Done.\n", + "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_500_final.vtp\n", + "Done.\n" + ] + } + ], + "source": [ + "@torch.no_grad\n", + "def run_inference_on_pc(pc_size: int, k: int = 4):\n", + " for sample in datamodule.test_dataloader():\n", + " vertices_denorm = model.data_dict_to_input(sample)\n", + " vertices = datamodule.encode(vertices_denorm, \"coordinates\")\n", + " normalized_pred, _ = model(vertices)\n", + " normalized_p_pred = normalized_pred[..., :1]\n", + " denorm_p_pred = datamodule.decode(normalized_p_pred, \"pressure\")\n", + " normalized_wss_pred = normalized_pred[..., 1:]\n", + " denorm_wss_pred = datamodule.decode(normalized_wss_pred, \"shear_stress\")\n", + "\n", + " # Read the original surface mesh.\n", + " idx = sample[\"design\"][0]\n", + " vtp_file = pc_path / f\"input_pc_{pc_size}_run_{idx}_final.vtp\"\n", + " print(f\"Reading {vtp_file}\")\n", + " mesh = pv.read(vtp_file)\n", + "\n", + " # Interpolate predictions on GT mesh.\n", + " nbrs_surface = NearestNeighbors(\n", + " n_neighbors=k, algorithm=\"ball_tree\"\n", + " ).fit(vertices_denorm[0].cpu().numpy())\n", + "\n", + " distances, indices = nbrs_surface.kneighbors(mesh.points)\n", + " if k == 1:\n", + " indices = indices.flatten()\n", + " pressure_pred_mesh = denorm_p_pred[0][indices]\n", + " shear_stress_pred_mesh = denorm_wss_pred[0][indices]\n", + " else:\n", + " # distances = distances.astype(np.float32)\n", + " # Weighted kNN interpolation\n", + " # Avoid division by zero by adding a small epsilon\n", + " epsilon = 1e-8\n", + " weights = 1 / (distances + epsilon)\n", + " weights_sum = np.sum(weights, axis=1, keepdims=True)\n", + " normalized_weights = weights / weights_sum\n", + " # Fetch the predictions of the k nearest neighbors\n", + " pressure_neighbors = denorm_p_pred[0][indices] # Shape: (n_samples, k, 1)\n", + " shear_stress_neighbors = denorm_wss_pred[0][indices] # Shape: (n_samples, k, 3)\n", + "\n", + " # Compute the weighted average\n", + " pressure_pred_mesh = np.sum(normalized_weights[:, :, np.newaxis] * pressure_neighbors.cpu().numpy(), axis=1)\n", + " shear_stress_pred_mesh = np.sum(normalized_weights[:, :, np.newaxis] * shear_stress_neighbors.cpu().numpy(), axis=1)\n", + "\n", + " # Convert back to torch tensors\n", + " pressure_pred_mesh = torch.from_numpy(pressure_pred_mesh).to(device)\n", + " shear_stress_pred_mesh = torch.from_numpy(shear_stress_pred_mesh).to(device)\n", + "\n", + " mesh.point_data[\"pMeanTrimPred\"] = pressure_pred_mesh.cpu().float().numpy()\n", + " mesh.point_data[\"wallShearStressMeanTrimPred\"] = shear_stress_pred_mesh.cpu().float().numpy()\n", + " out_path = output_path / f\"pc/500K_k{k}\"\n", + " out_path.mkdir(parents=True, exist_ok=True)\n", + " mesh.save(out_path / f\"inference_pc_{pc_size}_{idx}.vtp\")\n", + " print(\"Done.\")\n", + " torch.cuda.empty_cache()\n", + " # break\n", + "\n", + "run_inference_on_pc(5_000_000)\n", + "run_inference_on_pc(10_000_000)\n", + "run_inference_on_pc(20_000_000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module modulus.models.figconvnet.warp_neighbor_search load on device 'cuda:0' took 145.78 ms\n", + "Done.\n" + ] + } + ], + "source": [ + "@torch.no_grad\n", + "def inference_on_sim_mesh():\n", + " for idx in [100, 200, 300, 400, 500][:1]:\n", + " mesh_gt = pv.read(dataset_orig_path / f\"run_{idx}/boundary_{idx}.vtp\")\n", + " mesh_gt = mesh_gt.cell_data_to_point_data()\n", + " step = 500_000 # num_points\n", + " p_chunks = []\n", + " wss_chunks = []\n", + " rng = np.random.default_rng(1)\n", + " indices = rng.permutation(range(mesh_gt.number_of_points))\n", + " for i_start in range(0, mesh_gt.number_of_points, step):\n", + " vertices_denorm = torch.as_tensor(\n", + " mesh_gt.points[indices[i_start : i_start + step]], device=device\n", + " ).unsqueeze(0)\n", + " vertices = datamodule.encode(vertices_denorm, \"coordinates\")\n", + " normalized_pred, _ = model(vertices)\n", + "\n", + " normalized_p_pred = normalized_pred[..., :1]\n", + " denorm_p_pred = datamodule.decode(normalized_p_pred, \"pressure\")\n", + " p_chunks.append(denorm_p_pred.cpu())\n", + "\n", + " normalized_wss_pred = normalized_pred[..., 1:]\n", + " denorm_wss_pred = datamodule.decode(normalized_wss_pred, \"shear_stress\")\n", + " wss_chunks.append(denorm_wss_pred.cpu())\n", + " torch.cuda.empty_cache()\n", + "\n", + " pressure_pred_mesh = torch.cat(p_chunks, dim=1)[0]\n", + " shear_stress_pred_mesh = torch.cat(wss_chunks, dim=1)[0]\n", + " mesh_gt.point_data[\"pMeanTrimPred\"] = pressure_pred_mesh.cpu().float().numpy()\n", + " mesh_gt.point_data[\"wallShearStressMeanTrimPred\"] = shear_stress_pred_mesh.cpu().float().numpy()\n", + " mesh_gt.save(output_path / f\"inference_mesh_{idx}.vtp\")\n", + " print(\"Done.\")\n", + " torch.cuda.empty_cache()\n", + "\n", + "\n", + "# inference_on_sim_mesh()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# v = read_vtp(str(dataset_orig_path / f\"inference_point_cloud_{idx}.vtp\"))\n", + "# v = read_vtp(str(dataset_orig_path / f\"run_100/boundary_100.vtp\"))\n", + "# mesh_gt = pv.read(dataset_orig_path / f\"run_100/boundary_100.vtp\")\n", + "\n", + "# mesh_gt = mesh_gt.cell_data_to_point_data()\n", + "# mesh_pred = pv.read(dataset_orig_path / f\"inference/inference_mesh_100.vtp\")\n", + "# rrmse(torch.tensor(mesh_pred.point_data[\"pMeanTrim\"]), torch.tensor(mesh_pred.point_data[\"pMeanTrimPred\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.2085)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/cfd/figconvnet/src/data/__init__.py b/examples/cfd/figconvnet/src/data/__init__.py index fc0463711..bc5d48823 100644 --- a/examples/cfd/figconvnet/src/data/__init__.py +++ b/examples/cfd/figconvnet/src/data/__init__.py @@ -15,4 +15,5 @@ # limitations under the License. from .base_datamodule import BaseDataModule +from .drivaerml_datamodule import DrivAerMLDataModule from .drivaernet_datamodule import DrivAerNetDataModule diff --git a/examples/cfd/figconvnet/src/data/drivaerml_datamodule.py b/examples/cfd/figconvnet/src/data/drivaerml_datamodule.py new file mode 100644 index 000000000..4ff7903dc --- /dev/null +++ b/examples/cfd/figconvnet/src/data/drivaerml_datamodule.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Iterable +import json +from pathlib import Path +from typing import Any + +import dgl +import numpy as np + +import torch +from torch import Tensor +from torch.utils.data import Dataset + +from src.data.base_datamodule import BaseDataModule + + +class DrivAerMLPartitionedDataset(Dataset): + """DrivAerML partitioned dataset.""" + + def __init__( + self, + data_path: Path, + num_points: int = 0, + ) -> None: + self.p_files = sorted(data_path.glob("*.bin")) + self.num_points = num_points + + def __len__(self) -> int: + return len(self.p_files) + + def __getitem__(self, index: int) -> dict[str, Any]: + if not 0 <= index < len(self): + raise IndexError(f"Invalid {index = } expected in [0, {len(self)})") + + gs, _ = dgl.load_graphs(str(self.p_files[index])) + + coords = torch.cat([g.ndata["coordinates"] for g in gs], dim=0) + # Sample indices from the combined graph. + n_total = coords.shape[0] + if n_total >= self.num_points: + indices = np.random.choice(n_total, self.num_points) + else: + indices = np.concatenate( + ( + np.arange(n_total), + np.random.choice(n_total, self.num_points - n_total), + ) + ) + coords = coords[indices] + pressure = torch.cat([g.ndata["pressure"] for g in gs], dim=0)[indices] + shear_stress = torch.cat([g.ndata["shear_stress"] for g in gs], dim=0)[indices] + + return { + "coordinates": coords, + "pressure": pressure, + "shear_stress": shear_stress, + "design": self.p_files[index].stem.removeprefix("graph_partitions_"), + } + + +class DrivAerMLDataModule(BaseDataModule): + """DrivAerML data module""" + + def __init__( + self, + data_path: str | Path, + num_points: int = 0, + stats_filename: str = "global_stats.json", + **kwargs, + ): + data_path = Path(data_path) + self._train_dataset = DrivAerMLPartitionedDataset( + data_path / "partitions", num_points + ) + self._val_dataset = DrivAerMLPartitionedDataset( + data_path / "validation_partitions", num_points + ) + self._test_dataset = DrivAerMLPartitionedDataset( + data_path / "test_partitions", num_points + ) + + with open(data_path / stats_filename, "r", encoding="utf-8") as f: + stats = json.load(f) + + self.mean = {k: torch.tensor(v) for k, v in stats["mean"].items()} + self.std = {k: torch.tensor(v) for k, v in stats["std_dev"].items()} + + @property + def train_dataset(self): + return self._train_dataset + + @property + def val_dataset(self): + return self._val_dataset + + @property + def test_dataset(self): + return self._test_dataset + + def encode(self, x: Tensor, name: str): + return (x - self.mean[name].to(x.device)) / self.std[name].to(x.device) + + def decode(self, x: Tensor, name: str): + return x * self.std[name].to(x.device) + self.mean[name].to(x.device) diff --git a/examples/cfd/figconvnet/src/networks/__init__.py b/examples/cfd/figconvnet/src/networks/__init__.py index 760ea8713..b0176a4f6 100644 --- a/examples/cfd/figconvnet/src/networks/__init__.py +++ b/examples/cfd/figconvnet/src/networks/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .figconvunet_drivaer import FIGConvUNetDrivAerNet +from .figconvunet_drivaer import FIGConvUNetDrivAerML, FIGConvUNetDrivAerNet diff --git a/examples/cfd/figconvnet/src/networks/figconvunet_drivaer.py b/examples/cfd/figconvnet/src/networks/figconvunet_drivaer.py index 93587de7c..cd6616bb4 100644 --- a/examples/cfd/figconvnet/src/networks/figconvunet_drivaer.py +++ b/examples/cfd/figconvnet/src/networks/figconvunet_drivaer.py @@ -244,6 +244,142 @@ def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]: # return {"vis": im}, {"pred": pred_points, "gt": gt_points, "diff": diff_points} +class FIGConvUNetDrivAerML(FIGConvUNet): + """FIGConvUNetDrivAerNet + + DrivAerNet is a variant of FIGConvUNet that is specialized for the DrivAer dataset. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + hidden_channels: List[int], + num_levels: int = 3, + num_down_blocks: Union[int, List[int]] = 1, + num_up_blocks: Union[int, List[int]] = 1, + mlp_channels: List[int] = [512, 512], + aabb_max: Tuple[float, float, float] = (2.5, 1.5, 1.0), + aabb_min: Tuple[float, float, float] = (-2.5, -1.5, -1.0), + voxel_size: Optional[float] = None, + resolution_memory_format_pairs: List[ + Tuple[GridFeaturesMemoryFormat, Tuple[int, int, int]] + ] = [ + (GridFeaturesMemoryFormat.b_xc_y_z, (2, 128, 128)), + (GridFeaturesMemoryFormat.b_yc_x_z, (128, 2, 128)), + (GridFeaturesMemoryFormat.b_zc_x_y, (128, 128, 2)), + ], + use_rel_pos: bool = True, + use_rel_pos_encode: bool = True, + pos_encode_dim: int = 32, + communication_types: List[Literal["mul", "sum"]] = ["sum"], + to_point_sample_method: Literal["graphconv", "interp"] = "graphconv", + neighbor_search_type: Literal["knn", "radius"] = "knn", + knn_k: int = 16, + reductions: List[REDUCTION_TYPES] = ["mean"], + drag_loss_weight: Optional[float] = None, + pooling_type: Literal["attention", "max", "mean"] = "max", + pooling_layers: List[int] = None, + ): + super().__init__( + in_channels=hidden_channels[0], + out_channels=out_channels, + kernel_size=kernel_size, + hidden_channels=hidden_channels, + num_levels=num_levels, + num_down_blocks=num_down_blocks, + num_up_blocks=num_up_blocks, + mlp_channels=mlp_channels, + aabb_max=aabb_max, + aabb_min=aabb_min, + voxel_size=voxel_size, + resolution_memory_format_pairs=resolution_memory_format_pairs, + use_rel_pos=use_rel_pos, + use_rel_pos_embed=use_rel_pos_encode, + pos_encode_dim=pos_encode_dim, + communication_types=communication_types, + to_point_sample_method=to_point_sample_method, + neighbor_search_type=neighbor_search_type, + knn_k=knn_k, + reductions=reductions, + drag_loss_weight=drag_loss_weight, + pooling_type=pooling_type, + pooling_layers=pooling_layers, + ) + + def data_dict_to_input(self, data_dict) -> torch.Tensor: + vertices = data_dict["coordinates"].float() + return vertices.to(self.device) + + @torch.no_grad() + def eval_dict(self, data_dict, loss_fn=None, datamodule=None, **kwargs) -> Dict: + vertices = self.data_dict_to_input(data_dict) + vertices = datamodule.encode(vertices, "coordinates") + normalized_pred, _ = self(vertices) + p_gt = datamodule.encode(data_dict["pressure"], "pressure").to(self.device) + wss_gt = datamodule.encode(data_dict["shear_stress"], "shear_stress").to( + self.device + ) + normalized_gt = torch.cat((p_gt, wss_gt), -1) + + if loss_fn is None: + loss_fn = self.loss + + out_dict = {"l2": loss_fn(normalized_pred, normalized_gt)} + + normalized_pred = normalized_pred.clone() + normalized_p_pred = normalized_pred[..., :1] + normalized_wss_pred = normalized_pred[..., 1:] + + denorm_p_pred = datamodule.decode(normalized_p_pred, "pressure") + denorm_p_gt = data_dict["pressure"].to(self.device).view_as(denorm_p_pred) + out_dict["p_l2_denorm"] = loss_fn(denorm_p_pred, denorm_p_gt) + + denorm_wss_pred = datamodule.decode(normalized_wss_pred, "shear_stress") + denorm_wss_gt = ( + data_dict["shear_stress"].to(self.device).view_as(denorm_wss_pred) + ) + out_dict["wss_l2_denorm"] = loss_fn(denorm_wss_pred, denorm_wss_gt) + + # Pressure evaluation + out_dict.update( + eval_all_metrics(p_gt, normalized_p_pred, prefix="norm_pressure") + ) + # WSS evaluation + out_dict.update( + eval_all_metrics(wss_gt, normalized_wss_pred, prefix="norm_wss") + ) + + return out_dict + + def loss_dict(self, data_dict, loss_fn=None, datamodule=None, **kwargs) -> Dict: + vertices = self.data_dict_to_input(data_dict) + vertices = datamodule.encode(vertices, "coordinates") + normalized_pred, _ = self(vertices) + p_gt = datamodule.encode(data_dict["pressure"], "pressure").to(self.device) + wss_gt = datamodule.encode(data_dict["shear_stress"], "shear_stress").to( + self.device + ) + normalized_gt = torch.cat((p_gt, wss_gt), -1) + + return_dict = {} + if loss_fn is None: + loss_fn = self.loss + + # return_dict["p_wss_loss"] = loss_fn(normalized_pred, normalized_gt) + # return return_dict + p_pred = normalized_pred[..., :1] + wss_pred = normalized_pred[..., 1:4] + return { + "p_loss": loss_fn(p_pred, p_gt), + "wss_loss": loss_fn(wss_pred, wss_gt), + } + + def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]: + return {}, {} + + def drivaer_create_subplot(ax, vertices, data, title): # Flip along x axis vertices = vertices.clone() diff --git a/examples/cfd/figconvnet/src/utils/eval_funcs.py b/examples/cfd/figconvnet/src/utils/eval_funcs.py index d877f14a9..3b82c176d 100644 --- a/examples/cfd/figconvnet/src/utils/eval_funcs.py +++ b/examples/cfd/figconvnet/src/utils/eval_funcs.py @@ -17,6 +17,7 @@ from typing import Dict, Optional from jaxtyping import Float +import torch from torch import Tensor @@ -51,6 +52,11 @@ def max_absolute_error( return (y_true - y_pred).abs().max() +def rrmse(y_true: Float[Tensor, "B"], y_pred: Float[Tensor, "B"]) -> Float[Tensor, "1"]: + """Compute the relative RMSE.""" + return torch.linalg.vector_norm(y_pred - y_true) / torch.linalg.vector_norm(y_true) + + def eval_all_metrics( y_true: Float[Tensor, "B"], y_pred: Float[Tensor, "B"], prefix: Optional[str] = None ) -> Dict[str, float]: @@ -67,6 +73,7 @@ def eval_all_metrics( "mse": mean_squared_error(y_true, y_pred).cpu().item(), "mae": mean_absolute_error(y_true, y_pred).cpu().item(), "maxae": max_absolute_error(y_true, y_pred).cpu().item(), + "rrmse": rrmse(y_true, y_pred).cpu().item(), } if prefix is not None: out_dict = {f"{prefix}_{k}": v for k, v in out_dict.items()} diff --git a/examples/cfd/figconvnet/src/utils/loggers.py b/examples/cfd/figconvnet/src/utils/loggers.py index c8e78bd0a..6d6e0a09a 100644 --- a/examples/cfd/figconvnet/src/utils/loggers.py +++ b/examples/cfd/figconvnet/src/utils/loggers.py @@ -130,6 +130,7 @@ def __init__( entity: Optional[str] = None, resume: Optional[bool] = False, wandb_id: Optional[str] = None, + mode: Optional[str] = "online", ): super().__init__() if resume: @@ -154,6 +155,7 @@ def __init__( entity=entity, resume=resume, id=wandb_id, + mode=mode, ) # log config to wandb if config is not None and resume != "must": @@ -298,6 +300,7 @@ def init_logger(config: dict) -> Logger: config=config, resume=resume, wandb_id=wandb_id, + mode=logger_cfg.mode, ) ) else: diff --git a/examples/cfd/figconvnet/train.py b/examples/cfd/figconvnet/train.py index 4734527ce..48b04f1de 100644 --- a/examples/cfd/figconvnet/train.py +++ b/examples/cfd/figconvnet/train.py @@ -286,12 +286,7 @@ def train(config: DictConfig, signal_handler: SignalHandler): loss = 0 for k, v in loss_dict.items(): - weight_name = k + "_weight" - if ( - hasattr(config, weight_name) - and getattr(config, weight_name) is not None - ): - v = v * getattr(config, weight_name) + v = v * getattr(config, k + "_weight", 1) loss = loss + v.mean() # Assert loss is valid @@ -318,7 +313,7 @@ def train(config: DictConfig, signal_handler: SignalHandler): scaler.update() train_l2_meter.update(loss.item()) - loggers.log_scalar("train/iter_lr", scheduler.get_lr()[0], tot_iter) + loggers.log_scalar("train/iter_lr", scheduler.get_last_lr()[0], tot_iter) loggers.log_scalar("train/iter_loss", loss.item(), tot_iter) for k, v in loss_dict.items(): loggers.log_scalar(f"train/{k}", v.item(), tot_iter) @@ -328,10 +323,13 @@ def train(config: DictConfig, signal_handler: SignalHandler): print_str += f"{k}: {v.item():.4f}, " # only print the number logger.info(print_str) + if config.train.lr_scheduler_mode == "iteration": + scheduler.step() tot_iter += 1 torch.cuda.empty_cache() - scheduler.step() + if config.train.lr_scheduler_mode == "epoch": + scheduler.step() t2 = default_timer() logger.info( f"Training epoch {ep} took {t2 - t1:.2f} seconds. L2 loss: {train_l2_meter.avg:.4f}" @@ -387,7 +385,7 @@ def _slurm_setup(config: DictConfig) -> None: # Detect if it is running on a SLURM cluster. if "SLURM_JOB_ID" in os.environ: # The output directory is set to simply ${output}/SLURM_JOB_ID. - config.output = os.path.join(config.output, os.environ["SLURM_JOB_ID"]) + # config.output = os.path.join(config.output, os.environ["SLURM_JOB_ID"]) # Check for the checkpoints and model_*.pth files in the output directory. if os.path.exists(config.output) and any( f.startswith("model_") and f.endswith(".pth") From 29305d708d5bfd34d3a4d9cb502f0dbeef46482e Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Mon, 2 Dec 2024 16:45:32 -0800 Subject: [PATCH 2/6] Add FIGNet DrivAerML experiment config. --- .../experiment/drivaerml/figconv_unet.yaml | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 examples/cfd/external_aerodynamics/figconvnet/configs/experiment/drivaerml/figconv_unet.yaml diff --git a/examples/cfd/external_aerodynamics/figconvnet/configs/experiment/drivaerml/figconv_unet.yaml b/examples/cfd/external_aerodynamics/figconvnet/configs/experiment/drivaerml/figconv_unet.yaml new file mode 100644 index 000000000..27c69ec0d --- /dev/null +++ b/examples/cfd/external_aerodynamics/figconvnet/configs/experiment/drivaerml/figconv_unet.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - /data: drivaerml + - /model: figconv_unet_drivaerml + - /loss: mseloss + - /optimizer: adam + - /lr_scheduler: steplr + +train: + batch_size: 8 + num_epochs: 200 + +model: + aabb_max: [ 2.0, 1.8, 2.6] + aabb_min: [-2.0, -1.8, -1.5] + hidden_channels: [16, 16, 16] + kernel_size: 5 + # mlp_channels: [2048, 2048] + neighbor_search_type: radius + num_down_blocks: 1 + num_levels: 2 + pooling_layers: [2] + pooling_type: max + reductions: [mean] + resolution_memory_format_pairs: + - ${res_mem_pair:b_xc_y_z, [ 5, 150, 100]} + - ${res_mem_pair:b_yc_x_z, [250, 3, 100]} + - ${res_mem_pair:b_zc_x_y, [250, 150, 2]} + use_rel_pos_encode: true + +lr_scheduler: + step_size: 50 + +loggers: + wandb: + entity: modulus + project_name: car-cfd + group_name: fignet-drivaerml + run_name: FIGConvNet-level2-16,16,16-res250-150-100-pool-max-2-aabb-20x18x26-ks5-np32768-b8x2 + +seed: 0 From 628a4ac145dc0d02f1405fe4f71515f04e42c7e5 Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Tue, 10 Dec 2024 09:06:46 -0800 Subject: [PATCH 3/6] Update notebook. --- .../notebooks/figconvnet_drivaerml.ipynb | 428 ++---------------- 1 file changed, 27 insertions(+), 401 deletions(-) diff --git a/examples/cfd/external_aerodynamics/figconvnet/notebooks/figconvnet_drivaerml.ipynb b/examples/cfd/external_aerodynamics/figconvnet/notebooks/figconvnet_drivaerml.ipynb index 1f802b109..7aa7d9752 100644 --- a/examples/cfd/external_aerodynamics/figconvnet/notebooks/figconvnet_drivaerml.ipynb +++ b/examples/cfd/external_aerodynamics/figconvnet/notebooks/figconvnet_drivaerml.ipynb @@ -2,34 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Warp 1.0.2 initialized:\n", - " CUDA Toolkit 11.5, Driver 12.2\n", - " Devices:\n", - " \"cpu\" : \"x86_64\"\n", - " \"cuda:0\" : \"NVIDIA GeForce RTX 3090\" (24 GiB, sm_86, mempool enabled)\n", - " \"cuda:1\" : \"NVIDIA TITAN RTX\" (24 GiB, sm_75, mempool enabled)\n", - " CUDA peer access:\n", - " Not supported\n", - " Kernel cache:\n", - " /home/du/.cache/warp/1.0.2\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/src/modulus/src/modulus/modulus/distributed/manager.py:346: UserWarning: Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job\n", - " warn(\n" - ] - } - ], + "outputs": [], "source": [ "from pathlib import Path\n", "import sys\n", @@ -64,20 +39,24 @@ "if os.environ.get(\"SLURM_JOB_NAME\", None) is None:\n", " dataset_orig_path = Path(\"/data/src/modulus/data/drivaer_aws/\")\n", " dataset_part_path = Path(\"/data/src/modulus/data/drivaer_aws/partitions\")\n", - " output_path = dataset_orig_path / f\"inference/\"\n", - " model_path = Path(\"/data/src/modulus/models/fignet/drivaerml/lrsoc/model_00999.pth\")\n", + " output_path = dataset_orig_path / f\"inference/seploss4-01/\"\n", + " # model_path = Path(\"/data/src/modulus/models/fignet/drivaerml/lrsoc/model_00999.pth\")\n", + " model_path = Path(\"/data/src/modulus/models/fignet/drivaerml/seploss4-01/model_00999.pth\")\n", " pc_path = Path(\"/data/src/modulus/data/drivaer_aws/original_pointclouds\")\n", "else:\n", " dataset_orig_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/drivaer_data_full\")\n", " dataset_part_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/partitions/100_200_400/\")\n", - " output_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/inference/\")\n", - " model_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/9/model_00999.pth\")\n", - " pc_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/aero-benchmarking/original_pointclouds/\")" + " # output_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/inference/new/sharedloss\")\n", + " output_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/inference/new/seploss\")\n", + " # model_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/9/model_00999.pth\")\n", + " model_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/seploss-01/best/model_00999.pth\")\n", + " # pc_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/aero-benchmarking/original_pointclouds/\")\n", + " pc_path = Path(\"/lustre/fsw/portfolios/coreai/users/ktangsali/inference_pc_generation/original_pointclouds/\")" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -93,305 +72,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/torch/functional.py:512: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3559.)\n", - " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" - ] - }, - { - "data": { - "text/plain": [ - "FIGConvUNetDrivAerML(\n", - " (point_feature_to_grids): ModuleList(\n", - " (0): Sequential(\n", - " (0): PointFeatureToGrid(\n", - " (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)\n", - " )\n", - " (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_xc_y_z)\n", - " )\n", - " (1): Sequential(\n", - " (0): PointFeatureToGrid(\n", - " (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)\n", - " )\n", - " (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_yc_x_z)\n", - " )\n", - " (2): Sequential(\n", - " (0): PointFeatureToGrid(\n", - " (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)\n", - " )\n", - " (1): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_zc_x_y)\n", - " )\n", - " )\n", - " (down_blocks): ModuleList(\n", - " (0-1): 2 x Sequential(\n", - " (0): GridFeatureConv2DBlocksAndIntraCommunication(\n", - " (convs): ModuleList(\n", - " (0): GridFeatureConv2dBlock(\n", - " (conv1): GridFeatureConv2d(\n", - " (conv): Conv2d(80, 80, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))\n", - " )\n", - " (conv2): GridFeatureConv2d(\n", - " (conv): Conv2d(80, 80, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", - " )\n", - " (norm1): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((80,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (norm2): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((80,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (shortcut): GridFeatureConv2d(\n", - " (conv): Conv2d(80, 80, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (pad_to_match): GridFeaturePadToMatch()\n", - " (nonlinear): GridFeatureTransform(\n", - " (feature_transform): GELU(approximate='none')\n", - " )\n", - " )\n", - " (1): GridFeatureConv2dBlock(\n", - " (conv1): GridFeatureConv2d(\n", - " (conv): Conv2d(48, 48, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))\n", - " )\n", - " (conv2): GridFeatureConv2d(\n", - " (conv): Conv2d(48, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", - " )\n", - " (norm1): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((48,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (norm2): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((48,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (shortcut): GridFeatureConv2d(\n", - " (conv): Conv2d(48, 48, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (pad_to_match): GridFeaturePadToMatch()\n", - " (nonlinear): GridFeatureTransform(\n", - " (feature_transform): GELU(approximate='none')\n", - " )\n", - " )\n", - " (2): GridFeatureConv2dBlock(\n", - " (conv1): GridFeatureConv2d(\n", - " (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))\n", - " )\n", - " (conv2): GridFeatureConv2d(\n", - " (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", - " )\n", - " (norm1): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (norm2): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (shortcut): GridFeatureConv2d(\n", - " (conv): Conv2d(32, 32, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (pad_to_match): GridFeaturePadToMatch()\n", - " (nonlinear): GridFeatureTransform(\n", - " (feature_transform): GELU(approximate='none')\n", - " )\n", - " )\n", - " )\n", - " (intra_communications): GridFeatureGroupIntraCommunications(\n", - " (intra_communications): ModuleList(\n", - " (0): GridFeaturesGroupIntraCommunication()\n", - " )\n", - " (grid_cat): GridFeatureGroupCat(\n", - " (grid_cat): GridFeatureCat()\n", - " )\n", - " )\n", - " (proj): Identity()\n", - " (nonlinear): GridFeatureGroupTransform(\n", - " (transform): GELU(approximate='none')\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (up_blocks): ModuleList(\n", - " (0-1): 2 x Sequential(\n", - " (0): GridFeatureConv2DBlocksAndIntraCommunication(\n", - " (convs): ModuleList(\n", - " (0): GridFeatureConv2dBlock(\n", - " (conv1): GridFeatureConv2d(\n", - " (conv): ConvTranspose2d(80, 80, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (conv2): GridFeatureConv2d(\n", - " (conv): Conv2d(80, 80, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", - " )\n", - " (norm1): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((80,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (norm2): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((80,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (shortcut): GridFeatureConv2d(\n", - " (conv): ConvTranspose2d(80, 80, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (pad_to_match): GridFeaturePadToMatch()\n", - " (nonlinear): GridFeatureTransform(\n", - " (feature_transform): GELU(approximate='none')\n", - " )\n", - " )\n", - " (1): GridFeatureConv2dBlock(\n", - " (conv1): GridFeatureConv2d(\n", - " (conv): ConvTranspose2d(48, 48, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (conv2): GridFeatureConv2d(\n", - " (conv): Conv2d(48, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", - " )\n", - " (norm1): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((48,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (norm2): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((48,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (shortcut): GridFeatureConv2d(\n", - " (conv): ConvTranspose2d(48, 48, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (pad_to_match): GridFeaturePadToMatch()\n", - " (nonlinear): GridFeatureTransform(\n", - " (feature_transform): GELU(approximate='none')\n", - " )\n", - " )\n", - " (2): GridFeatureConv2dBlock(\n", - " (conv1): GridFeatureConv2d(\n", - " (conv): ConvTranspose2d(32, 32, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (conv2): GridFeatureConv2d(\n", - " (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", - " )\n", - " (norm1): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (norm2): GridFeatureTransform(\n", - " (feature_transform): LayerNorm2d((32,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (shortcut): GridFeatureConv2d(\n", - " (conv): ConvTranspose2d(32, 32, kernel_size=(2, 2), stride=(2, 2))\n", - " )\n", - " (pad_to_match): GridFeaturePadToMatch()\n", - " (nonlinear): GridFeatureTransform(\n", - " (feature_transform): GELU(approximate='none')\n", - " )\n", - " )\n", - " )\n", - " (intra_communications): GridFeatureGroupIntraCommunications(\n", - " (intra_communications): ModuleList(\n", - " (0): GridFeaturesGroupIntraCommunication()\n", - " )\n", - " (grid_cat): GridFeatureGroupCat(\n", - " (grid_cat): GridFeatureCat()\n", - " )\n", - " )\n", - " (proj): Identity()\n", - " (nonlinear): GridFeatureGroupTransform(\n", - " (transform): GELU(approximate='none')\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (convert_to_orig): GridFeatureMemoryFormatConverter(memory_format=GridFeaturesMemoryFormat.b_x_y_z_c)\n", - " (grid_pools): ModuleList(\n", - " (0): GridFeatureGroupPool(\n", - " (pools): ModuleList(\n", - " (0): GridFeaturePool(\n", - " (conv): Conv2d(80, 512, kernel_size=(1, 1), stride=(1, 1))\n", - " (pool): AdaptiveMaxPool1d(output_size=1)\n", - " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (1): GridFeaturePool(\n", - " (conv): Conv2d(48, 512, kernel_size=(1, 1), stride=(1, 1))\n", - " (pool): AdaptiveMaxPool1d(output_size=1)\n", - " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (2): GridFeaturePool(\n", - " (conv): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n", - " (pool): AdaptiveMaxPool1d(output_size=1)\n", - " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (mlp): MLP(\n", - " (layers): ModuleList(\n", - " (0): ResidualLinearBlock(\n", - " (blocks): Sequential(\n", - " (0): Linear(in_features=1536, out_features=1536, bias=True)\n", - " (1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n", - " (2): GELU(approximate='none')\n", - " (3): Linear(in_features=1536, out_features=512, bias=True)\n", - " (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (shortcut): Linear(in_features=1536, out_features=512, bias=True)\n", - " (activation): GELU(approximate='none')\n", - " )\n", - " (1): ResidualLinearBlock(\n", - " (blocks): Sequential(\n", - " (0): Linear(in_features=512, out_features=512, bias=True)\n", - " (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", - " (2): GELU(approximate='none')\n", - " (3): Linear(in_features=512, out_features=512, bias=True)\n", - " (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (shortcut): Identity()\n", - " (activation): GELU(approximate='none')\n", - " )\n", - " (2): LinearBlock(\n", - " (block): Sequential(\n", - " (0): Linear(in_features=512, out_features=512, bias=False)\n", - " (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", - " (2): GELU(approximate='none')\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (mlp_projection): Linear(in_features=512, out_features=1, bias=True)\n", - " (to_point): GridFeatureGroupToPoint(\n", - " (conv_list): ModuleList(\n", - " (0-2): 3 x GridFeatureToPoint(\n", - " (conv): GridFeatureToPointGraphConv(\n", - " (conv): PointFeatureConv(in_channels=16 out_channels=16 search_type=radius reductions=['mean'] rel_pos_encode=True)\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (projection): PointFeatureTransform(\n", - " (feature_transform): Sequential(\n", - " (0): Linear(in_features=32, out_features=32, bias=True)\n", - " (1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n", - " (2): GELU(approximate='none')\n", - " (3): Linear(in_features=32, out_features=4, bias=True)\n", - " )\n", - " )\n", - " (pad_to_match): GridFeatureGroupPadToMatch(\n", - " (match): GridFeaturePadToMatch()\n", - " )\n", - " (vertex_to_point_features): VerticesToPointFeatures(\n", - " (pos_embed): SinusoidalEncoding()\n", - " (mlp): MLP(\n", - " (layers): ModuleList(\n", - " (0): LinearBlock(\n", - " (block): Sequential(\n", - " (0): Linear(in_features=96, out_features=16, bias=False)\n", - " (1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)\n", - " (2): GELU(approximate='none')\n", - " )\n", - " )\n", - " )\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import src.networks\n", "from modulus.models.figconvnet.geometries import GridFeaturesMemoryFormat\n", @@ -427,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -462,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -473,7 +156,7 @@ "torch.cuda.empty_cache()\n", "\n", "@torch.no_grad\n", - "def run_inference(is_pointcloud: bool):\n", + "def run_inference(k: int = 4):\n", " for sample in datamodule.test_dataloader():\n", " vertices_denorm = model.data_dict_to_input(sample)\n", " vertices = datamodule.encode(vertices_denorm, \"coordinates\")\n", @@ -491,7 +174,6 @@ " mesh = mesh.cell_data_to_point_data()\n", "\n", " # Interpolate predictions on GT mesh.\n", - " k = 4\n", " nbrs_surface = NearestNeighbors(\n", " n_neighbors=k, algorithm=\"ball_tree\"\n", " ).fit(vertices_denorm[0].cpu().numpy())\n", @@ -523,7 +205,10 @@ "\n", " mesh.point_data[\"pMeanTrimPred\"] = pressure_pred_mesh.cpu().float().numpy()\n", " mesh.point_data[\"wallShearStressMeanTrimPred\"] = shear_stress_pred_mesh.cpu().float().numpy()\n", - " mesh.save(output_path / f\"500K_k4_pc/inference_mesh_{idx}.vtp\")\n", + " out_path = output_path / f\"simmesh/500K_k{k}\"\n", + " out_path.mkdir(parents=True, exist_ok=True)\n", + " mesh.save(out_path / f\"inference_mesh_{idx}.vtp\")\n", + "\n", " print(\"Done.\")\n", " print(\n", " rrmse(torch.tensor(mesh.point_data[\"pMeanTrim\"]), torch.tensor(mesh.point_data[\"pMeanTrimPred\"])),\n", @@ -534,46 +219,14 @@ " torch.cuda.empty_cache()\n", " # break\n", "\n", - "# run_inference()" + "run_inference()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_100_final.vtp\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[0m\u001b[33m2024-11-20 10:17:40.100 ( 117.435s) [ 7F914D57C280] vtkMath.cxx:778 WARN| vtkMath::Jacobi: Error extracting eigenfunctions\u001b[0m\n", - "ERROR:root:No data to measure...!\n", - "\u001b[0m\u001b[31m2024-11-20 10:17:40.737 ( 118.071s) [ 7F914D57C280] vtkMassProperties.cxx:60 ERR| vtkMassProperties (0x559237b10530): No data to measure...!\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Done.\n", - "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_200_final.vtp\n", - "Done.\n", - "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_300_final.vtp\n", - "Done.\n", - "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_400_final.vtp\n", - "Done.\n", - "Reading /data/src/modulus/data/drivaer_aws/original_pointclouds/input_pc_5000000_run_500_final.vtp\n", - "Done.\n" - ] - } - ], + "outputs": [], "source": [ "@torch.no_grad\n", "def run_inference_on_pc(pc_size: int, k: int = 4):\n", @@ -631,25 +284,16 @@ " torch.cuda.empty_cache()\n", " # break\n", "\n", - "run_inference_on_pc(5_000_000)\n", - "run_inference_on_pc(10_000_000)\n", - "run_inference_on_pc(20_000_000)" + "# run_inference_on_pc(5_000_000)\n", + "# run_inference_on_pc(10_000_000)\n", + "# run_inference_on_pc(20_000_000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Module modulus.models.figconvnet.warp_neighbor_search load on device 'cuda:0' took 145.78 ms\n", - "Done.\n" - ] - } - ], + "outputs": [], "source": [ "@torch.no_grad\n", "def inference_on_sim_mesh():\n", @@ -703,24 +347,6 @@ "# mesh_pred = pv.read(dataset_orig_path / f\"inference/inference_mesh_100.vtp\")\n", "# rrmse(torch.tensor(mesh_pred.point_data[\"pMeanTrim\"]), torch.tensor(mesh_pred.point_data[\"pMeanTrimPred\"]))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.2085)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] } ], "metadata": { From 22579f5b390539ebe12a5485bb079cd428ac4195 Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Wed, 8 Jan 2025 14:28:48 -0800 Subject: [PATCH 4/6] Update README and dataset docstring. --- .../figconvnet/README.md | 47 +++ .../notebooks/figconvnet_drivaerml.ipynb | 373 ------------------ .../src/data/drivaerml_datamodule.py | 37 +- 3 files changed, 83 insertions(+), 374 deletions(-) delete mode 100644 examples/cfd/external_aerodynamics/figconvnet/notebooks/figconvnet_drivaerml.ipynb diff --git a/examples/cfd/external_aerodynamics/figconvnet/README.md b/examples/cfd/external_aerodynamics/figconvnet/README.md index 923bf10fb..96d51cef5 100644 --- a/examples/cfd/external_aerodynamics/figconvnet/README.md +++ b/examples/cfd/external_aerodynamics/figconvnet/README.md @@ -19,6 +19,52 @@ We demonstrate a 140k× speed-up compared to GPU-accelerated computational fluid dynamics (CFD) simulators and over 2× improvement in pressure prediction over prior deep learning arts. +## Supported datasets + +The current version of the code supports the following datasets: + +### DrivAerNet + +Both DrivAerNet and DrivAerNet++ datasets [[4](#references)] are supported. +Please follow the instructions on the [dataset GitHub](https://github.com/Mohamedelrefaie/DrivAerNet) +page to download the dataset. + +The corresponding experiment configuration file can be found at: `./configs/experiment/drivaernet/figconv_unet.yaml`. +For more details, refer to the [Training section](#training). + +### DrivAerML + +DrivAerML dataset [[6](#references)] is supported but requires +conversion of the dataset to a more efficient binary format. +This format is supported by models like XAeroNet and FIGConvNet +and represents efficient storage of the original meshes as +partitioned graphs. +For more details on how to convert the original DrivAerML dataset +to partitioned dataset, refer to +[XAeroNet example README](https://github.com/NVIDIA/modulus/tree/main/examples/cfd/external_aerodynamics/xaeronet#training-the-xaeronet-s-model), +steps 1 to 5. + +The binary dataset should have the following structure: + +```text +├─ partitions +│ ├─ graph_partitions_1.bin +│ ├─ graph_partitions_1.bin +│ ├─ ... +├─ test_partitions +│ ├─ graph_partitions_100.bin +│ ├─ graph_partitions_101.bin +│ ├─ ... +├─ validation_partitions +│ ├─ graph_partitions_200.bin +│ ├─ graph_partitions_201.bin +│ ├─ ... +└─ global_stats.json +``` + +The corresponding experiment configuration file can be found at: +`./configs/experiment/drivaerml/figconv_unet.yaml`. + ## Installation FIGConvUNet dependencies can be installed with `pip install`, for example: @@ -107,3 +153,4 @@ mpirun -np 2 python train.py \ 3. [Ahmed body wiki](https://www.cfd-online.com/Wiki/Ahmed_body) 4. [DrivAerNet: A Parametric Car Dataset for Data-Driven Aerodynamic Design and Graph-Based Drag Prediction](https://arxiv.org/abs/2403.08055) 5. [Deep Learning for Real-Time Aerodynamic Evaluations of Arbitrary Vehicle Shapes](https://arxiv.org/abs/2108.05798) +6. [DrivAerML: High-Fidelity Computational Fluid Dynamics Dataset for Road-Car External Aerodynamics](https://arxiv.org/abs/2408.11969) diff --git a/examples/cfd/external_aerodynamics/figconvnet/notebooks/figconvnet_drivaerml.ipynb b/examples/cfd/external_aerodynamics/figconvnet/notebooks/figconvnet_drivaerml.ipynb deleted file mode 100644 index 7aa7d9752..000000000 --- a/examples/cfd/external_aerodynamics/figconvnet/notebooks/figconvnet_drivaerml.ipynb +++ /dev/null @@ -1,373 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "import sys\n", - "\n", - "import numpy as np\n", - "import pyvista as pv\n", - "import torch\n", - "import vtk\n", - "import warp as wp\n", - "\n", - "if sys.path[0] != \"..\":\n", - " sys.path.insert(0, \"..\")\n", - "\n", - "device = torch.device(\"cuda:0\")\n", - "torch.cuda.device(device)\n", - "wp.init()\n", - "wp.set_device(str(device))\n", - "\n", - "from modulus.distributed import DistributedManager\n", - "\n", - "DistributedManager.initialize()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "if os.environ.get(\"SLURM_JOB_NAME\", None) is None:\n", - " dataset_orig_path = Path(\"/data/src/modulus/data/drivaer_aws/\")\n", - " dataset_part_path = Path(\"/data/src/modulus/data/drivaer_aws/partitions\")\n", - " output_path = dataset_orig_path / f\"inference/seploss4-01/\"\n", - " # model_path = Path(\"/data/src/modulus/models/fignet/drivaerml/lrsoc/model_00999.pth\")\n", - " model_path = Path(\"/data/src/modulus/models/fignet/drivaerml/seploss4-01/model_00999.pth\")\n", - " pc_path = Path(\"/data/src/modulus/data/drivaer_aws/original_pointclouds\")\n", - "else:\n", - " dataset_orig_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/drivaer_data_full\")\n", - " dataset_part_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/partitions/100_200_400/\")\n", - " # output_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/inference/new/sharedloss\")\n", - " output_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/inference/new/seploss\")\n", - " # model_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/9/model_00999.pth\")\n", - " model_path = Path(\"/lustre/fsw/portfolios/coreai/users/akamenev/outputs/fignet/drivaerml/seploss-01/best/model_00999.pth\")\n", - " # pc_path = Path(\"/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/aero-benchmarking/original_pointclouds/\")\n", - " pc_path = Path(\"/lustre/fsw/portfolios/coreai/users/ktangsali/inference_pc_generation/original_pointclouds/\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import src.data\n", - "\n", - "\n", - "num_points = 500_000\n", - "datamodule = src.data.DrivAerMLDataModule(\n", - " data_path=dataset_part_path,\n", - " num_points=num_points\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import src.networks\n", - "from modulus.models.figconvnet.geometries import GridFeaturesMemoryFormat\n", - "\n", - "\n", - "model = src.networks.FIGConvUNetDrivAerML(\n", - " aabb_max=[2.0, 1.8, 2.6],\n", - " aabb_min=[-2.0, -1.8, -1.5],\n", - " hidden_channels=[16, 16, 16],\n", - " in_channels=1,\n", - " kernel_size=5,\n", - " mlp_channels=[512, 512], #[2048, 2048],\n", - " neighbor_search_type=\"radius\",\n", - " num_down_blocks=1,\n", - " num_levels=2,\n", - " out_channels=4,\n", - " pooling_layers=[2],\n", - " pooling_type=\"max\",\n", - " reductions=[\"mean\"],\n", - " resolution_memory_format_pairs=[\n", - " (GridFeaturesMemoryFormat.b_xc_y_z, [ 5, 150, 100]),\n", - " (GridFeaturesMemoryFormat.b_yc_x_z, [250, 3, 100]),\n", - " (GridFeaturesMemoryFormat.b_zc_x_y, [250, 150, 2]),\n", - " ],\n", - " use_rel_pos_encode=True,\n", - ")\n", - "# Load checkpoint.\n", - "chk = torch.load(model_path)\n", - "model.load_state_dict(chk[\"model\"])\n", - "model = model.to(device)\n", - "model.eval()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.neighbors import NearestNeighbors\n", - "\n", - "from modulus.datapipes.cae.readers import read_vtp\n", - "\n", - "\n", - "def convert_to_triangular_mesh(\n", - " polydata, write=False, output_filename=\"surface_mesh_triangular.vtu\"\n", - "):\n", - " \"\"\"Converts a vtkPolyData object to a triangular mesh.\"\"\"\n", - " tet_filter = vtk.vtkDataSetTriangleFilter()\n", - " tet_filter.SetInputData(polydata)\n", - " tet_filter.Update()\n", - "\n", - " tet_mesh = pv.wrap(tet_filter.GetOutput())\n", - "\n", - " if write:\n", - " tet_mesh.save(output_filename)\n", - "\n", - " return tet_mesh\n", - "\n", - "\n", - "def fetch_mesh_vertices(mesh):\n", - " \"\"\"Fetches the vertices of a mesh.\"\"\"\n", - " points = mesh.GetPoints()\n", - " num_points = points.GetNumberOfPoints()\n", - " vertices = [points.GetPoint(i) for i in range(num_points)]\n", - " return vertices" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from src.utils.eval_funcs import rrmse\n", - "\n", - "torch.set_grad_enabled(False)\n", - "\n", - "torch.cuda.empty_cache()\n", - "\n", - "@torch.no_grad\n", - "def run_inference(k: int = 4):\n", - " for sample in datamodule.test_dataloader():\n", - " vertices_denorm = model.data_dict_to_input(sample)\n", - " vertices = datamodule.encode(vertices_denorm, \"coordinates\")\n", - " normalized_pred, _ = model(vertices)\n", - " normalized_p_pred = normalized_pred[..., :1]\n", - " denorm_p_pred = datamodule.decode(normalized_p_pred, \"pressure\")\n", - " normalized_wss_pred = normalized_pred[..., 1:]\n", - " denorm_wss_pred = datamodule.decode(normalized_wss_pred, \"shear_stress\")\n", - "\n", - " # Read the original surface mesh.\n", - " idx = sample[\"design\"][0]\n", - " vtp_file = dataset_orig_path / f\"run_{idx}/boundary_{idx}.vtp\"\n", - " print(f\"Reading {vtp_file}\")\n", - " mesh = pv.read(vtp_file)\n", - " mesh = mesh.cell_data_to_point_data()\n", - "\n", - " # Interpolate predictions on GT mesh.\n", - " nbrs_surface = NearestNeighbors(\n", - " n_neighbors=k, algorithm=\"ball_tree\"\n", - " ).fit(vertices_denorm[0].cpu().numpy())\n", - "\n", - " distances, indices = nbrs_surface.kneighbors(mesh.points)\n", - " if k == 1:\n", - " indices = indices.flatten()\n", - " pressure_pred_mesh = denorm_p_pred[0][indices]\n", - " shear_stress_pred_mesh = denorm_wss_pred[0][indices]\n", - " else:\n", - " # distances = distances.astype(np.float32)\n", - " # Weighted kNN interpolation\n", - " # Avoid division by zero by adding a small epsilon\n", - " epsilon = 1e-8\n", - " weights = 1 / (distances + epsilon)\n", - " weights_sum = np.sum(weights, axis=1, keepdims=True)\n", - " normalized_weights = weights / weights_sum\n", - " # Fetch the predictions of the k nearest neighbors\n", - " pressure_neighbors = denorm_p_pred[0][indices] # Shape: (n_samples, k, 1)\n", - " shear_stress_neighbors = denorm_wss_pred[0][indices] # Shape: (n_samples, k, 3)\n", - "\n", - " # Compute the weighted average\n", - " pressure_pred_mesh = np.sum(normalized_weights[:, :, np.newaxis] * pressure_neighbors.cpu().numpy(), axis=1)\n", - " shear_stress_pred_mesh = np.sum(normalized_weights[:, :, np.newaxis] * shear_stress_neighbors.cpu().numpy(), axis=1)\n", - "\n", - " # Convert back to torch tensors\n", - " pressure_pred_mesh = torch.from_numpy(pressure_pred_mesh).to(device)\n", - " shear_stress_pred_mesh = torch.from_numpy(shear_stress_pred_mesh).to(device)\n", - "\n", - " mesh.point_data[\"pMeanTrimPred\"] = pressure_pred_mesh.cpu().float().numpy()\n", - " mesh.point_data[\"wallShearStressMeanTrimPred\"] = shear_stress_pred_mesh.cpu().float().numpy()\n", - " out_path = output_path / f\"simmesh/500K_k{k}\"\n", - " out_path.mkdir(parents=True, exist_ok=True)\n", - " mesh.save(out_path / f\"inference_mesh_{idx}.vtp\")\n", - "\n", - " print(\"Done.\")\n", - " print(\n", - " rrmse(torch.tensor(mesh.point_data[\"pMeanTrim\"]), torch.tensor(mesh.point_data[\"pMeanTrimPred\"])),\n", - " rrmse(torch.tensor(mesh.point_data[\"wallShearStressMeanTrim\"][:, 0]), torch.tensor(mesh.point_data[\"wallShearStressMeanTrimPred\"][:, 0])),\n", - " rrmse(torch.tensor(mesh.point_data[\"wallShearStressMeanTrim\"][:, 1]), torch.tensor(mesh.point_data[\"wallShearStressMeanTrimPred\"][:, 1])),\n", - " rrmse(torch.tensor(mesh.point_data[\"wallShearStressMeanTrim\"][:, 2]), torch.tensor(mesh.point_data[\"wallShearStressMeanTrimPred\"][:, 2])),\n", - " )\n", - " torch.cuda.empty_cache()\n", - " # break\n", - "\n", - "run_inference()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@torch.no_grad\n", - "def run_inference_on_pc(pc_size: int, k: int = 4):\n", - " for sample in datamodule.test_dataloader():\n", - " vertices_denorm = model.data_dict_to_input(sample)\n", - " vertices = datamodule.encode(vertices_denorm, \"coordinates\")\n", - " normalized_pred, _ = model(vertices)\n", - " normalized_p_pred = normalized_pred[..., :1]\n", - " denorm_p_pred = datamodule.decode(normalized_p_pred, \"pressure\")\n", - " normalized_wss_pred = normalized_pred[..., 1:]\n", - " denorm_wss_pred = datamodule.decode(normalized_wss_pred, \"shear_stress\")\n", - "\n", - " # Read the original surface mesh.\n", - " idx = sample[\"design\"][0]\n", - " vtp_file = pc_path / f\"input_pc_{pc_size}_run_{idx}_final.vtp\"\n", - " print(f\"Reading {vtp_file}\")\n", - " mesh = pv.read(vtp_file)\n", - "\n", - " # Interpolate predictions on GT mesh.\n", - " nbrs_surface = NearestNeighbors(\n", - " n_neighbors=k, algorithm=\"ball_tree\"\n", - " ).fit(vertices_denorm[0].cpu().numpy())\n", - "\n", - " distances, indices = nbrs_surface.kneighbors(mesh.points)\n", - " if k == 1:\n", - " indices = indices.flatten()\n", - " pressure_pred_mesh = denorm_p_pred[0][indices]\n", - " shear_stress_pred_mesh = denorm_wss_pred[0][indices]\n", - " else:\n", - " # distances = distances.astype(np.float32)\n", - " # Weighted kNN interpolation\n", - " # Avoid division by zero by adding a small epsilon\n", - " epsilon = 1e-8\n", - " weights = 1 / (distances + epsilon)\n", - " weights_sum = np.sum(weights, axis=1, keepdims=True)\n", - " normalized_weights = weights / weights_sum\n", - " # Fetch the predictions of the k nearest neighbors\n", - " pressure_neighbors = denorm_p_pred[0][indices] # Shape: (n_samples, k, 1)\n", - " shear_stress_neighbors = denorm_wss_pred[0][indices] # Shape: (n_samples, k, 3)\n", - "\n", - " # Compute the weighted average\n", - " pressure_pred_mesh = np.sum(normalized_weights[:, :, np.newaxis] * pressure_neighbors.cpu().numpy(), axis=1)\n", - " shear_stress_pred_mesh = np.sum(normalized_weights[:, :, np.newaxis] * shear_stress_neighbors.cpu().numpy(), axis=1)\n", - "\n", - " # Convert back to torch tensors\n", - " pressure_pred_mesh = torch.from_numpy(pressure_pred_mesh).to(device)\n", - " shear_stress_pred_mesh = torch.from_numpy(shear_stress_pred_mesh).to(device)\n", - "\n", - " mesh.point_data[\"pMeanTrimPred\"] = pressure_pred_mesh.cpu().float().numpy()\n", - " mesh.point_data[\"wallShearStressMeanTrimPred\"] = shear_stress_pred_mesh.cpu().float().numpy()\n", - " out_path = output_path / f\"pc/500K_k{k}\"\n", - " out_path.mkdir(parents=True, exist_ok=True)\n", - " mesh.save(out_path / f\"inference_pc_{pc_size}_{idx}.vtp\")\n", - " print(\"Done.\")\n", - " torch.cuda.empty_cache()\n", - " # break\n", - "\n", - "# run_inference_on_pc(5_000_000)\n", - "# run_inference_on_pc(10_000_000)\n", - "# run_inference_on_pc(20_000_000)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@torch.no_grad\n", - "def inference_on_sim_mesh():\n", - " for idx in [100, 200, 300, 400, 500][:1]:\n", - " mesh_gt = pv.read(dataset_orig_path / f\"run_{idx}/boundary_{idx}.vtp\")\n", - " mesh_gt = mesh_gt.cell_data_to_point_data()\n", - " step = 500_000 # num_points\n", - " p_chunks = []\n", - " wss_chunks = []\n", - " rng = np.random.default_rng(1)\n", - " indices = rng.permutation(range(mesh_gt.number_of_points))\n", - " for i_start in range(0, mesh_gt.number_of_points, step):\n", - " vertices_denorm = torch.as_tensor(\n", - " mesh_gt.points[indices[i_start : i_start + step]], device=device\n", - " ).unsqueeze(0)\n", - " vertices = datamodule.encode(vertices_denorm, \"coordinates\")\n", - " normalized_pred, _ = model(vertices)\n", - "\n", - " normalized_p_pred = normalized_pred[..., :1]\n", - " denorm_p_pred = datamodule.decode(normalized_p_pred, \"pressure\")\n", - " p_chunks.append(denorm_p_pred.cpu())\n", - "\n", - " normalized_wss_pred = normalized_pred[..., 1:]\n", - " denorm_wss_pred = datamodule.decode(normalized_wss_pred, \"shear_stress\")\n", - " wss_chunks.append(denorm_wss_pred.cpu())\n", - " torch.cuda.empty_cache()\n", - "\n", - " pressure_pred_mesh = torch.cat(p_chunks, dim=1)[0]\n", - " shear_stress_pred_mesh = torch.cat(wss_chunks, dim=1)[0]\n", - " mesh_gt.point_data[\"pMeanTrimPred\"] = pressure_pred_mesh.cpu().float().numpy()\n", - " mesh_gt.point_data[\"wallShearStressMeanTrimPred\"] = shear_stress_pred_mesh.cpu().float().numpy()\n", - " mesh_gt.save(output_path / f\"inference_mesh_{idx}.vtp\")\n", - " print(\"Done.\")\n", - " torch.cuda.empty_cache()\n", - "\n", - "\n", - "# inference_on_sim_mesh()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# v = read_vtp(str(dataset_orig_path / f\"inference_point_cloud_{idx}.vtp\"))\n", - "# v = read_vtp(str(dataset_orig_path / f\"run_100/boundary_100.vtp\"))\n", - "# mesh_gt = pv.read(dataset_orig_path / f\"run_100/boundary_100.vtp\")\n", - "\n", - "# mesh_gt = mesh_gt.cell_data_to_point_data()\n", - "# mesh_pred = pv.read(dataset_orig_path / f\"inference/inference_mesh_100.vtp\")\n", - "# rrmse(torch.tensor(mesh_pred.point_data[\"pMeanTrim\"]), torch.tensor(mesh_pred.point_data[\"pMeanTrimPred\"]))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/cfd/external_aerodynamics/figconvnet/src/data/drivaerml_datamodule.py b/examples/cfd/external_aerodynamics/figconvnet/src/data/drivaerml_datamodule.py index 4ff7903dc..c18ba8596 100644 --- a/examples/cfd/external_aerodynamics/figconvnet/src/data/drivaerml_datamodule.py +++ b/examples/cfd/external_aerodynamics/figconvnet/src/data/drivaerml_datamodule.py @@ -30,7 +30,42 @@ class DrivAerMLPartitionedDataset(Dataset): - """DrivAerML partitioned dataset.""" + """DrivAerML partitioned dataset. + + The dataset enables reading meshes from binary files + generated by the Modulus XAeroNet data processing utility. + It reads data from all partitions in the source file and + samples a predefined number of points, + as specified by the `num_points` parameter. + + The dataset expects the data to have the following structure: + + ``` + ├─ partitions + │ ├─ graph_partitions_1.bin + │ ├─ graph_partitions_1.bin + │ ├─ ... + ├─ test_partitions + │ ├─ graph_partitions_100.bin + │ ├─ graph_partitions_101.bin + │ ├─ ... + ├─ validation_partitions + │ ├─ graph_partitions_200.bin + │ ├─ graph_partitions_201.bin + │ ├─ ... + └─ global_stats.json + ``` + + where `partitions` directory contains training samples. + + For further details and examples on how to create a partitioned dataset, + refer to: `modulus/examples/cfd/external_aerodynamics/xaeronet/surface`. + + Parameters: + ---------- + data_path (Path): path the directory that contains binary partitioned files. + num_points (int): number of points to sample from the mesh. + """ def __init__( self, From 08751238df5ff17a724e4ab83bee5e94093d93ab Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Wed, 8 Jan 2025 14:34:28 -0800 Subject: [PATCH 5/6] Update CHANGELOG. --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b65314d74..155003505 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- DrivAerML dataset support in FIGConvNet example. + ### Changed - Refactored StormCast training example From 3ac8b4304073d909945102c7d3ca57152ddc0418 Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Wed, 8 Jan 2025 15:53:29 -0800 Subject: [PATCH 6/6] Update docstring. --- .../figconvnet/src/networks/figconvunet_drivaer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/cfd/external_aerodynamics/figconvnet/src/networks/figconvunet_drivaer.py b/examples/cfd/external_aerodynamics/figconvnet/src/networks/figconvunet_drivaer.py index cd6616bb4..501999f83 100644 --- a/examples/cfd/external_aerodynamics/figconvnet/src/networks/figconvunet_drivaer.py +++ b/examples/cfd/external_aerodynamics/figconvnet/src/networks/figconvunet_drivaer.py @@ -38,7 +38,8 @@ class FIGConvUNetDrivAerNet(FIGConvUNet): """FIGConvUNetDrivAerNet - DrivAerNet is a variant of FIGConvUNet that is specialized for the DrivAer dataset. + FIGConvUNetDrivAerNet is a variant of FIGConvUNet + that is specialized for the DrivAerNet dataset. """ def __init__( @@ -247,7 +248,8 @@ def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]: class FIGConvUNetDrivAerML(FIGConvUNet): """FIGConvUNetDrivAerNet - DrivAerNet is a variant of FIGConvUNet that is specialized for the DrivAer dataset. + FIGConvUNetDrivAerML is a variant of FIGConvUNet + that is specialized for the DrivAerML dataset. """ def __init__(