Skip to content

Commit

Permalink
Add signature for dataclasses.replace (#14849)
Browse files Browse the repository at this point in the history
Validate `dataclassses.replace` actual arguments to match the fields:

- Unlike `__init__`, the arguments are always named.
- All arguments are optional except for `InitVar`s without a default
value.

The tricks:
- We're looking up type of the first positional argument ("obj") through
private API. See #10216, #14845.
- We're preparing the signature of "replace" (for that specific
dataclass) during the dataclass transformation and storing it in a
"private" class attribute `__mypy-replace` (obviously not part of
PEP-557 but contains a hyphen so should not conflict with any future
valid identifier). Stashing the signature into the symbol table allows
it to be passed across phases and cached across invocations. The stashed
signature lacks the first argument, which we prepend at function
signature hook time, since it depends on the type that `replace` is
called on.

Based on #14526 but actually simpler.
Partially addresses #5152.

# Remaining tasks

- [x] handle generic dataclasses
- [x] avoid data class transforms
- [x] fine-grained mode tests

---------

Co-authored-by: Alex Waygood <[email protected]>
  • Loading branch information
ikonst and AlexWaygood authored Jun 17, 2023
1 parent 21cc1c7 commit 6f2bfff
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 3 deletions.
163 changes: 162 additions & 1 deletion mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.meet import meet_types
from mypy.messages import format_type_bare
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
Expand Down Expand Up @@ -38,7 +40,7 @@
TypeVarExpr,
Var,
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import ClassDefContext, FunctionSigContext, SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
_get_callee_type,
_get_decorator_bool_argument,
Expand All @@ -56,10 +58,13 @@
Instance,
LiteralType,
NoneType,
ProperType,
TupleType,
Type,
TypeOfAny,
TypeVarType,
UninhabitedType,
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars
Expand All @@ -76,6 +81,7 @@
frozen_default=False,
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)
_INTERNAL_REPLACE_SYM_NAME = "__mypy-replace"


class DataclassAttribute:
Expand Down Expand Up @@ -344,13 +350,47 @@ def transform(self) -> bool:

self._add_dataclass_fields_magic_attribute()

if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
self._add_internal_replace_method(attributes)

info.metadata["dataclass"] = {
"attributes": [attr.serialize() for attr in attributes],
"frozen": decorator_arguments["frozen"],
}

return True

def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) -> None:
"""
Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass
to be used later whenever 'dataclasses.replace' is called for this dataclass.
"""
arg_types: list[Type] = []
arg_kinds = []
arg_names: list[str | None] = []

info = self._cls.info
for attr in attributes:
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
arg_kinds.append(
ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT
)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
)

self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
)

def add_slots(
self, info: TypeInfo, attributes: list[DataclassAttribute], *, correct_version: bool
) -> None:
Expand Down Expand Up @@ -893,3 +933,124 @@ def _has_direct_dataclass_transform_metaclass(info: TypeInfo) -> bool:
info.declared_metaclass is not None
and info.declared_metaclass.type.dataclass_transform_spec is not None
)


def _fail_not_dataclass(ctx: FunctionSigContext, t: Type, parent_t: Type) -> None:
t_name = format_type_bare(t, ctx.api.options)
if parent_t is t:
msg = (
f'Argument 1 to "replace" has a variable type "{t_name}" not bound to a dataclass'
if isinstance(t, TypeVarType)
else f'Argument 1 to "replace" has incompatible type "{t_name}"; expected a dataclass'
)
else:
pt_name = format_type_bare(parent_t, ctx.api.options)
msg = (
f'Argument 1 to "replace" has type "{pt_name}" whose item "{t_name}" is not bound to a dataclass'
if isinstance(t, TypeVarType)
else f'Argument 1 to "replace" has incompatible type "{pt_name}" whose item "{t_name}" is not a dataclass'
)

ctx.api.fail(msg, ctx.context)


def _get_expanded_dataclasses_fields(
ctx: FunctionSigContext, typ: ProperType, display_typ: ProperType, parent_typ: ProperType
) -> list[CallableType] | None:
"""
For a given type, determine what dataclasses it can be: for each class, return the field types.
For generic classes, the field types are expanded.
If the type contains Any or a non-dataclass, returns None; in the latter case, also reports an error.
"""
if isinstance(typ, AnyType):
return None
elif isinstance(typ, UnionType):
ret: list[CallableType] | None = []
for item in typ.relevant_items():
item = get_proper_type(item)
item_types = _get_expanded_dataclasses_fields(ctx, item, item, parent_typ)
if ret is not None and item_types is not None:
ret += item_types
else:
ret = None # but keep iterating to emit all errors
return ret
elif isinstance(typ, TypeVarType):
return _get_expanded_dataclasses_fields(
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
)
elif isinstance(typ, Instance):
replace_sym = typ.type.get_method(_INTERNAL_REPLACE_SYM_NAME)
if replace_sym is None:
_fail_not_dataclass(ctx, display_typ, parent_typ)
return None
replace_sig = replace_sym.type
assert isinstance(replace_sig, ProperType)
assert isinstance(replace_sig, CallableType)
return [expand_type_by_instance(replace_sig, typ)]
else:
_fail_not_dataclass(ctx, display_typ, parent_typ)
return None


# TODO: we can potentially get the function signature hook to allow returning a union
# and leave this to the regular machinery of resolving a union of callables
# (https://github.com/python/mypy/issues/15457)
def _meet_replace_sigs(sigs: list[CallableType]) -> CallableType:
"""
Produces the lowest bound of the 'replace' signatures of multiple dataclasses.
"""
args = {
name: (typ, kind)
for name, typ, kind in zip(sigs[0].arg_names, sigs[0].arg_types, sigs[0].arg_kinds)
}

for sig in sigs[1:]:
sig_args = {
name: (typ, kind)
for name, typ, kind in zip(sig.arg_names, sig.arg_types, sig.arg_kinds)
}
for name in (*args.keys(), *sig_args.keys()):
sig_typ, sig_kind = args.get(name, (UninhabitedType(), ARG_NAMED_OPT))
sig2_typ, sig2_kind = sig_args.get(name, (UninhabitedType(), ARG_NAMED_OPT))
args[name] = (
meet_types(sig_typ, sig2_typ),
ARG_NAMED_OPT if sig_kind == sig2_kind == ARG_NAMED_OPT else ARG_NAMED,
)

return sigs[0].copy_modified(
arg_names=list(args.keys()),
arg_types=[typ for typ, _ in args.values()],
arg_kinds=[kind for _, kind in args.values()],
)


def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
"""
Returns a signature for the 'dataclasses.replace' function that's dependent on the type
of the first positional argument.
"""
if len(ctx.args) != 2:
# Ideally the name and context should be callee's, but we don't have it in FunctionSigContext.
ctx.api.fail(f'"{ctx.default_signature.name}" has unexpected type annotation', ctx.context)
return ctx.default_signature

if len(ctx.args[0]) != 1:
return ctx.default_signature # leave it to the type checker to complain

obj_arg = ctx.args[0][0]
obj_type = get_proper_type(ctx.api.get_expression_type(obj_arg))
inst_type_str = format_type_bare(obj_type, ctx.api.options)

replace_sigs = _get_expanded_dataclasses_fields(ctx, obj_type, obj_type, obj_type)
if replace_sigs is None:
return ctx.default_signature
replace_sig = _meet_replace_sigs(replace_sigs)

return replace_sig.copy_modified(
arg_names=[None, *replace_sig.arg_names],
arg_kinds=[ARG_POS, *replace_sig.arg_kinds],
arg_types=[obj_type, *replace_sig.arg_types],
ret_type=obj_type,
fallback=ctx.default_signature.fallback,
name=f"{ctx.default_signature.name} of {inst_type_str}",
)
4 changes: 3 additions & 1 deletion mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
def get_function_signature_hook(
self, fullname: str
) -> Callable[[FunctionSigContext], FunctionLike] | None:
from mypy.plugins import attrs
from mypy.plugins import attrs, dataclasses

if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
return attrs.evolve_function_sig_callback
elif fullname in ("attr.fields", "attrs.fields"):
return attrs.fields_function_sig_callback
elif fullname == "dataclasses.replace":
return dataclasses.replace_function_sig_callback
return None

def get_method_signature_hook(
Expand Down
19 changes: 19 additions & 0 deletions test-data/unit/check-dataclass-transform.test
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,24 @@ reveal_type(bar.base) # N: Revealed type is "builtins.int"
[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformReplace]
from dataclasses import replace
from typing import dataclass_transform, Type

@dataclass_transform()
def my_dataclass(cls: Type) -> Type:
return cls

@my_dataclass
class Person:
name: str

p = Person('John')
y = replace(p, name='Bob') # E: Argument 1 to "replace" has incompatible type "Person"; expected a dataclass

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformSimpleDescriptor]
# flags: --python-version 3.11

Expand Down Expand Up @@ -1051,5 +1069,6 @@ class Desc2:
class C:
x: Desc # E: Unsupported signature for "__set__" in "Desc"
y: Desc2 # E: Unsupported "__set__" in "Desc2"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]
Loading

0 comments on commit 6f2bfff

Please sign in to comment.