From 0f1d7816ac4a8ab898b5428fcf84b3db2aba7ff1 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:28:59 -0600 Subject: [PATCH] [lammps] Fix handling types when restoring --- pysages/backends/lammps.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pysages/backends/lammps.py b/pysages/backends/lammps.py index 63892f90..e559875d 100644 --- a/pysages/backends/lammps.py +++ b/pysages/backends/lammps.py @@ -154,6 +154,16 @@ def add_bias(forces, biases): def add_bias(forces, biases): forces[:, :3] += factor * biases + def restore_vm(view, snapshot, prev_snapshot): + velocities = view(snapshot.vel_mass[0]) + masses_types = snapshot.vel_mass[1] + masses = view(masses_types[0]) + types = view(masses_types[1]) + prev_masses_types = prev_snapshot.vel_mass[1] + velocities[:] = view(prev_snapshot.vel_mass[0]) + masses[:] = view(prev_masses_types[0]) + types[:] = view(prev_masses_types[1]) + # TODO: check if this can be sped up. # pylint: disable=W0511 def bias(snapshot, state): """Adds the computed bias to the forces.""" @@ -166,7 +176,7 @@ def bias(snapshot, state): snapshot_methods = build_snapshot_methods(sampling_method, on_gpu) flags = sampling_method.snapshot_flags - restore = partial(restore_fn, view) + restore = partial(restore_fn, view, restore_vm=restore_vm) helpers = HelperMethods(build_data_querier(snapshot_methods, flags), lambda: dim) return helpers, restore, bias