Skip to content

Commit

Permalink
bugfix: single entry issue (#157)
Browse files Browse the repository at this point in the history
* bump version

* add test case

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test info/calc/arrays

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* partial test

* partial bugfix

* bugfix finalize

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* undo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* partially fixed

* final bugfix

* remove debug prints

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove comment

* fix version test

* test frames as well

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* another test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Dec 9, 2024
1 parent 8fc920a commit fa64d56
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znh5md"
version = "0.4.1"
version = "0.4.2"
description = "ASE Interface for the H5MD format."
authors = ["zincwarecode <[email protected]>"]
license = "Apache-2.0"
Expand Down
66 changes: 66 additions & 0 deletions tests/test_single_obs_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy.testing as npt
from ase.build import molecule
from ase.calculators.singlepoint import SinglePointCalculator

import znh5md
import znh5md.serialization


def test_single_entry_info(tmp_path):
# Test a special case where only the first config has the key
# which caused an error in the past
io = znh5md.IO(tmp_path / "test.h5")
water = molecule("H2O")
water.info["density"] = 0.997
io.append(water)
del water.info["density"]
io.extend([water for _ in range(5)])
assert len(io) == 6
assert len(list(io)) == 6
assert len(io[:]) == 6
assert io[0].info["density"] == 0.997
assert "density" not in io[1].info

frames = znh5md.serialization.Frames.from_ase(list(io))
assert len(frames) == 6
assert len(list(frames)) == 6


def test_single_entry_arrays(tmp_path):
# Test a special case where only the first config has the key
# which caused an error in the past
io = znh5md.IO(tmp_path / "test.h5")
water = molecule("H2O")
water.arrays["density"] = [0.997, 0.998, 0.999]
io.append(water)
del water.arrays["density"]
io.extend([water for _ in range(5)])
assert len(io) == 6
assert len(list(io)) == 6
assert len(io[:]) == 6
npt.assert_array_equal(io[0].arrays["density"], [0.997, 0.998, 0.999])
assert "density" not in io[1].arrays

frames = znh5md.serialization.Frames.from_ase(list(io))
assert len(frames) == 6
assert len(list(frames)) == 6


def test_single_entry_calc(tmp_path):
# Test a special case where only the first config has the key
# which caused an error in the past
io = znh5md.IO(tmp_path / "test.h5")
water = molecule("H2O")
water.calc = SinglePointCalculator(water, energy=0.0, forces=[0.0, 0.0, 0.0])
io.append(water)
water.calc = None
io.extend([water for _ in range(5)])
assert len(io) == 6
assert len(list(io)) == 6
assert len(io[:]) == 6
assert io[0].calc.results["energy"] == 0.0
assert io[1].calc is None

frames = znh5md.serialization.Frames.from_ase(list(io))
assert len(frames) == 6
assert len(list(frames)) == 6
2 changes: 1 addition & 1 deletion tests/test_znh5md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def test_version():
assert znh5md.__version__ == "0.4.1"
assert znh5md.__version__ == "0.4.2"


def test_creator(tmp_path):
Expand Down
23 changes: 12 additions & 11 deletions znh5md/interface/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,17 +329,18 @@ def process_observables(self, frames: Frames, observables, index) -> None:
origin = grp.attrs.get(AttributePath.origin.value, None)
try:
try:
update_frames(
frames,
H5MDToASEMapping[grp_name].value,
grp["value"][index],
origin,
self.use_ase_calc,
)
except KeyError:
update_frames(
frames, grp_name, grp["value"][index], origin, self.use_ase_calc
)
try:
update_frames(
frames,
H5MDToASEMapping[grp_name].value,
grp["value"][index],
origin,
self.use_ase_calc,
)
except KeyError:
update_frames(
frames, grp_name, grp["value"][index], origin, self.use_ase_calc
)
except (OSError, IndexError):
pass # Handle backfilling for invalid values
except KeyError:
Expand Down
30 changes: 18 additions & 12 deletions znh5md/serialization/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import dataclasses
import functools
import json
Expand Down Expand Up @@ -238,28 +239,33 @@ def __len__(self) -> int:

def __getitem__(self, idx: int) -> ase.Atoms:
"""Return a single frame."""
# this raises the IndexError to determine the length of the Frames object
atoms = ase.Atoms(
numbers=self.numbers[idx],
positions=self.positions[idx],
cell=self.cell[idx],
pbc=self.pbc[idx],
)
# all data following here can be missing
for key in self.arrays:
if isinstance(self.arrays[key][idx], _MISSING):
continue
if key == "velocities":
atoms.set_velocities(self.arrays[key][idx])
else:
atoms.arrays[key] = self.arrays[key][idx]
with contextlib.suppress(IndexError):
if isinstance(self.arrays[key][idx], _MISSING):
continue
if key == "velocities":
atoms.set_velocities(self.arrays[key][idx])
else:
atoms.arrays[key] = self.arrays[key][idx]

for key in self.info:
if not isinstance(self.info[key][idx], _MISSING):
atoms.info[key] = self.info[key][idx]
with contextlib.suppress(IndexError):
if not isinstance(self.info[key][idx], _MISSING):
atoms.info[key] = self.info[key][idx]
for key in self.calc:
if not isinstance(self.calc[key][idx], _MISSING):
if atoms.calc is None:
atoms.calc = SinglePointCalculator(atoms)
atoms.calc.results[key] = self.calc[key][idx]
with contextlib.suppress(IndexError):
if not isinstance(self.calc[key][idx], _MISSING):
if atoms.calc is None:
atoms.calc = SinglePointCalculator(atoms)
atoms.calc.results[key] = self.calc[key][idx]

return atoms

Expand Down

0 comments on commit fa64d56

Please sign in to comment.