From 47b63d4015e60ca5c75822a677f2a3f03b24654b Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 10 Jan 2025 22:41:09 +0100 Subject: [PATCH 1/3] feat!: raise TypeError if config field does not match declared type This makes parsing stricter and could result in errors in some existing configs. However, it allows for more precise deserialization, especially in case of union types. --- coqpit/coqpit.py | 50 +++++++++++++++---- tests/test_serialization.py | 97 +++++++++++++++++++++++++++++++------ 2 files changed, 123 insertions(+), 24 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 9d8866c..64335d8 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -7,6 +7,7 @@ import json import operator import typing +import warnings from collections.abc import Callable, ItemsView, Iterable, Iterator, MutableMapping from dataclasses import MISSING as _MISSING from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace @@ -182,6 +183,9 @@ def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]: Returns: Dict: deserialized dictionary. """ + if not isinstance(x, dict): + msg = f"Value `{x}` is not a dictionary" + raise TypeError(msg) out_dict: dict[Any, Any] = {} for k, v in x.items(): if v is None: # if {'key':None} @@ -204,6 +208,9 @@ def _deserialize_list(x: list[Any], field_type: FieldType) -> list[Any]: Returns: [List]: deserialized list. """ + if not isinstance(x, list): + msg = f"Value `{x}` does not match field type `{field_type}`" + raise TypeError(msg) field_args = typing.get_args(field_type) if len(field_args) == 0: return x @@ -232,7 +239,7 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any: try: x = _deserialize(x, arg) break - except ValueError: + except (TypeError, ValueError): pass return x @@ -252,18 +259,30 @@ def _deserialize_primitive_types( Returns: Union[int, float, str, bool]: deserialized value. """ - if isinstance(x, str | bool): + base_type = _drop_none_type(field_type) + if base_type is not float and base_type is not int and base_type is not str and base_type is not bool: + raise TypeError + base_type = typing.cast(type[int | float | str | bool], base_type) + + type_mismatch = f"Value `{x}` does not match field type `{field_type}`" + if x is None and type(None) in typing.get_args(field_type): + return None + if isinstance(x, str): + if base_type is not str: + raise TypeError(type_mismatch) + return x + if isinstance(x, bool): + if base_type is not bool: + raise TypeError(type_mismatch) return x if isinstance(x, int | float): - base_type = _drop_none_type(field_type) - if base_type is not float and base_type is not int and base_type is not str and base_type is not bool: - raise TypeError - base_type = typing.cast(type[int | float | str | bool], base_type) if x == float("inf") or x == float("-inf"): # if value type is inf return regardless. return x + if base_type is not int and base_type is not float: + raise TypeError(type_mismatch) return base_type(x) - return None + raise TypeError(type_mismatch) def _deserialize_path(x: Any, field_type: FieldType) -> Path | None: @@ -299,8 +318,8 @@ def _deserialize(x: Any, field_type: FieldType) -> Any: return _deserialize_path(x, field_type) if _is_primitive_type(_drop_none_type(field_type)): return _deserialize_primitive_types(x, field_type) - msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type." - raise ValueError(msg) + msg = f"Type '{type(x)}' of value '{x}' does not match declared '{field_type}' field type." + raise TypeError(msg) CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"] @@ -433,7 +452,18 @@ def deserialize(self, data: dict[str, Any]) -> Self: if value == MISSING: msg = f"deserialized with unknown value for {field.name} in {self.__class__.__name__}" raise ValueError(msg) - value = _deserialize(value, field.type) + try: + value = _deserialize(value, field.type) + except TypeError: + warnings.warn( + ( + f"Type mismatch in {type(self).__name__}\n" + f"Failed to deserialize field: {field.name} ({field.type}) = {value}\n" + f"Replaced it with field's default value: {_default_value(field)}" + ), + stacklevel=2, + ) + value = _default_value(field) init_kwargs[field.name] = value for k, v in init_kwargs.items(): setattr(self, k, v) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 4470c60..43af10f 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,9 +1,11 @@ from dataclasses import dataclass, field from pathlib import Path +from types import UnionType +from typing import Any import pytest -from coqpit.coqpit import Coqpit, _deserialize_list, _deserialize_primitive_types +from coqpit.coqpit import Coqpit, FieldType, _deserialize_list, _deserialize_primitive_types, _deserialize_union @dataclass @@ -62,24 +64,44 @@ def test_serialization() -> None: assert ref_config.some_dict["c"] == new_config.some_dict["c"] +def test_serialization_type_mismatch() -> None: + file_path = Path(__file__).resolve().parent / "test_serialization.json" + + ref_config = Reference() + ref_config.size = True + ref_config.save_json(file_path) + + new_config = Group() + with pytest.warns(UserWarning, match="Type mismatch"): + new_config.load_json(file_path) + new_config.pprint() + + # check values + assert len(ref_config) == len(new_config) + assert new_config.size is None + + def test_deserialize_list() -> None: assert _deserialize_list([1, 2, 3], list) == [1, 2, 3] assert _deserialize_list([1, 2, 3], list[int]) == [1, 2, 3] + assert _deserialize_list([[1, 2, 3]], list[list[int]]) == [[1, 2, 3]] + assert _deserialize_list([1.0, 2.0, 3.0], list[float]) == [1.0, 2.0, 3.0] assert _deserialize_list([1, 2, 3], list[float]) == [1.0, 2.0, 3.0] - assert _deserialize_list([1, 2, 3], list[str]) == ["1", "2", "3"] + assert _deserialize_list(["1", "2", "3"], list[str]) == ["1", "2", "3"] + with pytest.raises(TypeError, match="does not match field type"): + _deserialize_list([1, 2, 3], list[list[int]]) -def test_deserialize_primitive_type() -> None: - cases = ( + +@pytest.mark.parametrize( + ("value", "field_type", "expected"), + [ (True, bool, True), (False, bool, False), ("a", str, "a"), ("3", str, "3"), (3, int, 3), (3, float, 3.0), - (3, str, "3"), - (3.0, str, "3.0"), - (3, bool, True), ("a", str | None, "a"), ("3", str | None, "3"), (3, int | None, 3), @@ -87,14 +109,61 @@ def test_deserialize_primitive_type() -> None: (None, str | None, None), (None, int | None, None), (None, float | None, None), - (None, str | None, None), + (None, bool | None, None), (float("inf"), float, float("inf")), (float("inf"), int, float("inf")), (float("-inf"), float, float("-inf")), (float("-inf"), int, float("-inf")), - ) - for value, field_type, expected in cases: - assert _deserialize_primitive_types(value, field_type) == expected - - with pytest.raises(TypeError): - _deserialize_primitive_types(3, Coqpit) + ], +) +def test_deserialize_primitive_type( + value: str | bool | float | None, + field_type: FieldType, + expected: str | bool | float | None, +) -> None: + assert _deserialize_primitive_types(value, field_type) == expected + + +@pytest.mark.parametrize( + ("value", "field_type"), + [ + (3, str), + (3, str | None), + (3.0, str), + (3, bool), + ("1", int), + ("2.0", float), + ("True", bool), + ("True", bool | None), + ("", bool | None), + ([1, 2], str), + ([1, 2, 3], int), + ], +) +def test_deserialize_primitive_type_mismatch( + value: str | bool | float | None, + field_type: FieldType, +) -> None: + with pytest.raises(TypeError, match="does not match field type"): + _deserialize_primitive_types(value, field_type) + + +@pytest.mark.parametrize( + ("value", "field_type", "expected"), + [ + ("a", int | str, "a"), + ("a", str | int, "a"), + (1, int | str, 1), + (1, str | int, 1), + (1, str | int | list[int], 1), + ([1, 2], str | int | list[int], [1, 2]), + ([1, 2], list[int] | int | str, [1, 2]), + ([1, 2], dict | list, [1, 2]), + (["a", "b"], list[str] | list[list[str]], ["a", "b"]), + (["a", "b"], list[list[str]] | list[str], ["a", "b"]), + ([["a", "b"]], list[str] | list[list[str]], [["a", "b"]]), + ([["a", "b"]], list[list[str]] | list[str], [["a", "b"]]), + ], +) +def test_deserialize_union(value: Any, field_type: UnionType, expected: Any) -> None: + assert _deserialize_union(value, field_type) == expected From eacb5b335da34f83f3ac70f33571a0f2249caef3 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 11 Jan 2025 19:38:45 +0100 Subject: [PATCH 2/3] chore: remove test output file from git --- .gitignore | 1 + tests/test_serialization.json | 24 ------------------------ 2 files changed, 1 insertion(+), 24 deletions(-) delete mode 100644 tests/test_serialization.json diff --git a/.gitignore b/.gitignore index a3bbda1..81689fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ uv.lock +tests/test_serialization.json WadaSNR/ .idea/ diff --git a/tests/test_serialization.json b/tests/test_serialization.json deleted file mode 100644 index 81be55d..0000000 --- a/tests/test_serialization.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "name": "Coqpit", - "size": 3, - "path": "a/b", - "people": [ - { - "name": "Eren", - "age": 11 - }, - { - "name": "Geren", - "age": 12 - }, - { - "name": "Ceren", - "age": 15 - } - ], - "some_dict": { - "a": 1, - "b": 2, - "c": null - } -} From d41fd357ca4009d9e140bccc4802ed89a49a592d Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 11 Jan 2025 18:11:30 +0100 Subject: [PATCH 3/3] chore: bump version to 0.2.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b457ea5..905a0ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "coqpit-config" -version = "0.1.2" +version = "0.2.0" description = "Simple (maybe too simple), light-weight config management through python data-classes." readme = "README.md" requires-python = ">=3.10"