Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modern path handling and pydantic implementation #82

Draft
wants to merge 3 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 132 additions & 92 deletions COSIPY.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions conda_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ nco
cdo
cartopy
vtk
pydantic
richdem
coveralls
codecov
Expand Down
26 changes: 17 additions & 9 deletions convert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import configparser
import inspect
import sys
from pathlib import Path

import config
import constants
Expand Down Expand Up @@ -284,12 +285,12 @@ def get_utilities_params() -> dict:
return params


def write_toml(parameters: dict, filename: str):
def write_toml(parameters: dict, filename: Path | str):
"""Write parameters to .toml file."""

with open(f"{filename}.toml", "w") as f:
if isinstance(filename, str):
filename = Path(filename)
with filename.with_suffix(".toml").open("w") as f:
toml.dump(parameters, f)


print(f"Generated {filename}.toml")

Expand All @@ -308,15 +309,22 @@ def main():

print_warning()

script_path = inspect.getfile(inspect.currentframe())
toml_suffix = script_path.split("/")[-2] # avoid overwrite
frame = inspect.currentframe()
if frame is None:
msg = "Could not find the current frame. This is likely due to a bug in the code."
raise RuntimeError(msg)
try:
script_path = Path(inspect.getfile(frame)).resolve()
finally:
del frame
_ = script_path.parent.name # HACK: avoid overwrite (Why is this here?)

config_params = get_config_params()
write_toml(parameters=config_params, filename=f"config")
write_toml(parameters=config_params, filename="config")
constants_params = get_constants_params()
write_toml(parameters=constants_params, filename=f"constants")
write_toml(parameters=constants_params, filename="constants")
utilities_params = get_utilities_params()
write_toml(parameters=utilities_params, filename=f"utilities_config")
write_toml(parameters=utilities_params, filename="utilities_config")


if __name__ == "__main__":
Expand Down
214 changes: 143 additions & 71 deletions cosipy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,25 @@
import argparse
import sys
from importlib.metadata import entry_points
from pathlib import Path
from typing import Annotated, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic.types import StringConstraints
from typing_extensions import Self

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib # backwards compatibility

# FIXME: Will this work for all occasions or do we need to use frame?
cwd = Path.cwd()
default_path = cwd / "config.toml"
default_slurm_path = cwd / "slurm_config.toml"
default_constants_path = cwd / "constants.toml"
default_utilities_path = cwd / "utilities_config.toml"


def set_parser() -> argparse.ArgumentParser:
"""Set argument parser for COSIPY."""
Expand All @@ -23,9 +36,9 @@ def set_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-c",
"--config",
default="./config.toml",
default=default_path,
dest="config_path",
type=str,
type=Path,
metavar="<path>",
required=False,
help="relative path to configuration file",
Expand All @@ -34,9 +47,9 @@ def set_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-x",
"--constants",
default="./constants.toml",
default=default_constants_path,
dest="constants_path",
type=str,
type=Path,
metavar="<path>",
required=False,
help="relative path to constants file",
Expand All @@ -45,9 +58,9 @@ def set_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-s",
"--slurm",
default="./slurm_config.toml",
default=default_slurm_path,
dest="slurm_path",
type=str,
type=Path,
metavar="<path>",
required=False,
help="relative path to Slurm configuration file",
Expand Down Expand Up @@ -87,6 +100,74 @@ def get_user_arguments() -> argparse.Namespace:
return arguments


DatetimeStr = Annotated[
str,
StringConstraints(
strip_whitespace=True, pattern=r"\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]\d"
),
]


class CosipyConfigModel(BaseModel):
"""COSIPY configuration model."""

model_config = ConfigDict(from_attributes=True)
time_start: DatetimeStr = Field(
description="Start time of the simulation in ISO format"
)
time_end: DatetimeStr = Field(description="End time of the simulation in ISO format")
data_path: Path = Field(description="Path to the data directory")
input_netcdf: Path = Field(description="Input NetCDF file path")
output_prefix: str = Field(description="Prefix for output files")
restart: bool = Field(description="Restart flag")
stake_evaluation: bool = Field(description="Flag for stake data evaluation")
stakes_loc_file: Path = Field(description="Path to stake location file")
stakes_data_file: Path = Field(description="Path to stake data file")
eval_method: Literal["rmse"] = Field(
"rmse", description="Evaluation method for simulations"
)
obs_type: Literal["mb", "snowheight"] = Field(description="Type of stake data used")
WRF: bool = Field(description="Flag for WRF input")
WRF_X_CSPY: bool = Field(description="Interactive simulation with WRF flag")
northing: str = Field(description="Name of northing dimension")
easting: str = Field(description="Name of easting dimension")
compression_level: int = Field(
ge=0, le=9, description="Output NetCDF compression level"
)
slurm_use: bool = Field(description="Use SLURM flag")
workers: Optional[int] = Field(
default=None,
ge=0,
description="""
Setting is only used is slurm_use is False.
Number of workers (cores), with 0 all available cores are used.
""",
)
local_port: int = Field(default=8786, gt=0, description="Port for local cluster")
full_field: bool = Field(description="Flag for writing full fields to file")
force_use_TP: bool = Field(..., description="Total precipitation flag")
force_use_N: bool = Field(..., description="Cloud cover fraction flag")
tile: bool = Field(description="Flag for tiling")
xstart: int = Field(ge=0, description="Start x index")
xend: int = Field(ge=0, description="End x index")
ystart: int = Field(ge=0, description="Start y index")
yend: int = Field(ge=0, description="End y index")
output_atm: str = Field(description="Atmospheric output variables")
output_internal: str = Field(description="Internal output variables")
output_full: str = Field(description="Full output variables")

@model_validator(mode="after")
def validate_output_variables(self) -> Self:
if self.WRF:
self.northing = "south_north"
self.easting = "west_east"
if self.WRF_X_CSPY:
self.full_field = True
if self.workers == 0:
self.workers = None
return self


def get_help():
"""Print help for commands."""
parser = set_parser()
Expand Down Expand Up @@ -127,7 +208,7 @@ class TomlLoader(object):
"""Load and parse configuration files."""

@staticmethod
def get_raw_toml(file_path: str = "./config.toml") -> dict:
def get_raw_toml(file_path: Path = default_path) -> dict:
"""Open and load .toml configuration file.

Args:
Expand All @@ -136,21 +217,20 @@ def get_raw_toml(file_path: str = "./config.toml") -> dict:
Returns:
Loaded .toml data.
"""
with open(file_path, "rb") as f:
raw_config = tomllib.load(f)

return raw_config
with file_path.open("rb") as f:
return tomllib.load(f)

@classmethod
def set_config_values(cls, config_table: dict):
@staticmethod
def flatten(config_table: dict[str, dict]) -> dict:
"""Overwrite attributes with configuration data.

Args:
config_table: Loaded .toml data.
"""
for _, table in config_table.items():
for k, v in table.items():
setattr(cls, k, v)
flat_dict = {}
for table in config_table.values():
flat_dict = {**flat_dict, **table}
return flat_dict


class Config(TomlLoader):
Expand All @@ -160,37 +240,45 @@ class Config(TomlLoader):
.toml file.
"""

def __init__(self):
self.args = get_user_arguments()
self.load(self.args.config_path)

@classmethod
def load(cls, path: str = "./config.toml"):
raw_toml = cls.get_raw_toml(path)
parsed_toml = cls.set_correct_config(raw_toml)
cls.set_config_values(parsed_toml)
def __init__(self, path: Path = default_path) -> None:
raw_toml = self.get_raw_toml(path)
self.raw_toml = self.flatten(raw_toml)

@classmethod
def set_correct_config(cls, config_table: dict) -> dict:
"""Adjust invalid or mutually exclusive configuration values.

Args:
config_table: Loaded .toml data.
def validate(self) -> CosipyConfigModel:
"""Validate configuration using Pydantic class.

Returns:
Adjusted .toml data.
CosipyConfigModel: Validated configuration.
"""
# WRF Compatibility
if config_table["DIMENSIONS"]["WRF"]:
config_table["DIMENSIONS"]["northing"] = "south_north"
config_table["DIMENSIONS"]["easting"] = "west_east"
if config_table["DIMENSIONS"]["WRF_X_CSPY"]:
config_table["FULL_FIELDS"]["full_field"] = True
# TOML doesn't support null values
if config_table["PARALLELIZATION"]["workers"] == 0:
config_table["PARALLELIZATION"]["workers"] = None
return CosipyConfigModel(**self.raw_toml)


ShebangStr = Annotated[str, StringConstraints(strip_whitespace=True, pattern=r"^#!")]


class SlurmConfigModel(BaseModel):
"""Slurm configuration model."""

account: str = Field(description="Slurm account/group")
name: str = Field(description="Equivalent to Slurm parameter `--job-name`")
queue: str = Field(description="Queue name")
slurm_parameters: list[str] = Field(description="Additional Slurm parameters")
shebang: ShebangStr = Field(description="Shebang string")
local_directory: Path = Field(description="Local directory")
port: int = Field(description="Network port number")
cores: int = Field(description="One grid point per core")
nodes: int = Field(description="Grid points submitted in one sbatch script")
processes: int = Field(description="Number of processes")
memory: str = Field(description="Memory per process")
memory_per_process: Optional[int] = Field(gt=0, description="Memory per process")

return config_table
@model_validator(mode="after")
def validate_output_variables(self):
if self.memory_per_process:
memory = self.memory_per_process * self.cores
self.memory = f"{memory}GB"

return self


class SlurmConfig(TomlLoader):
Expand All @@ -214,43 +302,27 @@ class SlurmConfig(TomlLoader):
slurm_parameters (List[str]): Additional Slurm parameters.
"""

def __init__(self):
self.args = get_user_arguments()
self.load(self.args.slurm_path)

@classmethod
def load(cls, path: str = "./slurm_config.toml"):
raw_toml = cls.get_raw_toml(path)
parsed_toml = cls.set_correct_config(raw_toml)
cls.set_config_values(parsed_toml)
def __init__(self, path: Path = default_slurm_path) -> None:
raw_toml = self.get_raw_toml(path)
self.raw_toml = self.flatten(raw_toml)

@classmethod
def set_correct_config(cls, config_table: dict) -> dict:
"""Adjust invalid or mutually exclusive configuration values.

Args:
config_table: Loaded .toml data.
def validate(self) -> SlurmConfigModel:
"""Validate configuration using Pydantic class.

Returns:
Adjusted .toml data.
CosipyConfigModel: Validated configuration.
"""
if config_table["OVERRIDES"]["memory_per_process"]:
memory = (
config_table["OVERRIDES"]["memory_per_process"]
* config_table["MEMORY"]["cores"]
)
config_table["MEMORY"]["memory"] = f"{memory}GB"

return config_table
return SlurmConfigModel(**self.raw_toml)


def main():
cfg = Config()
if cfg.slurm_use:
SlurmConfig()
def main() -> tuple[CosipyConfigModel, Optional[SlurmConfigModel]]:
args = get_user_arguments()
cfg = Config(args.config_path).validate()
slurm_cfg = SlurmConfig(args.slurm_path).validate() if cfg.slurm_use else None
return cfg, slurm_cfg


if __name__ == "__main__":
main()
else:
main()
main_config, slurm_config = main()
Loading