Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: raise TypeError if config field does not match declared type #7

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
uv.lock
tests/test_serialization.json

WadaSNR/
.idea/
Expand Down
50 changes: 40 additions & 10 deletions coqpit/coqpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import operator
import typing
import warnings
from collections.abc import Callable, ItemsView, Iterable, Iterator, MutableMapping
from dataclasses import MISSING as _MISSING
from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace
Expand Down Expand Up @@ -182,6 +183,9 @@ def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]:
Returns:
Dict: deserialized dictionary.
"""
if not isinstance(x, dict):
msg = f"Value `{x}` is not a dictionary"
raise TypeError(msg)
out_dict: dict[Any, Any] = {}
for k, v in x.items():
if v is None: # if {'key':None}
Expand All @@ -204,6 +208,9 @@ def _deserialize_list(x: list[Any], field_type: FieldType) -> list[Any]:
Returns:
[List]: deserialized list.
"""
if not isinstance(x, list):
msg = f"Value `{x}` does not match field type `{field_type}`"
raise TypeError(msg)
field_args = typing.get_args(field_type)
if len(field_args) == 0:
return x
Expand Down Expand Up @@ -232,7 +239,7 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any:
try:
x = _deserialize(x, arg)
break
except ValueError:
except (TypeError, ValueError):
pass
return x

Expand All @@ -252,18 +259,30 @@ def _deserialize_primitive_types(
Returns:
Union[int, float, str, bool]: deserialized value.
"""
if isinstance(x, str | bool):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)

type_mismatch = f"Value `{x}` does not match field type `{field_type}`"
if x is None and type(None) in typing.get_args(field_type):
return None
if isinstance(x, str):
if base_type is not str:
raise TypeError(type_mismatch)
return x
if isinstance(x, bool):
if base_type is not bool:
raise TypeError(type_mismatch)
return x
if isinstance(x, int | float):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)
if x == float("inf") or x == float("-inf"):
# if value type is inf return regardless.
return x
if base_type is not int and base_type is not float:
raise TypeError(type_mismatch)
return base_type(x)
return None
raise TypeError(type_mismatch)


def _deserialize_path(x: Any, field_type: FieldType) -> Path | None:
Expand Down Expand Up @@ -299,8 +318,8 @@ def _deserialize(x: Any, field_type: FieldType) -> Any:
return _deserialize_path(x, field_type)
if _is_primitive_type(_drop_none_type(field_type)):
return _deserialize_primitive_types(x, field_type)
msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type."
raise ValueError(msg)
msg = f"Type '{type(x)}' of value '{x}' does not match declared '{field_type}' field type."
raise TypeError(msg)


CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"]
Expand Down Expand Up @@ -433,7 +452,18 @@ def deserialize(self, data: dict[str, Any]) -> Self:
if value == MISSING:
msg = f"deserialized with unknown value for {field.name} in {self.__class__.__name__}"
raise ValueError(msg)
value = _deserialize(value, field.type)
try:
value = _deserialize(value, field.type)
except TypeError:
warnings.warn(
(
f"Type mismatch in {type(self).__name__}\n"
f"Failed to deserialize field: {field.name} ({field.type}) = {value}\n"
f"Replaced it with field's default value: {_default_value(field)}"
),
stacklevel=2,
)
value = _default_value(field)
init_kwargs[field.name] = value
for k, v in init_kwargs.items():
setattr(self, k, v)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "coqpit-config"
version = "0.1.2"
version = "0.2.0"
description = "Simple (maybe too simple), light-weight config management through python data-classes."
readme = "README.md"
requires-python = ">=3.10"
Expand Down
24 changes: 0 additions & 24 deletions tests/test_serialization.json

This file was deleted.

97 changes: 83 additions & 14 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from pathlib import Path
from types import UnionType
from typing import Any

import pytest

from coqpit.coqpit import Coqpit, _deserialize_list, _deserialize_primitive_types
from coqpit.coqpit import Coqpit, FieldType, _deserialize_list, _deserialize_primitive_types, _deserialize_union


@dataclass
Expand Down Expand Up @@ -62,39 +64,106 @@ def test_serialization() -> None:
assert ref_config.some_dict["c"] == new_config.some_dict["c"]


def test_serialization_type_mismatch() -> None:
file_path = Path(__file__).resolve().parent / "test_serialization.json"

ref_config = Reference()
ref_config.size = True
ref_config.save_json(file_path)

new_config = Group()
with pytest.warns(UserWarning, match="Type mismatch"):
new_config.load_json(file_path)
new_config.pprint()

# check values
assert len(ref_config) == len(new_config)
assert new_config.size is None


def test_deserialize_list() -> None:
assert _deserialize_list([1, 2, 3], list) == [1, 2, 3]
assert _deserialize_list([1, 2, 3], list[int]) == [1, 2, 3]
assert _deserialize_list([[1, 2, 3]], list[list[int]]) == [[1, 2, 3]]
assert _deserialize_list([1.0, 2.0, 3.0], list[float]) == [1.0, 2.0, 3.0]
assert _deserialize_list([1, 2, 3], list[float]) == [1.0, 2.0, 3.0]
assert _deserialize_list([1, 2, 3], list[str]) == ["1", "2", "3"]
assert _deserialize_list(["1", "2", "3"], list[str]) == ["1", "2", "3"]

with pytest.raises(TypeError, match="does not match field type"):
_deserialize_list([1, 2, 3], list[list[int]])

def test_deserialize_primitive_type() -> None:
cases = (

@pytest.mark.parametrize(
("value", "field_type", "expected"),
[
(True, bool, True),
(False, bool, False),
("a", str, "a"),
("3", str, "3"),
(3, int, 3),
(3, float, 3.0),
(3, str, "3"),
(3.0, str, "3.0"),
(3, bool, True),
("a", str | None, "a"),
("3", str | None, "3"),
(3, int | None, 3),
(3, float | None, 3.0),
(None, str | None, None),
(None, int | None, None),
(None, float | None, None),
(None, str | None, None),
(None, bool | None, None),
(float("inf"), float, float("inf")),
(float("inf"), int, float("inf")),
(float("-inf"), float, float("-inf")),
(float("-inf"), int, float("-inf")),
)
for value, field_type, expected in cases:
assert _deserialize_primitive_types(value, field_type) == expected

with pytest.raises(TypeError):
_deserialize_primitive_types(3, Coqpit)
],
)
def test_deserialize_primitive_type(
value: str | bool | float | None,
field_type: FieldType,
expected: str | bool | float | None,
) -> None:
assert _deserialize_primitive_types(value, field_type) == expected


@pytest.mark.parametrize(
("value", "field_type"),
[
(3, str),
(3, str | None),
(3.0, str),
(3, bool),
("1", int),
("2.0", float),
("True", bool),
("True", bool | None),
("", bool | None),
([1, 2], str),
([1, 2, 3], int),
],
)
def test_deserialize_primitive_type_mismatch(
value: str | bool | float | None,
field_type: FieldType,
) -> None:
with pytest.raises(TypeError, match="does not match field type"):
_deserialize_primitive_types(value, field_type)


@pytest.mark.parametrize(
("value", "field_type", "expected"),
[
("a", int | str, "a"),
("a", str | int, "a"),
(1, int | str, 1),
(1, str | int, 1),
(1, str | int | list[int], 1),
([1, 2], str | int | list[int], [1, 2]),
([1, 2], list[int] | int | str, [1, 2]),
([1, 2], dict | list, [1, 2]),
(["a", "b"], list[str] | list[list[str]], ["a", "b"]),
(["a", "b"], list[list[str]] | list[str], ["a", "b"]),
([["a", "b"]], list[str] | list[list[str]], [["a", "b"]]),
([["a", "b"]], list[list[str]] | list[str], [["a", "b"]]),
],
)
def test_deserialize_union(value: Any, field_type: UnionType, expected: Any) -> None:
assert _deserialize_union(value, field_type) == expected
Loading