Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: PyTorch Surrogate Example #621

Merged
merged 6 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/dependencies/gcc-openmpi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ python3 -m pip install -U -r src/python/impactx/dashboard/requirements.txt
python3 -m pip install -U -r examples/requirements.txt
python3 -m pip install -U -r tests/python/requirements.txt

# extra tests
python3 -m pip install -U -r examples/requirements_torch_cpu.txt
python3 -m pip install -U openPMD-validator
2 changes: 2 additions & 0 deletions .github/workflows/dependencies/gcc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ python3 -m pip install -U -r src/python/impactx/dashboard/requirements.txt
python3 -m pip install -U -r examples/requirements.txt
python3 -m pip install -U -r tests/python/requirements.txt

# extra tests
python3 -m pip install -U -r examples/requirements_torch_cpu.txt
python3 -m pip install -U openPMD-validator
25 changes: 25 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ function(add_impactx_test name input is_mpi analysis_script plot_script)
else()
set_property(TEST ${name}.run APPEND PROPERTY ENVIRONMENT "OMP_NUM_THREADS=2")
endif()
# special return code for skipped tests (e.g., runtime prerequisite fails)
set_tests_properties(${name}.run PROPERTIES SKIP_RETURN_CODE 42)

# analysis and plots
set(THIS_Python_SCRIPT_EXE)
Expand All @@ -131,6 +133,11 @@ function(add_impactx_test name input is_mpi analysis_script plot_script)

# make HDF5 I/O more robust on various filesystems
set_property(TEST ${name}.analysis APPEND PROPERTY ENVIRONMENT "HDF5_USE_FILE_LOCKING=FALSE")

# run test failed? Mark this as skipped
set_property(TEST ${name}.analysis PROPERTY SKIP_REGULAR_EXPRESSION
"Supplied directory is not valid: diags"
)
endif()
if(plot_script)
add_test(NAME ${name}.plot
Expand All @@ -141,6 +148,11 @@ function(add_impactx_test name input is_mpi analysis_script plot_script)

# make HDF5 I/O more robust on various filesystems
set_property(TEST ${name}.plot APPEND PROPERTY ENVIRONMENT "HDF5_USE_FILE_LOCKING=FALSE")

# run test failed? Mark this as skipped
set_property(TEST ${name}.plot PROPERTY SKIP_REGULAR_EXPRESSION
"ValueError: No objects to concatenate"
)
endif()
endfunction()

Expand Down Expand Up @@ -1000,6 +1012,7 @@ add_impactx_test(spectrometer.py
OFF # no plot script yet
)


# Chicane with CSR ###########################################################
#
if(ImpactX_FFT)
Expand Down Expand Up @@ -1097,6 +1110,7 @@ add_impactx_test(linac-segment.py
OFF # no plot script yet
)


# Iteration of a linear one-turn map #########################################
#
# w/o space charge
Expand All @@ -1112,3 +1126,14 @@ add_impactx_test(linear-map.py
examples/linear_map/analysis_map.py
OFF # no plot script yet
)


# PyTorch Surrogate: Staged LPA ##############################################
#
add_impactx_test(pytorch_surrogate_model
examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py
OFF # ImpactX MPI-parallel
examples/pytorch_surrogate_model/analyze_ml_surrogate_15_stage.py
examples/pytorch_surrogate_model/visualize_ml_surrogate_15_stage.py
)
label_impactx_test(pytorch_surrogate_model slow)
18 changes: 10 additions & 8 deletions examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
print("Warning: Cannot import PyTorch. Skipping test.")
import sys

sys.exit(0)
sys.exit(42) # ImpactX special return code for skipped tests

import zipfile
from urllib import request
Expand Down Expand Up @@ -100,18 +100,19 @@ def download_and_unzip(url, data_dir):
data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1"
download_and_unzip(data_url, "models.zip")

# It was found that the PyTorch multithreaded defaults interfere with MPI-enabled AMReX
# when initializing the models: https://github.com/AMReX-Codes/pyamrex/issues/322
# It was found that the PyTorch multithreaded defaults interfere with AMReX OpenMP
# when initializing the models or iterating elements:
# https://github.com/AMReX-Codes/pyamrex/issues/322
# https://github.com/ECP-WarpX/impactx/issues/773#issuecomment-2585043099
# So we manually set the number of threads to serial (1).
if Config.have_mpi:
n_threads = torch.get_num_threads()
torch.set_num_threads(1)
# Torch threading is not a problem with GPUs and might work when MPI is disabled.
# Could also just be a mixing of OpenMP libraries (gomp and llvm omp) when using the
# pre-build PyTorch pip packages.
torch.set_num_threads(1)
model_list = [
surrogate_model(f"models/beam_stage_{stage_i}_model.pt", device=device)
for stage_i in range(N_stage)
]
if Config.have_mpi:
torch.set_num_threads(n_threads)

pp_amrex = amr.ParmParse("amrex")
pp_amrex.add("the_arena_init_size", 0)
Expand Down Expand Up @@ -328,6 +329,7 @@ def set_lens(self, pc, step, period):
lpa = LPASurrogateStage(i, model_list[i], L_surrogate, L_stage_period * i)
lpa.nslice = n_slice
lpa.ds = L_surrogate
lpa.threadsafe = False
lpa_stages.append(lpa)

monitor = elements.BeamMonitor("monitor")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ class surrogate_model:
def __init__(self, model_file, device=None):
self.device = device
if device is None:
model_dict = torch.load(model_file, map_location="cpu")
model_dict = torch.load(model_file, map_location="cpu", weights_only=False)
else:
model_dict = torch.load(model_file, map_location=device)
model_dict = torch.load(model_file, map_location=device, weights_only=False)
self.source_means = torch.tensor(
model_dict["source_means"], device=self.device, dtype=torch.float64
)
Expand Down
6 changes: 6 additions & 0 deletions examples/requirements_torch_cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# This is for CPU CI tests with extra requirements.
#
# For PyTorch, see alternative packages, e.g., for GPU here:
# https://pytorch.org/get-started/locally/
--extra-index-url https://download.pytorch.org/whl/cpu
torch
Loading