Skip to content

Commit

Permalink
add sevennet support (#226)
Browse files Browse the repository at this point in the history
* add sevennet support
---------

Co-authored-by: ElliottKasoar <[email protected]>
Co-authored-by: Jacob Wilkins <[email protected]>
  • Loading branch information
3 people authored Jul 31, 2024
1 parent ccf2081 commit 8d77904
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 40 deletions.
30 changes: 25 additions & 5 deletions docs/source/developer_guide/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Converting ``model_path`` into ``path`` is a minimum requirement, but we also ai

.. note::
``model_path`` will already be a ``pathlib.Path`` object, if the path exists.
Some MLIPs do not support this, so you may be required to cast it back to a string (``str(model_path)``).

To ensure that the calculator does not receive multiple versions of keywords, it's also necessary to set ``model_path = path``, and remove ``path`` from ``kwargs``.

Expand All @@ -104,7 +105,12 @@ In addition to setting the calculator, ``__version__`` must also imported here,

Tests must be added to ensure that, at a minimum, the new calculator allows an MLIP to be loaded correctly, and that an energy can be calculated.

This can be done by adding the appropriate data as tuples to the ``pytest.mark.parametrize`` lists in the ``tests.test_mlip_calculators`` and ``tests.test_single_point`` modules.
This can be done by adding the appropriate data as tuples to the ``pytest.mark.parametrize`` lists in the ``tests.test_mlip_calculators`` and ``tests.test_single_point`` modules
that reside in files ``tests/test_mlip_calculators.py``` and ``tests/test_single_point.py``, respectively.


Load models - success
^^^^^^^^^^^^^^^^^^^^^

For ``tests.test_mlip_calculators``, ``architecture``, ``device`` and accepted forms of ``model_path`` should be tested, ensuring that the calculator and its version are correctly set::

Expand All @@ -121,27 +127,41 @@ For ``tests.test_mlip_calculators``, ``architecture``, ``device`` and accepted f
)
def test_extra_mlips(architecture, device, kwargs):

It is also useful to test that ``model_path``, and ``model`` or and the "standard" MLIP calculator parameter (``path``) cannot be defined simultaneously::
.. note::
Not all models support an empty (default) model path, so the equivalent test to``("alignn", "cpu", {})`` may need to be removed, or moved to the tests described in `Load models - failure`_.

Load models - failure
^^^^^^^^^^^^^^^^^^^^^

It is also useful to test that ``model_path``, and ``model`` or and the "standard" MLIP calculator parameter (``path``) cannot be defined simultaneously

.. code-block:: python
@pytest.mark.extra_mlips
@pytest.mark.parametrize(
"kwargs",
[
{
"model_path": "tests/models/v5.27.2024/best_model.pt",
"model": "tests/models/v5.27.2024/best_model.pt",
"architecture": "alignn",
"model_path": MODEL_PATH / "v5.27.2024" / "best_model.pt",
"model": MODEL_PATH / "v5.27.2024" / "best_model.pt",
},
{
"architecture": "alignn",
"model_path": "tests/models/v5.27.2024/best_model.pt",
"path": "tests/models/v5.27.2024/best_model.pt",
},
],
)
def test_extra_mlips_invalid(kwargs):
Test correctness
^^^^^^^^^^^^^^^^

For ``tests.test_single_point``, ``architecture``, ``device``, and the potential energy of NaCl predicted by the MLIP should be defined, ensuring that calculations can be performed::

test_extra_mlips_data = [("alignn", "cpu", -11.148092269897461)]
test_extra_mlips_data = [("alignn", "cpu", -11.148092269897461, {})]


Running these tests requires an additional flag to be passed to ``pytest``::

Expand Down
4 changes: 3 additions & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class CorrelationKwargs(TypedDict, total=True):


# Janus specific
Architectures = Literal["mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn"]
Architectures = Literal[
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet"
]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"]
Properties = Literal["energy", "stress", "forces"]
Expand Down
21 changes: 20 additions & 1 deletion janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def choose_calculator(

# No default `model_path`
if model_path is None:
raise ValueError("Please specify `model_path`")
raise ValueError(
"Please specify `model_path`, as there is no "
f"default model for {architecture}"
)
# Default to float64 precision
kwargs.setdefault("default_dtype", "float64")

Expand Down Expand Up @@ -203,6 +206,22 @@ def choose_calculator(

calculator = AlignnAtomwiseCalculator(path=path, device=device, **kwargs)

elif architecture == "sevennet":
from sevenn.sevennet_calculator import SevenNetCalculator

__version__ = "0.0.0"

if isinstance(model_path, Path):
model = str(model_path)
elif isinstance(model_path, str):
model = model_path
else:
model = "SevenNet-0_11July2024"

kwargs.setdefault("file_type", "checkpoint")
kwargs.setdefault("sevennet_config", None)
calculator = SevenNetCalculator(model=model, device=device, **kwargs)

else:
raise ValueError(
f"Unrecognized {architecture=}. Suported architectures "
Expand Down
13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ typer-config = "^1.4.0"
phonopy = "^2.23.1"
seekpath = "^1.9.7"
spglib = "^2.3.0"
torch-dftd = "^0.4.0"
torch-dftd = "0.4.0"
codecarbon = "^2.5.0"
alignn = { version = "2024.5.27", optional = true }
sevenn = { version = "0.9.3", optional = true }
torch_scatter = { version = "^2.1.2", optional = true }
torch_geometric = { version = "^2.5.3", optional = true }

[tool.poetry.group.extra-mlips]
optional = true
[tool.poetry.group.extra-mlips.dependencies]
alignn = "^2024.5.27"
[tool.poetry.extras]
alignnff = ["alignn"]
sevennet = ["sevenn", "torch_scatter", "torch_geometric"]

[tool.poetry.group.dev.dependencies]
coverage = {extras = ["toml"], version = "^7.4.1"}
Expand Down
Binary file added tests/models/sevennet_0.pth
Binary file not shown.
37 changes: 29 additions & 8 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
CHGNET_PATH = MODEL_PATH / "chgnet_0.3.0_e29f68s314m37.pth.tar"
CHGNET_MODEL = CHGNet.from_file(path=CHGNET_PATH)

SEVENNET_PATH = MODEL_PATH / "sevennet_0.pth"

ALIGNN_PATH = MODEL_PATH / "v5.27.2024"


@pytest.mark.parametrize(
"architecture, device, kwargs",
Expand Down Expand Up @@ -102,10 +106,15 @@ def test_invalid_device(architecture):
"architecture, device, kwargs",
[
("alignn", "cpu", {}),
("alignn", "cpu", {"model_path": MODEL_PATH / "v5.27.2024"}),
("alignn", "cpu", {"model_path": MODEL_PATH / "v5.27.2024/best_model.pt"}),
("alignn", "cpu", {"model_path": ALIGNN_PATH}),
("alignn", "cpu", {"model_path": ALIGNN_PATH / "best_model.pt"}),
("alignn", "cpu", {"model": "alignnff_wt10"}),
("alignn", "cpu", {"path": MODEL_PATH / "v5.27.2024"}),
("alignn", "cpu", {"path": ALIGNN_PATH}),
("sevennet", "cpu", {"model": SEVENNET_PATH}),
("sevennet", "cpu", {"path": SEVENNET_PATH}),
("sevennet", "cpu", {"model_path": SEVENNET_PATH}),
("sevennet", "cpu", {}),
("sevennet", "cpu", {"model": "sevennet-0"}),
],
)
def test_extra_mlips(architecture, device, kwargs):
Expand All @@ -123,16 +132,28 @@ def test_extra_mlips(architecture, device, kwargs):
"kwargs",
[
{
"model_path": MODEL_PATH / "v5.27.2024/best_model.pt",
"model": MODEL_PATH / "v5.27.2024/best_model.pt",
"architecture": "alignn",
"model_path": ALIGNN_PATH / "best_model.pt",
"model": ALIGNN_PATH / "best_model.pt",
},
{
"architecture": "alignn",
"model_path": ALIGNN_PATH / "best_model.pt",
"path": ALIGNN_PATH / "best_model.pt",
},
{
"architecture": "sevennet",
"model_path": SEVENNET_PATH,
"path": SEVENNET_PATH,
},
{
"model_path": MODEL_PATH / "v5.27.2024/best_model.pt",
"path": MODEL_PATH / "v5.27.2024/best_model.pt",
"architecture": "sevennet",
"model_path": SEVENNET_PATH,
"model": SEVENNET_PATH,
},
],
)
def test_extra_mlips_invalid(kwargs):
"""Test error raised if multiple model paths defined for extra MLIPs."""
with pytest.raises(ValueError):
choose_calculator(architecture="alignn", **kwargs)
choose_calculator(**kwargs)
49 changes: 29 additions & 20 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from tests.utils import read_atoms

DATA_PATH = Path(__file__).parent / "data"
MODEL_PATH = Path(__file__).parent / "models" / "mace_mp_small.model"
MODEL_PATH = Path(__file__).parent / "models"

MACE_PATH = MODEL_PATH / "mace_mp_small.model"
SEVENNET_PATH = MODEL_PATH / "sevennet_0.pth"

test_data = [
(DATA_PATH / "benzene.xyz", -76.0605725422795, "energy", "energy", {}, None),
Expand All @@ -34,7 +37,7 @@ def test_potential_energy(
struct_path, expected, properties, prop_key, calc_kwargs, idx
):
"""Test single point energy using MACE calculators."""
calc_kwargs["model"] = MODEL_PATH
calc_kwargs["model"] = MACE_PATH
single_point = SinglePoint(
struct_path=struct_path, architecture="mace", calc_kwargs=calc_kwargs
)
Expand All @@ -59,7 +62,7 @@ def test_single_point_none():
single_point = SinglePoint(
struct_path=DATA_PATH / "NaCl.cif",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

results = single_point.run()
Expand All @@ -72,7 +75,7 @@ def test_single_point_clean():
single_point = SinglePoint(
struct_path=DATA_PATH / "H2O.cif",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

results = single_point.run()
Expand All @@ -86,7 +89,7 @@ def test_single_point_traj():
single_point = SinglePoint(
struct_path=DATA_PATH / "benzene-traj.xyz",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

assert len(single_point.struct) == 2
Expand All @@ -110,7 +113,7 @@ def test_single_point_write():
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert "mace_forces" not in single_point.struct.arrays

Expand Down Expand Up @@ -142,7 +145,7 @@ def test_single_point_write_kwargs(tmp_path):
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert "mace_forces" not in single_point.struct.arrays

Expand All @@ -159,7 +162,7 @@ def test_single_point_molecule(tmp_path):
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

assert isfinite(single_point.run("energy")["energy"]).all()
Expand All @@ -177,7 +180,7 @@ def test_invalid_prop():
single_point = SinglePoint(
struct_path=DATA_PATH / "H2O.cif",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
with pytest.raises(NotImplementedError):
single_point.run("invalid")
Expand All @@ -190,7 +193,7 @@ def test_atoms():
struct=struct,
struct_name="NaCl",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert single_point.struct_name == "NaCl"
assert single_point.run("energy")["energy"] < 0
Expand All @@ -202,7 +205,7 @@ def test_default_atoms_name():
single_point = SinglePoint(
struct=struct,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert single_point.struct_name == "Cl4Na4"

Expand All @@ -213,7 +216,7 @@ def test_default_path_name():
single_point = SinglePoint(
struct_path=struct_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert single_point.struct_name == "NaCl"

Expand All @@ -225,7 +228,7 @@ def test_path_specify_name():
struct_path=struct_path,
struct_name="example_name",
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)
assert single_point.struct_name == "example_name"

Expand All @@ -239,7 +242,7 @@ def test_atoms_and_path():
struct=struct,
struct_path=struct_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)


Expand All @@ -248,7 +251,7 @@ def test_no_atoms_or_path():
with pytest.raises(ValueError):
SinglePoint(
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)


Expand All @@ -258,7 +261,7 @@ def test_invalidate_calc():
single_point = SinglePoint(
struct_path=struct_path,
architecture="mace",
calc_kwargs={"model": MODEL_PATH},
calc_kwargs={"model": MACE_PATH},
)

single_point.run(write_kwargs={"invalidate_calc": False})
Expand Down Expand Up @@ -286,17 +289,23 @@ def test_mlips(arch, device, expected_energy):
assert energy == pytest.approx(expected_energy)


test_extra_mlips_data = [("alignn", "cpu", -11.148092269897461)]
test_extra_mlips_data = [
("alignn", "cpu", -11.148092269897461, {}),
("sevennet", "cpu", -27.061979293823242, {"model_path": SEVENNET_PATH}),
("sevennet", "cpu", -27.061979293823242, {}),
("sevennet", "cpu", -27.061979293823242, {"model": "SevenNet-0_11July2024"}),
]


@pytest.mark.extra_mlips
@pytest.mark.parametrize("arch, device, expected_energy", test_extra_mlips_data)
def test_extra_mlips(arch, device, expected_energy):
"""Test single point energy using ALIGNN-FF calculator."""
@pytest.mark.parametrize("arch, device, expected_energy, kwargs", test_extra_mlips_data)
def test_extra_mlips_alignn(arch, device, expected_energy, kwargs):
"""Test single point energy using extra mlips calculators."""
single_point = SinglePoint(
struct_path=DATA_PATH / "NaCl.cif",
architecture=arch,
device=device,
**kwargs,
)
energy = single_point.run("energy")["energy"]
assert energy == pytest.approx(expected_energy)

0 comments on commit 8d77904

Please sign in to comment.