Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing module #6

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
49 changes: 24 additions & 25 deletions ipsuite/analysis/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import logging
import pathlib
import typing

import ase
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -399,28 +400,21 @@ class BoxScaleAnalysis(base.ProcessSingleAtom):
Attributes
----------
model: The MLModel node that implements the 'predict' method
atoms: list[Atoms] to predict properties for
logspace: bool, default=True
Increase the stdev of rattle with 'np.logspace' instead of 'np.linspace'
stop: float, default = 1.0
The stop value for the generated space of stdev points
num: int, default = 100
The size of the generated space of stdev points
factor: float, default = 0.001
The 'np.linspace(0.0, stop, num) * factor'
atom_id: int, default = 0
The atom to pick from self.atoms as a starting point
start: int, default = None
The initial box scale, default value is the original box size.
"""

model: models.MLModel = zntrack.zn.deps()

logspace: bool = zntrack.zn.params(False)
stop: float = zntrack.zn.params(2.0)
factor: float = zntrack.zn.params(1.0)
mapping: typing.Callable = zntrack.zn.deps(None)

start: float = zntrack.zn.params(0.9)
stop: float = zntrack.zn.params(2.5)
num: int = zntrack.zn.params(100)
start: float = zntrack.zn.params(None)

energies: pd.DataFrame = zntrack.zn.plots(
# x="x",
Expand All @@ -429,35 +423,40 @@ class BoxScaleAnalysis(base.ProcessSingleAtom):
# y_label="predicted energy",
)

def post_init(self):
figure = zntrack.dvc.outs(zntrack.nwd / "box_scale_analysis.png")

def _post_init_(self):
self.data = utils.helpers.get_deps_if_node(self.data, "atoms")
if self.start is None:
self.start = 0.0 if self.logspace else 1.0

def run(self):
if self.logspace:
scale_space = (
np.logspace(start=self.start, stop=self.stop, num=self.num) * self.factor
)
else:
scale_space = (
np.linspace(start=self.start, stop=self.stop, num=self.num) * self.factor
)
scale_space = np.linspace(start=self.start, stop=self.stop, num=self.num)

atoms = self.get_data()
cell = atoms.copy().cell
atoms.calc = self.model.calc
calc = self.model.calc

energies = []
self.atoms = []

for scale in tqdm.tqdm(scale_space, ncols=70):
atoms.set_cell(cell=cell * scale, scale_atoms=True)
energies.append(atoms.get_potential_energy())
self.atoms.append(atoms.copy())
if self.mapping is not None:
new_atoms = self.mapping({self.data_id: atoms})[0].copy()
else:
new_atoms = atoms.copy()
new_atoms.calc = calc
energies.append(new_atoms.get_potential_energy())

self.atoms.append(new_atoms)

self.energies = pd.DataFrame({"y": energies, "x": scale_space})

fig, ax = plt.subplots()
ax.plot(scale_space, energies)
ax.set_ylabel("Predicted Energy (eV)")
ax.set_xlabel("Scale factor of the initial cell")
fig.savefig(self.figure, bbox_inches="tight")


class BoxHeatUp(base.ProcessSingleAtom):
"""Attributes
Expand Down
3 changes: 2 additions & 1 deletion ipsuite/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Custom ZnTrack serialization types."""

from ipsuite.fields.atoms import Atoms
from ipsuite.fields.graph import NxGraph

__all__ = ["Atoms"]
__all__ = ["Atoms", "NxGraph"]
5 changes: 4 additions & 1 deletion ipsuite/fields/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def get_stage_add_argument(self, instance: zntrack.Node) -> typing.List[tuple]:

def save(self, instance: zntrack.Node):
"""Save value with ase.db.connect."""
atoms: base.ATOMS_LST = getattr(instance, self.name)
try:
atoms: base.ATOMS_LST = getattr(instance, self.name)
except AttributeError:
return
instance.nwd.mkdir(exist_ok=True, parents=True)
file = self.get_files(instance)[0]

Expand Down
40 changes: 40 additions & 0 deletions ipsuite/fields/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Lazy ASE Atoms loading."""
import json
import pathlib
import typing

import networkx as nx
import zntrack


class NxGraph(zntrack.Field):
"""Store list[ase.Atoms] in an ASE database."""

dvc_option = "--outs"
group = zntrack.FieldGroup.RESULT

def __init__(self):
super().__init__(use_repr=False)

def get_files(self, instance: zntrack.Node) -> list:
return [(instance.nwd / f"{self.name}.json").as_posix()]

def get_stage_add_argument(self, instance: zntrack.Node) -> typing.List[tuple]:
return [(self.dvc_option, file) for file in self.get_files(instance)]

def save(self, instance: zntrack.Node):
"""Save value with ase.db.connect."""
try:
graph: nx.Graph = getattr(instance, self.name)
except AttributeError:
return
instance.nwd.mkdir(exist_ok=True, parents=True)
file = self.get_files(instance)[0]
with pathlib.Path(file).open("w") as f:
json.dump(graph, f, default=nx.node_link_data)

def get_data(self, instance: zntrack.Node) -> typing.List[nx.Graph]:
"""Get graph File."""
file = self.get_files(instance)[0]
with pathlib.Path(file).open("r") as f:
return [nx.node_link_graph(x) for x in json.load(f)]
147 changes: 147 additions & 0 deletions ipsuite/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
import pathlib
import typing

import ase.io
import dvc.cli
import git
import numpy as np
import zntrack
from ase.calculators.singlepoint import SinglePointCalculator
from zntrack.tools import timeit

from ipsuite import AddData, Project, base, fields


class UpdateCalculator(base.ProcessSingleAtom):
"""Update the calculator of an atoms object.

Set energy, forces to zero.
"""

energy = zntrack.zn.params(0.0)
forces = zntrack.zn.params((0, 0, 0))

time: float = zntrack.zn.metrics()

@timeit(field="time")
def run(self) -> None:
self.atoms = self.get_data()

self.atoms.calc = SinglePointCalculator(
self.atoms,
energy=self.energy,
forces=np.stack([self.forces] * len(self.atoms)),
)
self.atoms = [self.atoms]


class MockAtoms(zntrack.Node):
"""Create Atoms objects with random data."""

atoms: typing.List[ase.Atoms] = fields.Atoms()
seed: int = zntrack.zn.params(0)

n_configurations: int = zntrack.zn.params(10)
n_atoms: int = zntrack.zn.params(10)

calculator: bool = zntrack.zn.params(True)

def run(self) -> None:
self.atoms = []
np.random.seed(self.seed)
for _ in range(self.n_configurations):
atoms = ase.Atoms(
symbols="C" * self.n_atoms,
positions=np.random.random((self.n_atoms, 3)),
)
if self.calculator:
atoms.calc = SinglePointCalculator(
atoms,
energy=np.random.random(),
forces=np.random.random((self.n_atoms, 3)),
)
self.atoms.append(atoms)


class AtomsToXYZ(base.AnalyseAtoms):
"""Convert Atoms objects to XYZ files."""

output: pathlib.Path = zntrack.dvc.outs(zntrack.nwd / "atoms")

def run(self) -> None:
self.output.mkdir(parents=True, exist_ok=True)
for idx, atom in enumerate(self.data):
ase.io.write(self.output / f"{idx:05d}.xyz", atom)

@property
def files(self) -> typing.List[pathlib.Path]:
return [x.resolve() for x in self.output.glob("*.xyz")]


class NodesPerAtoms(base.ProcessAtoms):
processor: base.ProcessSingleAtom = zntrack.zn.nodes()
repo: str = zntrack.meta.Text(None)
commit: bool = zntrack.meta.Text(True)
clean_exp: bool = zntrack.meta.Text(True)

def run(self):
# lazy loading: load now
_ = self.data
processor = self.processor
processor.name = processor.__class__.__name__

repo = git.Repo.init(self.repo or self.name)

gitignore = pathlib.Path(".gitignore")
# TODO: move this into a function
if not gitignore.exists():
gitignore.write_text(f"{repo.working_dir}\n")
elif repo.working_dir not in gitignore.read_text().split(" "):
gitignore.write_text(f"{repo.working_dir}\n")

os.chdir(repo.working_dir)
dvc.cli.main(["init"])
project = Project()

with project:
data = AddData(file="atoms.xyz")
project.run(repro=False)

processor.data = data @ "atoms"
processor.write_graph()

repo.git.add(all=True)
repo.index.commit("Build graph")

if self.clean_exp:
dvc.cli.main(["exp", "gc", "-w", "-f"])

self.run_exp(project, processor)
if self.commit:
self.run_commits(repo)

os.chdir("..") # we need to go back to save

def run_exp(self, project, processor):
exp_lst = []
for atom in self.data:
with project.create_experiment() as exp:
ase.io.write("atoms.xyz", atom)
exp_lst.append(exp)
project.run_exp()

self.atoms = [
processor.from_rev(name=processor.name, rev=x.name).atoms[0] for x in exp_lst
]

def run_commits(self, repo):
commits = []
for idx, atom in enumerate(self.data):
ase.io.write("atoms.xyz", atom)
dvc.cli.main(["add", "atoms.xyz"])
dvc.cli.main(["repro"])
repo.git.add(all=True)
# do not use repo.index.add("*"); it will add atoms.xyz
commit_message = f"repro {self.name}_{idx}"
commits.append(repo.index.commit(commit_message))
Loading