-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* move to deployment module * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove deployment file * use vanilla deployment class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * temporary circular import bugfix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use base class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add deployment exchanging tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * works, except for non-relevant tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix most deployment tests * update deployment tests * move to conftest.py * fix not running nodes multiple times * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ignore F811 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup by copilot * manual cleanup * use `available` * use super * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update README.md * remove print statement * bump version to 0.2.0 * use descriptive key * use result * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more tests including deployment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more deployment testing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo and import order * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bugfix, mock IPS lotf workflow * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * raise Error for _external_ nodes in dask deployment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add future-me problems * typo * pre-release version --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
f7b4e43
commit a7b3c41
Showing
19 changed files
with
456 additions
and
244 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "znflow" | ||
version = "0.1.15" | ||
version = "0.2.0a0" | ||
description = "A general purpose framework for building and running computational graphs." | ||
authors = ["zincwarecode <[email protected]>"] | ||
license = "Apache-2.0" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
from distributed.utils_test import ( # noqa: F401 | ||
cleanup, | ||
client, | ||
cluster_fixture, | ||
loop, | ||
loop_in_thread, | ||
) | ||
|
||
import znflow | ||
|
||
|
||
@pytest.fixture | ||
def vanilla_deployment(): | ||
return znflow.deployment.VanillaDeployment() | ||
|
||
|
||
@pytest.fixture | ||
def dask_deployment(client): # noqa: F811 | ||
return znflow.deployment.DaskDeployment(client=client) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Mock version of IPS LotF workflow for testing purposes.""" | ||
|
||
import dataclasses | ||
import random | ||
|
||
import pytest | ||
|
||
import znflow | ||
|
||
|
||
@dataclasses.dataclass | ||
class AddData(znflow.Node): | ||
file: str | ||
|
||
def run(self): | ||
if self.file is None: | ||
raise ValueError("File is None") | ||
print(f"Adding data from {self.file}") | ||
|
||
@property | ||
def atoms(self): | ||
return "Atoms" | ||
|
||
|
||
@dataclasses.dataclass | ||
class TrainModel(znflow.Node): | ||
data: str | ||
model: str = None | ||
|
||
def run(self): | ||
if self.data is None: | ||
raise ValueError("Data is None") | ||
self.model = "Model" | ||
print(f"Model: {self.model}") | ||
|
||
|
||
@dataclasses.dataclass | ||
class MD(znflow.Node): | ||
model: str | ||
atoms: str = None | ||
|
||
def run(self): | ||
if self.model is None: | ||
raise ValueError("Model is None") | ||
self.atoms = "Atoms" | ||
print(f"Atoms: {self.atoms}") | ||
|
||
|
||
@dataclasses.dataclass | ||
class EvaluateModel(znflow.Node): | ||
model: str | ||
seed: int | ||
metrics: float = None | ||
|
||
def run(self): | ||
random.seed(self.seed) | ||
if self.model is None: | ||
raise ValueError("Model is None") | ||
self.metrics = random.random() | ||
print(f"Metrics: {self.metrics}") | ||
|
||
|
||
@pytest.mark.parametrize("deployment", ["vanilla_deployment", "dask_deployment"]) | ||
def test_lotf(deployment, request): | ||
deployment = request.getfixturevalue(deployment) | ||
|
||
graph = znflow.DiGraph(deployment=deployment) | ||
with graph: | ||
data = AddData(file="data.xyz") | ||
model = TrainModel(data=data.atoms) | ||
md = MD(model=model.model) | ||
metrics = EvaluateModel(model=model.model, seed=0) | ||
for idx in range(10): | ||
model = TrainModel(data=md.atoms) | ||
md = MD(model=model.model) | ||
metrics = EvaluateModel(model=model.model, seed=idx) | ||
if znflow.resolve(metrics.metrics) == pytest.approx(0.623, 1e-3): | ||
# break loop after 6th iteration | ||
break | ||
|
||
assert len(graph) == 22 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.