diff --git a/poetry.lock b/poetry.lock index 847b28d..aec6faf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -414,13 +414,13 @@ license = ["ukkonen"] [[package]] name = "importlib-metadata" -version = "8.0.0" +version = "8.2.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-8.0.0-py3-none-any.whl", hash = "sha256:15584cf2b1bf449d98ff8a6ff1abef57bf20f3ac6454f431736cd3e660921b2f"}, - {file = "importlib_metadata-8.0.0.tar.gz", hash = "sha256:188bd24e4c346d3f0a933f275c2fec67050326a856b9a359881d7c2a697e8812"}, + {file = "importlib_metadata-8.2.0-py3-none-any.whl", hash = "sha256:11901fa0c2f97919b288679932bb64febaeacf289d18ac84dd68cb2e74213369"}, + {file = "importlib_metadata-8.2.0.tar.gz", hash = "sha256:72e8d4399996132204f9a16dcc751af254a48f8d1b20b9ff0f98d4a8f901e73d"}, ] [package.dependencies] @@ -761,13 +761,13 @@ test = ["coverage"] [[package]] name = "mrcfile" -version = "1.5.1" +version = "1.5.3" description = "MRC file I/O library" optional = false python-versions = "*" files = [ - {file = "mrcfile-1.5.1-py2.py3-none-any.whl", hash = "sha256:06900f1245e66dd4617cbd4a7117a2d75d53fc4e5b74d811766f71a858b059a9"}, - {file = "mrcfile-1.5.1.tar.gz", hash = "sha256:403c4bb0ac842410ce5ea501f4fddc91ea37c12ef869d508d3ac571868d82ac2"}, + {file = "mrcfile-1.5.3-py2.py3-none-any.whl", hash = "sha256:fbf2b5583afae38656343f2d6bac67d85e0e798b2fd608be63ecd2758cd67c61"}, + {file = "mrcfile-1.5.3.tar.gz", hash = "sha256:3f304c02cb9f0900b26683679c5d3d750da64b5c370b58d69af8a8ddf720c0ce"}, ] [package.dependencies] @@ -835,7 +835,6 @@ files = [ {file = "msgpack-1.0.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fbb160554e319f7b22ecf530a80a3ff496d38e8e07ae763b9e82fadfe96f273"}, {file = "msgpack-1.0.8-cp39-cp39-win32.whl", hash = "sha256:f9af38a89b6a5c04b7d18c492c8ccf2aee7048aff1ce8437c4683bb5a1df893d"}, {file = "msgpack-1.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:ed59dd52075f8fc91da6053b12e8c89e37aa043f8986efd89e61fae69dc1b011"}, - {file = "msgpack-1.0.8-py3-none-any.whl", hash = "sha256:24f727df1e20b9876fa6e95f840a2a2651e34c0ad147676356f4bf5fbb0206ca"}, {file = "msgpack-1.0.8.tar.gz", hash = "sha256:95c02b0e27e706e48d0e5426d1710ca78e0f0628d6e89d5b5a5b91a5f12274f3"}, ] @@ -1305,6 +1304,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1312,8 +1312,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1330,6 +1337,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1337,6 +1345,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, diff --git a/pyproject.toml b/pyproject.toml index 3b1054e..188c5fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "znh5md" -version = "0.3.1" +version = "0.3.2" description = "ASE Interface for the H5MD format." authors = ["zincwarecode "] license = "Apache-2.0" diff --git a/tests/test_timestep.py b/tests/test_timestep.py new file mode 100644 index 0000000..142a076 --- /dev/null +++ b/tests/test_timestep.py @@ -0,0 +1,66 @@ +import ase.build +import h5py +import numpy as np +import numpy.testing as npt +import pytest +from ase.calculators.singlepoint import SinglePointCalculator + +import znh5md + + +def test_h5md_time(tmp_path): + io = znh5md.IO(tmp_path / "test_time_step.h5", store="time") + for step in range(1, 10): + atoms = ase.build.molecule("H2O") + atoms.calc = SinglePointCalculator(atoms, energy=step * 0.1) + atoms.info["h5md_step"] = step + atoms.info["h5md_time"] = step * 0.5 + io.append(atoms) + + for idx, atoms in enumerate(io[:]): + assert atoms.info["h5md_step"] == idx + 1 + assert atoms.info["h5md_time"] == (idx + 1) * 0.5 + assert atoms.get_potential_energy() == (idx + 1) * 0.1 + + with h5py.File(tmp_path / "test_time_step.h5") as f: + npt.assert_array_equal( + f["particles/atoms/position/time"][:], np.arange(1, 10) * 0.5 + ) + npt.assert_array_equal(f["particles/atoms/position/step"][:], np.arange(1, 10)) + npt.assert_array_equal( + f["observables/atoms/energy/time"][:], np.arange(1, 10) * 0.5 + ) + npt.assert_array_equal(f["observables/atoms/energy/step"][:], np.arange(1, 10)) + npt.assert_array_equal( + f["observables/atoms/energy/value"][:], np.arange(1, 10) * 0.1 + ) + + +def test_inconsistent_time(tmp_path): + images = [ase.build.molecule("H2O") for _ in range(10)] + images[5].info["h5md_time"] = 0.5 + + io = znh5md.IO(tmp_path / "test_inconsistent_time.h5", store="time") + with pytest.raises(ValueError): + io.extend(images) + + +def test_inconsistent_step(tmp_path): + images = [ase.build.molecule("H2O") for _ in range(10)] + images[5].info["h5md_step"] = 5 + + io = znh5md.IO(tmp_path / "test_inconsistent_step.h5", store="time") + with pytest.raises(ValueError): + io.extend(images) + + +def test_wrong_store(tmp_path, capsys): + io = znh5md.IO(tmp_path / "test_wrong_store.h5", store="linear") + atoms = ase.build.molecule("H2O") + atoms.info["h5md_step"] = 1 + atoms.info["h5md_time"] = 0.1 + + with pytest.warns( + UserWarning, match="time and step are ignored in 'linear' storage mode" + ): + io.append(atoms) diff --git a/tests/test_znh5md.py b/tests/test_znh5md.py index 43f4607..94c7291 100644 --- a/tests/test_znh5md.py +++ b/tests/test_znh5md.py @@ -2,4 +2,4 @@ def test_version(): - assert znh5md.__version__ == "0.3.1" + assert znh5md.__version__ == "0.3.2" diff --git a/znh5md/format.py b/znh5md/format.py index a5b463f..3bfe173 100644 --- a/znh5md/format.py +++ b/znh5md/format.py @@ -31,6 +31,8 @@ class ASEData: observables: Dict[str, np.ndarray] particles: Dict[str, np.ndarray] metadata: Optional[Dict[str, ASEKeyMetaData]] = None + time: Optional[np.ndarray] = None + step: Optional[np.ndarray] = None def __post_init__(self): if self.metadata is None: @@ -180,12 +182,17 @@ def extract_atoms_data(atoms: ase.Atoms) -> ASEData: if key not in all_properties and key not in ASE_TO_H5MD: particles[key] = value + time = atoms.info.get(CustomINFOData.h5md_time.name, None) + step = atoms.info.get(CustomINFOData.h5md_step.name, None) + return ASEData( cell=cell, pbc=pbc, observables=info_data, particles=particles, metadata={key: {"unit": None, "calc": True} for key in uses_calc}, + time=time, + step=step, ) @@ -200,12 +207,30 @@ def combine_asedata(data: List[ASEData]) -> ASEData: observables = _combine_dicts([x.observables for x in data]) particles = _combine_dicts([x.particles for x in data]) + time_occurrences = sum([x.time is not None for x in data]) + step_occurrences = sum([x.step is not None for x in data]) + if time_occurrences == len(data): + time = np.array([x.time for x in data]) + elif time_occurrences == 0: + time = None + else: + raise ValueError("Time is not consistent across data objects") + + if step_occurrences == len(data): + step = np.array([x.step for x in data]) + elif step_occurrences == 0: + step = None + else: + raise ValueError("Step is not consistent across data objects") + return ASEData( cell=cell, pbc=pbc, observables=observables, particles=particles, metadata=data[0].metadata, # we assume they are all equal + time=time, + step=step, ) diff --git a/znh5md/io.py b/znh5md/io.py index c351d98..e20a2bd 100644 --- a/znh5md/io.py +++ b/znh5md/io.py @@ -3,6 +3,7 @@ import os import pathlib import typing as t +import warnings from collections.abc import MutableSequence from typing import List, Optional, Union @@ -195,14 +196,18 @@ def extend(self, images: List[ase.Atoms]): def _create_particle_group(self, f, data: fmt.ASEData): g_particle_grp = f["particles"].create_group(self.particle_group) - self._create_group(g_particle_grp, "box/edges", data.cell) + self._create_group( + g_particle_grp, "box/edges", data.cell, time=data.time, step=data.step + ) g_particle_grp["box"].attrs["dimension"] = 3 g_particle_grp["box"].attrs["boundary"] = [ "periodic" if y else "none" for y in data.pbc[0] ] if self.pbc_group and data.pbc is not None: - self._create_group(g_particle_grp, "box/pbc", data.pbc) + self._create_group( + g_particle_grp, "box/pbc", data.pbc, time=data.time, step=data.step + ) for key, value in data.particles.items(): self._create_group( g_particle_grp, @@ -210,11 +215,15 @@ def _create_particle_group(self, f, data: fmt.ASEData): value, data.metadata.get(key, {}).get("unit"), calc=data.metadata.get(key, {}).get("calc"), + time=data.time, + step=data.step, ) self._create_observables( f, data.observables, data.metadata, + time=data.time, + step=data.step, ) def _create_group( @@ -224,6 +233,8 @@ def _create_group( data, unit: Optional[str] = None, calc: Optional[bool] = None, + time: np.ndarray | None = None, + step: np.ndarray | None = None, ): if data is not None: g_grp = parent_grp.create_group(name) @@ -240,14 +251,18 @@ def _create_group( ds_value.attrs["ASE_CALCULATOR_RESULT"] = calc if unit and self.save_units: ds_value.attrs["unit"] = unit - self._add_time_and_step(g_grp, len(data)) + if time is None: + time = np.arange(len(data)) * self.timestep + if step is None: + step = np.arange(len(data)) + self._add_time_and_step(g_grp, step, time) - def _add_time_and_step(self, grp, length): + def _add_time_and_step(self, grp, step: np.ndarray, time: np.ndarray): if self.store == "time": ds_time = grp.create_dataset( "time", dtype=np.float64, - data=np.arange(length) * self.timestep, + data=time, compression=self.compression, compression_opts=self.compression_opts, maxshape=(None,), @@ -256,12 +271,14 @@ def _add_time_and_step(self, grp, length): ds_step = grp.create_dataset( "step", dtype=int, - data=np.arange(length), + data=step, compression=self.compression, compression_opts=self.compression_opts, maxshape=(None,), ) elif self.store == "linear": + if time is not None or step is not None: + warnings.warn("time and step are ignored in 'linear' storage mode") ds_time = grp.create_dataset( "time", dtype=np.float64, @@ -281,6 +298,8 @@ def _create_observables( f, info_data, metadata: dict, + time: np.ndarray | None = None, + step: np.ndarray | None = None, ): if info_data: g_observables = f.require_group("observables") @@ -300,33 +319,58 @@ def _create_observables( ds_value.attrs["ASE_CALCULATOR_RESULT"] = metadata[key]["calc"] if metadata.get(key, {}).get("unit") and self.save_units: ds_value.attrs["unit"] = metadata[key]["unit"] - self._add_time_and_step(g_observable, len(value)) + if time is None: + time = np.arange(len(value)) * self.timestep + if step is None: + step = np.arange(len(value)) + self._add_time_and_step(g_observable, step, time) def _extend_existing_data(self, f, data: fmt.ASEData): g_particle_grp = f["particles"][self.particle_group] - self._extend_group(g_particle_grp, "box/edges", data.cell) + self._extend_group( + g_particle_grp, "box/edges", data.cell, step=data.step, time=data.time + ) if self.pbc_group and data.pbc is not None: - self._extend_group(g_particle_grp, "box/pbc", data.pbc) + self._extend_group( + g_particle_grp, "box/pbc", data.pbc, step=data.step, time=data.time + ) for key, value in data.particles.items(): - self._extend_group(g_particle_grp, key, value) - self._extend_observables(f, data.observables) + self._extend_group( + g_particle_grp, key, value, step=data.step, time=data.time + ) + self._extend_observables(f, data.observables, step=data.step, time=data.time) - def _extend_group(self, parent_grp, name, data): + def _extend_group( + self, + parent_grp, + name, + data, + step: np.ndarray | None = None, + time: np.ndarray | None = None, + ): if data is not None and name in parent_grp: g_grp = parent_grp[name] utils.fill_dataset(g_grp["value"], data) if self.store == "time": - last_time = g_grp["time"][-1] - last_step = g_grp["step"][-1] + if time is None: + last_time = g_grp["time"][-1] + time = np.arange(len(data)) * self.timestep + last_time + if step is None: + last_step = g_grp["step"][-1] + step = np.arange(len(data)) + last_step utils.fill_dataset( g_grp["time"], - np.arange(1, len(data) + 1) * self.timestep + last_time, - ) - utils.fill_dataset( - g_grp["step"], np.arange(1, len(data) + 1) + last_step + time, ) + utils.fill_dataset(g_grp["step"], step) - def _extend_observables(self, f, info_data): + def _extend_observables( + self, + f, + info_data, + step: np.ndarray | None = None, + time: np.ndarray | None = None, + ): if f"observables/{self.particle_group}" in f: g_observables = f[f"observables/{self.particle_group}"] for key, value in info_data.items(): @@ -334,15 +378,17 @@ def _extend_observables(self, f, info_data): g_val = g_observables[key] utils.fill_dataset(g_val["value"], value) if self.store == "time": - last_time = g_val["time"][-1] - last_step = g_val["step"][-1] + if time is None: + last_time = g_val["time"][-1] + time = np.arange(len(value)) * self.timestep + last_time + if step is None: + last_step = g_val["step"][-1] + step = np.arange(len(value)) + last_step utils.fill_dataset( g_val["time"], - np.arange(len(value)) * self.timestep + last_time, - ) - utils.fill_dataset( - g_val["step"], np.arange(len(value)) + last_step + time, ) + utils.fill_dataset(g_val["step"], step) def append(self, atoms: ase.Atoms): self.extend([atoms])