Skip to content

Commit

Permalink
refactor: manually address remaining ruff issues for python 3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Nov 25, 2024
1 parent 0e41a42 commit dc569eb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 40 deletions.
14 changes: 2 additions & 12 deletions coqpit/coqpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,17 @@
import contextlib
import json
import operator
import sys
import typing
from collections.abc import Callable, ItemsView, Iterable, Iterator, MutableMapping
from dataclasses import MISSING as _MISSING
from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace
from pathlib import Path
from pprint import pprint
from types import GenericAlias
from types import GenericAlias, UnionType
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeGuard, TypeVar, Union, overload

from typing_extensions import Self, TypeIs

# TODO: Available from Python 3.10
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType: TypeAlias = Union

if TYPE_CHECKING: # pragma: no cover
import os
from dataclasses import _MISSING_TYPE
Expand Down Expand Up @@ -90,10 +83,7 @@ def _is_union(field_type: FieldType) -> TypeIs[UnionType]:
bool: True if input type is `Union`
"""
origin = typing.get_origin(field_type)
is_union = origin is Union
if sys.version_info >= (3, 10):
is_union = is_union or origin is UnionType
return is_union
return origin is Union or origin is UnionType


def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionType]:
Expand Down
21 changes: 9 additions & 12 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union

import pytest

Expand Down Expand Up @@ -70,8 +69,6 @@ def test_deserialize_list() -> None:
assert _deserialize_list([1, 2, 3], list[str]) == ["1", "2", "3"]


# TODO: `type: ignore` can probably be removed when switching to Python 3.10
# Union syntax (e.g. str | int)
def test_deserialize_primitive_type() -> None:
cases = (
(True, bool, True),
Expand All @@ -83,21 +80,21 @@ def test_deserialize_primitive_type() -> None:
(3, str, "3"),
(3.0, str, "3.0"),
(3, bool, True),
("a", Union[str, None], "a"),
("3", Union[str, None], "3"),
(3, Union[int, None], 3),
(3, Union[float, None], 3.0),
(None, Union[str, None], None),
(None, Union[int, None], None),
(None, Union[float, None], None),
(None, Union[str, None], None),
("a", str | None, "a"),
("3", str | None, "3"),
(3, int | None, 3),
(3, float | None, 3.0),
(None, str | None, None),
(None, int | None, None),
(None, float | None, None),
(None, str | None, None),
(float("inf"), float, float("inf")),
(float("inf"), int, float("inf")),
(float("-inf"), float, float("-inf")),
(float("-inf"), int, float("-inf")),
)
for value, field_type, expected in cases:
assert _deserialize_primitive_types(value, field_type) == expected # type: ignore[arg-type]
assert _deserialize_primitive_types(value, field_type) == expected

with pytest.raises(TypeError):
_deserialize_primitive_types(3, Coqpit)
27 changes: 11 additions & 16 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,39 @@
from typing import Union

from coqpit.coqpit import _is_optional_field, _is_union, _is_union_and_not_simple_optional

# TODO: `type: ignore` can probably be removed when switching to Python 3.10
# Union syntax (e.g. str | int)


def test_is_union() -> None:
cases = (
(Union[str, int], True),
(Union[str, None], True),
(str | int, True),
(str | None, True),
(int, False),
(list[int], False),
(list[str | int], False),
)
for item, expected in cases:
assert _is_union(item) == expected # type: ignore[arg-type]
assert _is_union(item) == expected


def test_is_union_and_not_simple_optional() -> None:
cases = (
(Union[str, int], True),
(Union[str, None], False),
(Union[list[int], None], False),
(str | int, True),
(str | None, False),
(list[int] | None, False),
(int, False),
(list[int], False),
(list[str | int], False),
)
for item, expected in cases:
assert _is_union_and_not_simple_optional(item) == expected # type: ignore[arg-type]
assert _is_union_and_not_simple_optional(item) == expected


def test_is_optional_field() -> None:
cases = (
(Union[str, int], False),
(Union[str, None], True),
(Union[list[int], None], True),
(str | int, False),
(str | None, True),
(list[int] | None, True),
(int, False),
(list[int], False),
(list[str | int], False),
)
for item, expected in cases:
assert _is_optional_field(item) == expected # type: ignore[arg-type]
assert _is_optional_field(item) == expected

0 comments on commit dc569eb

Please sign in to comment.