Skip to content

Commit

Permalink
Merge pull request #17 from wpbonelli/cleanup
Browse files Browse the repository at this point in the history
Cleanup/fixes
  • Loading branch information
wpbonelli authored Jul 30, 2024
2 parents aad8a54 + b9eb84a commit ae33a56
Show file tree
Hide file tree
Showing 9 changed files with 481 additions and 167 deletions.
14 changes: 9 additions & 5 deletions flopy4/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,6 @@ def __init__(
self._how = how
self._factor = factor

def __get__(self, obj, type=None):
return self if self.value is None else self.value

def __getitem__(self, item):
return self.raw[item]

Expand Down Expand Up @@ -293,8 +290,15 @@ def value(self) -> Optional[np.ndarray]:
return self._value.reshape(self._shape) * self.factor

@value.setter
def value(self, value: np.ndarray):
assert value.shape == self.shape
def value(self, value: Optional[np.ndarray]):
if value is None:
return

if value.shape != self.shape:
raise ValueError(
f"Expected array with shape {self.shape},"
f"got shape {value.shape}"
)
self._value = value

@property
Expand Down
229 changes: 180 additions & 49 deletions flopy4/block.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
from abc import ABCMeta
from collections import UserDict
from collections import OrderedDict, UserDict
from dataclasses import asdict
from io import StringIO
from pprint import pformat
from typing import Any
from typing import Any, Dict, Optional

from flopy4.array import MFArray
from flopy4.compound import MFKeystring, MFRecord
from flopy4.compound import MFKeystring, MFRecord, get_keystrings
from flopy4.param import MFParam, MFParams
from flopy4.scalar import MFScalar
from flopy4.utils import find_upper, strip


def get_keystrings(members, name):
return [
m for m in members.values() if isinstance(m, MFKeystring) and name in m
]


def get_param(members, name, block):
param = next(iter(get_keystrings(members, name)), None)
def get_param(params, block_name, param_name):
"""
Find the first parameter in the collection with
the given name, set its block name, and return it.
"""
param = next(get_keystrings(params, param_name), None)
if param is None:
param = members.get(name)
param = params.get(param_name)
if param is None:
raise ValueError(f"Invalid parameter: {name.upper()}")
param.name = name
param.block = block
raise ValueError(f"Invalid parameter: {param_name}")
param.name = param_name
param.block = block_name
return param


Expand All @@ -40,8 +39,8 @@ def __new__(cls, clsname, bases, attrs):
.lower()
)

# add parameter specification as class attribute.
# dynamically set the parameters' name and block.
# add class attributes for the block parameter specification.
# dynamically set each parameter's name, block and docstring.
params = dict()
for attr_name, attr in attrs.items():
if issubclass(type(attr), MFParam):
Expand All @@ -50,6 +49,7 @@ def __new__(cls, clsname, bases, attrs):
attr.block = block_name
attrs[attr_name] = attr
params[attr_name] = attr

attrs["params"] = MFParams(params)

return super().__new__(cls, clsname, bases, attrs)
Expand All @@ -64,6 +64,7 @@ class MFBlock(MFParams, metaclass=MFBlockMappingMeta):
"""
MF6 input block. Maps parameter names to parameters.
Notes
-----
This class is dynamically subclassed by `MFPackage`
Expand All @@ -74,32 +75,109 @@ class MFBlock(MFParams, metaclass=MFBlockMappingMeta):
attributes expose the parameter value.
The block's name and index are discovered upon load.
Likewise the parameter values are populated on load.
They can also be initialized by passing a dictionary
of names/values to `params` when calling `__init__`.
Only recognized parameters (i.e. parameters known to
the block specification) are allowed.
"""

def __init__(self, name=None, index=None, params=None):
def __init__(
self,
name: Optional[str] = None,
index: Optional[int] = None,
params: Optional[Dict[str, Any]] = None,
):
self.name = name
self.index = index
super().__init__(params)

# if a parameter mapping is provided, coerce it to the
# spec and set default values
if params is not None:
params = type(self).coerce(params, set_default=True)

super().__init__(params=params)

def __getattribute__(self, name: str) -> Any:
if name == "data":
return super().__getattribute__(name)
self_type = type(self)

# shortcut to parameter value for instance attribute.
# the class attribute is the parameter specification.
if name in self_type.params:
return self.value[name]

# add .params attribute as an alias for .value, this
# overrides the class attribute with the param spec.
if name == "params":
return MFParams({k: v.value for k, v in self.data.items()})
return self.value

param = self.data.get(name)
return (
param.value
if param is not None
else super().__getattribute__(name)
)
return super().__getattribute__(name)

def __str__(self):
buffer = StringIO()
self.write(buffer)
return buffer.getvalue()

def __eq__(self, other):
return super().__eq__(other)

@property
def value(self):
"""Get a dictionary of block parameter values."""
return MFParams.value.fget(self)

@value.setter
def value(self, value):
"""Set block parameter values from a dictionary."""

if value is None or not any(value):
return

# coerce the parameter mapping to the spec and set defaults
params = type(self).coerce(value.copy(), set_default=True)
MFParams.value.fset(self, params)

@classmethod
def coerce(
cls, params: Dict[str, Any], set_default: bool = False
) -> Dict[str, MFParam]:
"""
Check that the dictionary contains only expected parameters,
raising an error if any unrecognized parameters are provided.
Dictionary values may be subclasses of `MFParam` or values
provided directly. If the former, this function optionally
sets default values for any missing member parameters.
"""

known = dict()
for param_name, param_spec in cls.params.copy().items():
param = params.pop(param_name, param_spec)

# make sure param is of expected type. set a
# default value if enabled and none provided.
spec_type = type(param_spec)
real_type = type(param)
if issubclass(real_type, MFParam):
if param.value is None and set_default:
param.value = param_spec.default_value
elif issubclass(spec_type, MFScalar) and real_type == spec_type.T:
param = spec_type(value=param, **asdict(param_spec))
else:
raise TypeError(
f"Expected '{param_name}' as {spec_type}, got {real_type}"
)

known[param_name] = param

# raise an error if we have any unknown parameters.
# `MFBlock` strictly disallows unrecognized params,
# for arbitrary parameter collections use `MFParams`.
if any(params):
raise ValueError(f"Unrecognized parameters:\n{pformat(params)}")

return known

@classmethod
def load(cls, f, **kwargs):
"""Load the block from file."""
Expand All @@ -126,23 +204,24 @@ def load(cls, f, **kwargs):
elif key == "end":
break
elif found:
param = get_param(members, key, name)
if param is not None:
f.seek(pos)
spec = asdict(param)
kwrgs = {**kwargs, **spec}
ptype = type(param)
if ptype is MFArray:
# TODO: inject from model somehow?
# and remove special handling here
kwrgs["cwd"] = ""
if ptype is MFRecord:
kwrgs["params"] = param.data.copy()
if ptype is MFKeystring:
kwrgs["params"] = param.data.copy()
params[param.name] = ptype.load(f, **kwrgs)

return cls(name, index, params)
param = get_param(members, name, key)
if param is None:
continue
f.seek(pos)
spec = asdict(param)
kwrgs = {**kwargs, **spec}
ptype = type(param)
if ptype is MFArray:
# TODO: inject from model somehow?
# and remove special handling here
kwrgs["cwd"] = ""
if ptype is MFRecord:
kwrgs["params"] = param.data.copy()
if ptype is MFKeystring:
kwrgs["params"] = param.data.copy()
params[param.name] = ptype.load(f, **kwrgs)

return cls(name=name, index=index, params=params)

def write(self, f):
"""Write the block to file."""
Expand All @@ -156,17 +235,69 @@ def write(self, f):


class MFBlocks(UserDict):
"""Mapping of block names to blocks."""
"""
Mapping of block names to blocks. Acts like a
dictionary, also supports named attribute access.
"""

def __init__(self, blocks=None):
MFBlocks.assert_blocks(blocks)
super().__init__(blocks)
for key, block in self.items():
setattr(self, key, block)

def __repr__(self):
return pformat(self.data)

def write(self, f):
def __eq__(self, other):
if not isinstance(other, MFBlocks):
raise TypeError(f"Expected MFBlocks, got {type(other)}")
return OrderedDict(sorted(self.value)) == OrderedDict(
sorted(other.value)
)

@staticmethod
def assert_blocks(blocks):
"""
Raise an error if any of the given items are
not subclasses of `MFBlock`.
"""
if not blocks:
return
elif isinstance(blocks, dict):
blocks = blocks.values()
not_blocks = [
b
for b in blocks
if b is not None and not issubclass(type(b), MFBlock)
]
if any(not_blocks):
raise TypeError(f"Expected MFBlock subclasses, got {not_blocks}")

@property
def value(self) -> Dict[str, Dict[str, Any]]:
"""
Get a dictionary of package block values. This is a
nested mapping of block names to blocks, where each
block is a mapping of parameter names to parameter
values.
"""
return {k: v.value for k, v in self.items()}

@value.setter
def value(self, value: Optional[Dict[str, Dict[str, Any]]]):
"""Set block values from a nested dictionary."""

if value is None or not any(value):
return

blocks = value.copy()
MFBlocks.assert_blocks(blocks)
self.update(blocks)
for key, block in self.items():
setattr(self, key, block)

def write(self, f, **kwargs):
"""Write the blocks to file."""
for block in self.data.values():
block.write(f)
for block in self.values():
block.write(f, **kwargs)
Loading

0 comments on commit ae33a56

Please sign in to comment.