Skip to content

Commit

Permalink
UW-489 JEDI driver (ufs-community#453)
Browse files Browse the repository at this point in the history
* Add jedi.yaml and jedi.py

* WIP

* jedi.py and jedi schema

* API addition

* Jedi CLI

* Added tasks to jedi driver

* WIP

* Update iotaa version number

* Update jedi api to include cycle

* WIP

* WIP

* WI[

* Initial driver tests

* WIP

* WIP and fix error

* WIP

* WIP

* Fix minor error in drivers

* Format check

* Debug join error

* WIP unit tests

* Fixing minor errors

* WIP

* WIP

* WIP

* WIP

* WIP files copied and linked unit tests

* WIP

* WIP

* Updates (ufs-community#3)

* WIP

* WIP

* Update (ufs-community#4)

* Update src/uwtools/cli.py

Co-authored-by: Christina Holt <[email protected]>

* Update src/uwtools/api/jedi.py

Co-authored-by: Paul Madden <[email protected]>

* Update src/uwtools/cli.py

Co-authored-by: Paul Madden <[email protected]>

* Update src/uwtools/drivers/jedi.py

Co-authored-by: Paul Madden <[email protected]>

* Addressing comments

* Addressing PR comments

* Minor fixs

* WIP

* Change jedi.yaml to config.yaml

* Addressing PR comments

* Debug test case

* Addressing PR comments

* jedi-test-fix (ufs-community#5)

* Update

* Simplify

* Simplify

---------

Co-authored-by: Paul Madden <[email protected]>
Co-authored-by: Christina Holt <[email protected]>
  • Loading branch information
3 people authored Apr 17, 2024
1 parent 796f9b6 commit 455ec90
Show file tree
Hide file tree
Showing 13 changed files with 566 additions and 10 deletions.
27 changes: 27 additions & 0 deletions docs/shared/jedi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
platform:
account: me
scheduler: slurm
jedi:
execution:
batchargs:
nodes: 1
stdout: path/to/runscript.out
walltime: 08:00:00
envcmds:
- module load some-module
- module load jedi-module
executable: /path/to/jedi
mpiargs:
- --export=ALL
- --ntasks $SLURM_CPUS_ON_NODE
mpicmd: srun
configuration_file:
base_file: path/to/config.yaml
update_values: {"baz": "qux"}
files_to_copy:
f1: /path/to/f1
d/f2: /path/to/f2
files_to_link:
f3: /path/to/f3
f4: d/f4
run_dir: /path/to/run
54 changes: 54 additions & 0 deletions src/uwtools/api/jedi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
API access to the ``uwtools`` ``jedi`` driver.
"""

import datetime as dt
from pathlib import Path
from typing import Dict, Optional

import uwtools.drivers.support as _support
from uwtools.drivers.jedi import JEDI as _JEDI


def execute(
task: str,
config: Path,
cycle: dt.datetime,
batch: bool = False,
dry_run: bool = False,
graph_file: Optional[Path] = None,
) -> bool:
"""
Execute a JEDI task.
If ``batch`` is specified, a runscript will be written and submitted to the batch system.
Otherwise, the executable will be run directly on the current system.
:param task: The task to execute
:param cycle: The cycle.
:param config: Path to YAML config file
:param cycle: The cycle to run
:param batch: Submit run to the batch system
:param dry_run: Do not run the executable, just report what would have been done
:param graph_file: Write Graphviz DOT output here
:return: True if task completes without raising an exception
"""
obj = _JEDI(config=config, cycle=cycle, batch=batch, dry_run=dry_run)
getattr(obj, task)()
if graph_file:
with open(graph_file, "w", encoding="utf-8") as f:
print(graph(), file=f)
return True


def graph() -> str:
"""
Returns Graphviz DOT code for the most recently executed task.
"""
return _support.graph()


def tasks() -> Dict[str, str]:
"""
Returns a mapping from task names to their one-line descriptions.
"""
return _support.tasks(_JEDI)
57 changes: 57 additions & 0 deletions src/uwtools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import uwtools.api.config
import uwtools.api.file
import uwtools.api.fv3
import uwtools.api.jedi
import uwtools.api.mpas
import uwtools.api.mpas_init
import uwtools.api.rocoto
Expand Down Expand Up @@ -63,6 +64,7 @@ def main() -> None:
STR.config: _dispatch_config,
STR.file: _dispatch_file,
STR.fv3: _dispatch_fv3,
STR.jedi: _dispatch_jedi,
STR.mpas: _dispatch_mpas,
STR.mpasinit: _dispatch_mpas_init,
STR.rocoto: _dispatch_rocoto,
Expand Down Expand Up @@ -431,6 +433,59 @@ def _dispatch_fv3(args: Args) -> bool:
)


# Mode jedi


def _add_subparser_jedi(subparsers: Subparsers) -> ModeChecks:
"""
Subparser for mode: jedi
:param subparsers: Parent parser's subparsers, to add this subparser to.
"""
parser = _add_subparser(subparsers, STR.jedi, "Execute JEDI tasks")
_basic_setup(parser)
subparsers = _add_subparsers(parser, STR.action, STR.task.upper())
return {
task: _add_subparser_jedi_task(subparsers, task, helpmsg)
for task, helpmsg in uwtools.api.jedi.tasks().items()
}


def _add_subparser_jedi_task(subparsers: Subparsers, task: str, helpmsg: str) -> ActionChecks:
"""
Subparser for mode: jedi <task>
:param subparsers: Parent parser's subparsers, to add this subparser to.
:param task: The task to add a subparser for.
:param helpmsg: Help message for task.
"""
parser = _add_subparser(subparsers, task, helpmsg.rstrip("."))
required = parser.add_argument_group(TITLE_REQ_ARG)
_add_arg_config_file(group=required, required=True)
_add_arg_cycle(required)
optional = _basic_setup(parser)
_add_arg_batch(optional)
_add_arg_dry_run(optional)
_add_arg_graph_file(optional)
checks = _add_args_verbosity(optional)
return checks


def _dispatch_jedi(args: Args) -> bool:
"""
Dispatch logic for jedi mode.
:param args: Parsed command-line args.
"""
return uwtools.api.jedi.execute(
task=args[STR.action],
config=args[STR.cfgfile],
cycle=args[STR.cycle],
batch=args[STR.batch],
dry_run=args[STR.dryrun],
graph_file=args[STR.graphfile],
)


# Mode mpas


Expand Down Expand Up @@ -459,6 +514,7 @@ def _add_subparser_mpas_task(subparsers: Subparsers, task: str, helpmsg: str) ->
"""
parser = _add_subparser(subparsers, task, helpmsg.rstrip("."))
required = parser.add_argument_group(TITLE_REQ_ARG)

_add_arg_cycle(required)
optional = _basic_setup(parser)
_add_arg_config_file(group=optional, required=False)
Expand Down Expand Up @@ -1206,6 +1262,7 @@ def _parse_args(raw_args: List[str]) -> Tuple[Args, Checks]:
STR.config: _add_subparser_config(subparsers),
STR.file: _add_subparser_file(subparsers),
STR.fv3: _add_subparser_fv3(subparsers),
STR.jedi: _add_subparser_jedi(subparsers),
STR.mpas: _add_subparser_mpas(subparsers),
STR.mpasinit: _add_subparser_mpas_init(subparsers),
STR.rocoto: _add_subparser_rocoto(subparsers),
Expand Down
138 changes: 138 additions & 0 deletions src/uwtools/drivers/jedi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
A driver for the jedi component.
"""

import logging
from datetime import datetime
from pathlib import Path

from iotaa import asset, dryrun, refs, run, task, tasks

from uwtools.config.formats.yaml import YAMLConfig
from uwtools.drivers.driver import Driver
from uwtools.strings import STR
from uwtools.utils.tasks import file, filecopy, symlink


class JEDI(Driver):
"""
A driver for the JEDI component.
"""

def __init__(self, config: Path, cycle: datetime, dry_run: bool = False, batch: bool = False):
"""
The driver.
:param config: Path to config file.
:param cycle: The forecast cycle.
:param dry_run: Run in dry-run mode?
:param batch: Run component via the batch system?
"""
super().__init__(config=config, dry_run=dry_run, batch=batch, cycle=cycle)
if self._dry_run:
dryrun()
self._cycle = cycle

# Workflow tasks

@task
def configuration_file(self):
"""
The configuration file.
"""
fn = "jedi.yaml"
yield self._taskname(fn)
path = self._rundir / fn
yield asset(path, path.is_file)
yield None
self._create_user_updated_config(
config_class=YAMLConfig,
config_values=self._driver_config["configuration_file"],
path=path,
)

@tasks
def files_copied(self):
"""
Files copied for run.
"""
yield self._taskname("files copied")
yield [
filecopy(src=Path(src), dst=self._rundir / dst)
for dst, src in self._driver_config.get("files_to_copy", {}).items()
]

@tasks
def files_linked(self):
"""
Files linked for run.
"""
yield self._taskname("files linked")
yield [
symlink(target=Path(target), linkname=self._rundir / linkname)
for linkname, target in self._driver_config.get("files_to_link", {}).items()
]

@tasks
def provisioned_run_directory(self):
"""
Run directory provisioned with all required content.
"""
yield self._taskname("provisioned run directory")
yield [
self.configuration_file(),
self.files_copied(),
self.files_linked(),
self.runscript(),
self.validate_only(),
]

@task
def runscript(self):
"""
The runscript.
"""
path = self._runscript_path
yield self._taskname(path.name)
yield asset(path, path.is_file)
yield None
self._write_runscript(path=path, envvars={})

@task
def validate_only(self):
"""
Validate JEDI config YAML.
"""
taskname = self._taskname("validate_only")
yield taskname
a = asset(None, lambda: False)
yield a
executable = file(Path(self._driver_config["execution"]["executable"]))
config = self.configuration_file()
yield [executable, config]
cmd = "%s && time %s --validate-only %s 2>&1" % (
" && ".join(self._driver_config["execution"]["envcmds"]),
refs(executable),
refs(config),
)
result = run(taskname, cmd)
if result.success:
logging.info("%s: Config is valid", taskname)
a.ready = lambda: True

# Private helper methods

@property
def _driver_name(self) -> str:
"""
Returns the name of this driver.
"""
return STR.jedi

def _taskname(self, suffix: str) -> str:
"""
Returns a common tag for graph-task log messages.
:param suffix: Log-string suffix.
"""
return "%s %s %s" % (self._cycle.strftime("%Y%m%d %HZ"), self._driver_name, suffix)
8 changes: 8 additions & 0 deletions src/uwtools/resources/jsonschema/jedi.jsonschema
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"properties": {
"jedi": {
"type": "object"
}
},
"type": "object"
}
1 change: 1 addition & 0 deletions src/uwtools/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class STR:
help: str = "help"
infile: str = "input_file"
infmt: str = "input_format"
jedi: str = "jedi"
keys: str = "keys"
keyvalpairs: str = "key_eq_val_pairs"
link: str = "link"
Expand Down
34 changes: 34 additions & 0 deletions src/uwtools/tests/api/test_jedi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# pylint: disable=missing-function-docstring,protected-access

import datetime as dt
from unittest.mock import patch

from uwtools.api import jedi


def test_execute(tmp_path):
dot = tmp_path / "graph.dot"
args: dict = {
"config": "config.yaml",
"cycle": dt.datetime.now(),
"batch": False,
"dry_run": True,
"graph_file": dot,
}
with patch.object(jedi, "_JEDI") as JEDI:
assert jedi.execute(**args, task="foo") is True
del args["graph_file"]
JEDI.assert_called_once_with(**args)
JEDI().foo.assert_called_once_with()


def test_graph():
with patch.object(jedi._support, "graph") as graph:
jedi.graph()
graph.assert_called_once_with()


def test_tasks():
with patch.object(jedi._support, "tasks") as _tasks:
jedi.tasks()
_tasks.assert_called_once_with(jedi._JEDI)
Loading

0 comments on commit 455ec90

Please sign in to comment.