Skip to content

Commit

Permalink
Refactors "set_fields" routine with custom context manager (RascalSof…
Browse files Browse the repository at this point in the history
…tware#112)

* Refactors "set_fields" routine with custom context manager

* Addresses review comments
  • Loading branch information
DrPaulSharp authored Jan 9, 2025
1 parent 082657c commit 973ff1d
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 4 deletions.
65 changes: 62 additions & 3 deletions RATapi/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import collections
import contextlib
import importlib
import warnings
from collections.abc import Sequence
from typing import Any, Generic, TypeVar, Union
Expand Down Expand Up @@ -261,9 +262,67 @@ def extend(self, other: Sequence[T]) -> None:
def set_fields(self, index: int, **kwargs) -> None:
"""Assign the values of an existing object's attributes using keyword arguments."""
self._validate_name_field(kwargs)
class_handle = self.data[index].__class__
new_fields = {**self.data[index].__dict__, **kwargs}
self.data[index] = class_handle(**new_fields)
pydantic_object = False

if importlib.util.find_spec("pydantic"):
# Pydantic is installed, so set up a context manager that will
# suppress custom validation errors until all fields have been set.
from pydantic import BaseModel, ValidationError

if isinstance(self.data[index], BaseModel):
pydantic_object = True

# Define a custom context manager
class SuppressCustomValidation(contextlib.AbstractContextManager):
"""Context manager to suppress "value_error" based validation errors in pydantic.
This validation context is necessary because errors can occur whilst individual
model values are set, which are resolved when all of the input values are set.
After the exception is suppressed, execution proceeds with the next
statement following the with statement.
with SuppressCustomValidation():
setattr(self.data[index], key, value)
# Execution still resumes here if the attribute cannot be set
"""

def __init__(self):
pass

def __enter__(self):
pass

def __exit__(self, exctype, excinst, exctb):
# If the return of __exit__ is True or truthy, the exception is suppressed.
# Otherwise, the default behaviour of raising the exception applies.
#
# To suppress errors arising from field and model validators in pydantic,
# we will examine the validation errors raised. If all of the errors
# listed in the exception have the type "value_error", this indicates
# they have arisen from field or model validators and will be suppressed.
# Otherwise, they will be raised.
if exctype is None:
return
if issubclass(exctype, ValidationError) and all(
[error["type"] == "value_error" for error in excinst.errors()]
):
return True
return False

validation_context = SuppressCustomValidation()
else:
validation_context = contextlib.nullcontext()

for key, value in kwargs.items():
with validation_context:
setattr(self.data[index], key, value)

# We have suppressed custom validation errors for pydantic objects.
# We now must revalidate the pydantic model outside the validation context
# to catch any errors that remain after setting all of the fields.
if pydantic_object:
self._class_handle.model_validate(self.data[index])

def get_names(self) -> list[str]:
"""Return a list of the values of the name_field attribute of each class object in the list.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extend-exclude = ["*.ipynb"]

[tool.ruff.lint]
select = ["E", "F", "UP", "B", "SIM", "I"]
ignore = ["SIM108"]
ignore = ["SIM103", "SIM108"]

[tool.ruff.lint.flake8-pytest-style]
fixture-parentheses = false
Expand Down
22 changes: 22 additions & 0 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,3 +1005,25 @@ class NestedModel(pydantic.BaseModel):
for submodel, exp_dict in zip(model.submodels, submodels_list):
for key, value in exp_dict.items():
assert getattr(submodel, key) == value

def test_set_pydantic_fields(self):
"""Test that intermediate validation errors for pydantic models are suppressed when using "set_fields"."""
from pydantic import BaseModel, model_validator

class MinMaxModel(BaseModel):
min: float
value: float
max: float

@model_validator(mode="after")
def check_value_in_range(self) -> "MinMaxModel":
if self.value < self.min or self.value > self.max:
raise ValueError(
f"value {self.value} is not within the defined range: {self.min} <= value <= {self.max}"
)
return self

model_list = ClassList([MinMaxModel(min=1, value=2, max=5)])
model_list.set_fields(0, min=3, value=4)

assert model_list == ClassList([MinMaxModel(min=3.0, value=4.0, max=5.0)])

0 comments on commit 973ff1d

Please sign in to comment.