Skip to content

Commit

Permalink
Fix non-generic protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinche committed Oct 22, 2023
1 parent 23b0fef commit b2931e6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 23 deletions.
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
- _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/).
- Remove some unused lines in the unstructuring code.
([#416](https://github.com/python-attrs/cattrs/pull/416))
- Fix handling classes inheriting from non-generic protocols.
([#374](https://github.com/python-attrs/cattrs/issues/374))

## 23.1.2 (2023-06-02)

Expand Down
7 changes: 5 additions & 2 deletions src/cattrs/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,11 @@ def is_counter(type):
)

def is_generic(obj) -> bool:
return isinstance(obj, (_GenericAlias, GenericAlias)) or is_subclass(
obj, Generic
"""Whether obj is a generic type."""
# Inheriting from protocol will inject `Generic` into the MRO
# without `__orig_bases__`.
return isinstance(obj, (_GenericAlias, GenericAlias)) or (
is_subclass(obj, Generic) and hasattr(obj, "__orig_bases__")
)

def copy_with(type, args):
Expand Down
58 changes: 37 additions & 21 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,7 @@ def test_raises_if_no_generic_params_supplied(
assert exc.value.type_ is T


def test_unstructure_generic_attrs():
c = Converter()

def test_unstructure_generic_attrs(genconverter):
@attrs(auto_attribs=True)
class Inner(Generic[T]):
a: T
Expand All @@ -191,23 +189,21 @@ class Outer:
inner: Inner[int]

initial = Outer(Inner(1))
raw = c.unstructure(initial)
raw = genconverter.unstructure(initial)

assert raw == {"inner": {"a": 1}}

new = c.structure(raw, Outer)
new = genconverter.structure(raw, Outer)
assert initial == new

@attrs(auto_attribs=True)
class OuterStr:
inner: Inner[str]

assert c.structure(raw, OuterStr) == OuterStr(Inner("1"))
assert genconverter.structure(raw, OuterStr) == OuterStr(Inner("1"))


def test_unstructure_deeply_nested_generics():
c = Converter()

def test_unstructure_deeply_nested_generics(genconverter):
@define
class Inner:
a: int
Expand All @@ -217,16 +213,14 @@ class Outer(Generic[T]):
inner: T

initial = Outer[Inner](Inner(1))
raw = c.unstructure(initial, Outer[Inner])
raw = genconverter.unstructure(initial, Outer[Inner])
assert raw == {"inner": {"a": 1}}

raw = c.unstructure(initial)
raw = genconverter.unstructure(initial)
assert raw == {"inner": {"a": 1}}


def test_unstructure_deeply_nested_generics_list():
c = Converter()

def test_unstructure_deeply_nested_generics_list(genconverter):
@define
class Inner:
a: int
Expand All @@ -236,16 +230,14 @@ class Outer(Generic[T]):
inner: List[T]

initial = Outer[Inner]([Inner(1)])
raw = c.unstructure(initial, Outer[Inner])
raw = genconverter.unstructure(initial, Outer[Inner])
assert raw == {"inner": [{"a": 1}]}

raw = c.unstructure(initial)
raw = genconverter.unstructure(initial)
assert raw == {"inner": [{"a": 1}]}


def test_unstructure_protocol():
c = Converter()

def test_unstructure_protocol(genconverter):
class Proto(Protocol):
a: int

Expand All @@ -258,10 +250,10 @@ class Outer:
inner: Proto

initial = Outer(Inner(1))
raw = c.unstructure(initial, Outer)
raw = genconverter.unstructure(initial, Outer)
assert raw == {"inner": {"a": 1}}

raw = c.unstructure(initial)
raw = genconverter.unstructure(initial)
assert raw == {"inner": {"a": 1}}


Expand Down Expand Up @@ -306,3 +298,27 @@ class B(A[int]):
pass

assert generate_mapping(B, {}) == {T.__name__: int}


def test_nongeneric_protocols(converter):
"""Non-generic protocols work."""

class NongenericProtocol(Protocol):
...

@define
class Entity(NongenericProtocol):
...

assert generate_mapping(Entity) == {}

class GenericProtocol(Protocol[T]):
...

@define
class GenericEntity(GenericProtocol[int]):
a: int

assert generate_mapping(GenericEntity) == {"T": int}

assert converter.structure({"a": 1}, GenericEntity) == GenericEntity(1)

0 comments on commit b2931e6

Please sign in to comment.