Skip to content

Commit

Permalink
Deployment module (#97)
Browse files Browse the repository at this point in the history
* move to deployment module

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove deployment file

* use vanilla deployment class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* temporary circular import bugfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use base class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add deployment exchanging tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* works, except for non-relevant tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix most deployment tests

* update deployment tests

* move to conftest.py

* fix not running nodes multiple times

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ignore F811

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup by copilot

* manual cleanup

* use `available`

* use super

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update README.md

* remove print statement

* bump version to 0.2.0

* use descriptive key

* use result

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more tests including deployment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more deployment testing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo and import order

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bugfix, mock IPS lotf workflow

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* raise Error for _external_ nodes in dask deployment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add future-me problems

* typo

* pre-release version

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Feb 28, 2024
1 parent f7b4e43 commit a7b3c41
Show file tree
Hide file tree
Showing 19 changed files with 456 additions and 244 deletions.
17 changes: 7 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,26 +127,23 @@ class ComputeMean(znflow.Node):
def run(self):
self.results = (self.x + self.y) / 2

with znflow.DiGraph() as graph:

client = Client()
deployment = znflow.deployment.DaskDeployment(client=client)


with znflow.DiGraph(deployment=deployment) as graph:
n1 = ComputeMean(2, 8)
n2 = compute_mean(13, 7)
# connecting classes and functions to a Node
n3 = ComputeMean(n1.results, n2)

client = Client()
deployment = znflow.deployment.Deployment(graph=graph, client=client)
deployment.submit_graph()
graph.run()

n3 = deployment.get_results(n3)
print(n3)
# >>> ComputeMean(x=5.0, y=10.0, results=7.5)
```

We need to get the updated instance from the Dask worker via
`Deployment.get_results`. Due to the way Dask works, an inplace update is not
possible. To retrieve the full graph, you can use
`Deployment.get_results(graph.nodes)` instead.

### Working with lists

ZnFlow supports some special features for working with lists. In the following
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znflow"
version = "0.1.15"
version = "0.2.0a0"
description = "A general purpose framework for building and running computational graphs."
authors = ["zincwarecode <[email protected]>"]
license = "Apache-2.0"
Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from distributed.utils_test import ( # noqa: F401
cleanup,
client,
cluster_fixture,
loop,
loop_in_thread,
)

import znflow


@pytest.fixture
def vanilla_deployment():
return znflow.deployment.VanillaDeployment()


@pytest.fixture
def dask_deployment(client): # noqa: F811
return znflow.deployment.DaskDeployment(client=client)
81 changes: 81 additions & 0 deletions tests/examples/test_ips_lotf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Mock version of IPS LotF workflow for testing purposes."""

import dataclasses
import random

import pytest

import znflow


@dataclasses.dataclass
class AddData(znflow.Node):
file: str

def run(self):
if self.file is None:
raise ValueError("File is None")
print(f"Adding data from {self.file}")

@property
def atoms(self):
return "Atoms"


@dataclasses.dataclass
class TrainModel(znflow.Node):
data: str
model: str = None

def run(self):
if self.data is None:
raise ValueError("Data is None")
self.model = "Model"
print(f"Model: {self.model}")


@dataclasses.dataclass
class MD(znflow.Node):
model: str
atoms: str = None

def run(self):
if self.model is None:
raise ValueError("Model is None")
self.atoms = "Atoms"
print(f"Atoms: {self.atoms}")


@dataclasses.dataclass
class EvaluateModel(znflow.Node):
model: str
seed: int
metrics: float = None

def run(self):
random.seed(self.seed)
if self.model is None:
raise ValueError("Model is None")
self.metrics = random.random()
print(f"Metrics: {self.metrics}")


@pytest.mark.parametrize("deployment", ["vanilla_deployment", "dask_deployment"])
def test_lotf(deployment, request):
deployment = request.getfixturevalue(deployment)

graph = znflow.DiGraph(deployment=deployment)
with graph:
data = AddData(file="data.xyz")
model = TrainModel(data=data.atoms)
md = MD(model=model.model)
metrics = EvaluateModel(model=model.model, seed=0)
for idx in range(10):
model = TrainModel(data=md.atoms)
md = MD(model=model.model)
metrics = EvaluateModel(model=model.model, seed=idx)
if znflow.resolve(metrics.metrics) == pytest.approx(0.623, 1e-3):
# break loop after 6th iteration
break

assert len(graph) == 22
111 changes: 66 additions & 45 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses

import numpy as np
import pytest

import znflow

Expand All @@ -25,76 +26,94 @@ def add_to_ComputeSum(instance: ComputeSum):
return instance.outputs + 1


def test_single_nodify():
with znflow.DiGraph() as graph:
@pytest.mark.parametrize(
"deployment",
["vanilla_deployment", "dask_deployment"],
)
def test_single_nodify(request, deployment):
deployment = request.getfixturevalue(deployment)

with znflow.DiGraph(deployment=deployment) as graph:
node1 = compute_sum(1, 2, 3)

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()
graph.run()

assert depl.get_results(node1) == 6
assert node1.result == 6


def test_single_Node():
with znflow.DiGraph() as graph:
node1 = ComputeSum(inputs=[1, 2, 3])
@pytest.mark.parametrize(
"deployment",
["vanilla_deployment", "dask_deployment"],
)
def test_single_Node(request, deployment):
deployment = request.getfixturevalue(deployment)

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()
with znflow.DiGraph(deployment=deployment) as graph:
node1 = ComputeSum(inputs=[1, 2, 3])

node1 = depl.get_results(node1)
graph.run()
assert node1.outputs == 6


def test_multiple_nodify():
with znflow.DiGraph() as graph:
@pytest.mark.parametrize(
"deployment",
["vanilla_deployment", "dask_deployment"],
)
def test_multiple_nodify(request, deployment):
deployment = request.getfixturevalue(deployment)

with znflow.DiGraph(deployment=deployment) as graph:
node1 = compute_sum(1, 2, 3)
node2 = compute_sum(4, 5, 6)
node3 = compute_sum(node1, node2)

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()
graph.run()

assert node1.result == 6
assert node2.result == 15
assert node3.result == 21

assert depl.get_results(node1) == 6
assert depl.get_results(node2) == 15
assert depl.get_results(node3) == 21

@pytest.mark.parametrize(
"deployment",
["vanilla_deployment", "dask_deployment"],
)
def test_multiple_Node(request, deployment):
deployment = request.getfixturevalue(deployment)

def test_multiple_Node():
with znflow.DiGraph() as graph:
with znflow.DiGraph(deployment=deployment) as graph:
node1 = ComputeSum(inputs=[1, 2, 3])
node2 = ComputeSum(inputs=[4, 5, 6])
node3 = ComputeSum(inputs=[node1.outputs, node2.outputs])

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()
graph.run()

node1 = depl.get_results(node1)
node2 = depl.get_results(node2)
node3 = depl.get_results(node3)
assert node1.outputs == 6
assert node2.outputs == 15
assert node3.outputs == 21


def test_multiple_nodify_and_Node():
with znflow.DiGraph() as graph:
@pytest.mark.parametrize(
"deployment",
["vanilla_deployment", "dask_deployment"],
)
def test_multiple_nodify_and_Node(request, deployment):
deployment = request.getfixturevalue(deployment)

with znflow.DiGraph(deployment=deployment) as graph:
node1 = compute_sum(1, 2, 3)
node2 = ComputeSum(inputs=[4, 5, 6])
node3 = compute_sum(node1, node2.outputs)
node4 = ComputeSum(inputs=[node1, node2.outputs, node3])
node5 = add_to_ComputeSum(node4)

depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()

results = depl.get_results(graph.nodes)
graph.run()

assert results[node1.uuid] == 6
assert results[node2.uuid].outputs == 15
assert results[node3.uuid] == 21
assert results[node4.uuid].outputs == 42
assert results[node5.uuid] == 43
assert node1.result == 6
assert node2.outputs == 15
assert node3.result == 21
assert node4.outputs == 42
assert node5.result == 43


@znflow.nodify
Expand All @@ -107,16 +126,18 @@ def concatenate(forces):
return np.concatenate(forces)


def test_concatenate():
with znflow.DiGraph() as graph:
@pytest.mark.parametrize(
"deployment",
["vanilla_deployment", "dask_deployment"],
)
def test_concatenate(request, deployment):
deployment = request.getfixturevalue(deployment)

with znflow.DiGraph(deployment=deployment) as graph:
forces = [get_forces() for _ in range(10)]
forces = concatenate(forces)

deployment = znflow.deployment.Deployment(
graph=graph,
)
deployment.submit_graph()
results = deployment.get_results(forces)
graph.run()

assert isinstance(results, np.ndarray)
assert results.shape == (1000, 3)
assert isinstance(forces.result, np.ndarray)
assert forces.result.shape == (1000, 3)
Loading

0 comments on commit a7b3c41

Please sign in to comment.