Skip to content

Commit

Permalink
set time and step via ase atoms object (#103)
Browse files Browse the repository at this point in the history
* test setup

* prepare for setting time / step from ase.atoms object

* set step / time via ase attributes

* test new features

* update lock
  • Loading branch information
PythonFZ authored Jul 26, 2024
1 parent 7d0c4aa commit 758da70
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 34 deletions.
23 changes: 16 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znh5md"
version = "0.3.1"
version = "0.3.2"
description = "ASE Interface for the H5MD format."
authors = ["zincwarecode <[email protected]>"]
license = "Apache-2.0"
Expand Down
66 changes: 66 additions & 0 deletions tests/test_timestep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import ase.build
import h5py
import numpy as np
import numpy.testing as npt
import pytest
from ase.calculators.singlepoint import SinglePointCalculator

import znh5md


def test_h5md_time(tmp_path):
io = znh5md.IO(tmp_path / "test_time_step.h5", store="time")
for step in range(1, 10):
atoms = ase.build.molecule("H2O")
atoms.calc = SinglePointCalculator(atoms, energy=step * 0.1)
atoms.info["h5md_step"] = step
atoms.info["h5md_time"] = step * 0.5
io.append(atoms)

for idx, atoms in enumerate(io[:]):
assert atoms.info["h5md_step"] == idx + 1
assert atoms.info["h5md_time"] == (idx + 1) * 0.5
assert atoms.get_potential_energy() == (idx + 1) * 0.1

with h5py.File(tmp_path / "test_time_step.h5") as f:
npt.assert_array_equal(
f["particles/atoms/position/time"][:], np.arange(1, 10) * 0.5
)
npt.assert_array_equal(f["particles/atoms/position/step"][:], np.arange(1, 10))
npt.assert_array_equal(
f["observables/atoms/energy/time"][:], np.arange(1, 10) * 0.5
)
npt.assert_array_equal(f["observables/atoms/energy/step"][:], np.arange(1, 10))
npt.assert_array_equal(
f["observables/atoms/energy/value"][:], np.arange(1, 10) * 0.1
)


def test_inconsistent_time(tmp_path):
images = [ase.build.molecule("H2O") for _ in range(10)]
images[5].info["h5md_time"] = 0.5

io = znh5md.IO(tmp_path / "test_inconsistent_time.h5", store="time")
with pytest.raises(ValueError):
io.extend(images)


def test_inconsistent_step(tmp_path):
images = [ase.build.molecule("H2O") for _ in range(10)]
images[5].info["h5md_step"] = 5

io = znh5md.IO(tmp_path / "test_inconsistent_step.h5", store="time")
with pytest.raises(ValueError):
io.extend(images)


def test_wrong_store(tmp_path, capsys):
io = znh5md.IO(tmp_path / "test_wrong_store.h5", store="linear")
atoms = ase.build.molecule("H2O")
atoms.info["h5md_step"] = 1
atoms.info["h5md_time"] = 0.1

with pytest.warns(
UserWarning, match="time and step are ignored in 'linear' storage mode"
):
io.append(atoms)
2 changes: 1 addition & 1 deletion tests/test_znh5md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


def test_version():
assert znh5md.__version__ == "0.3.1"
assert znh5md.__version__ == "0.3.2"
25 changes: 25 additions & 0 deletions znh5md/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class ASEData:
observables: Dict[str, np.ndarray]
particles: Dict[str, np.ndarray]
metadata: Optional[Dict[str, ASEKeyMetaData]] = None
time: Optional[np.ndarray] = None
step: Optional[np.ndarray] = None

def __post_init__(self):
if self.metadata is None:
Expand Down Expand Up @@ -180,12 +182,17 @@ def extract_atoms_data(atoms: ase.Atoms) -> ASEData:
if key not in all_properties and key not in ASE_TO_H5MD:
particles[key] = value

time = atoms.info.get(CustomINFOData.h5md_time.name, None)
step = atoms.info.get(CustomINFOData.h5md_step.name, None)

return ASEData(
cell=cell,
pbc=pbc,
observables=info_data,
particles=particles,
metadata={key: {"unit": None, "calc": True} for key in uses_calc},
time=time,
step=step,
)


Expand All @@ -200,12 +207,30 @@ def combine_asedata(data: List[ASEData]) -> ASEData:
observables = _combine_dicts([x.observables for x in data])
particles = _combine_dicts([x.particles for x in data])

time_occurrences = sum([x.time is not None for x in data])
step_occurrences = sum([x.step is not None for x in data])
if time_occurrences == len(data):
time = np.array([x.time for x in data])
elif time_occurrences == 0:
time = None
else:
raise ValueError("Time is not consistent across data objects")

if step_occurrences == len(data):
step = np.array([x.step for x in data])
elif step_occurrences == 0:
step = None
else:
raise ValueError("Step is not consistent across data objects")

return ASEData(
cell=cell,
pbc=pbc,
observables=observables,
particles=particles,
metadata=data[0].metadata, # we assume they are all equal
time=time,
step=step,
)


Expand Down
Loading

0 comments on commit 758da70

Please sign in to comment.