Skip to content

Commit

Permalink
fix (#1263): correct AsyncAPI schema in descriminator case (#1272)
Browse files Browse the repository at this point in the history
* fix (#1271): correct AsyncAPI schema in descriminator case

* lint: fix types

* tests: RMQ-compatible tests

* fix: mypy ignores

* chore: bump version
  • Loading branch information
Lancetnik authored Feb 27, 2024
1 parent b375025 commit 0318b97
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 36 deletions.
2 changes: 1 addition & 1 deletion faststream/__about__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Simple and fast framework to create message brokers based microservices."""
__version__ = "0.4.4"
__version__ = "0.4.5"


INSTALL_YAML = """
Expand Down
5 changes: 5 additions & 0 deletions faststream/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@ def raise_fastapi_validation_error(errors: List[Any], body: AnyDict) -> Never:
)

from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined as PydanticUndefined
from pydantic_core import to_jsonable_python

SCHEMA_FIELD = "json_schema_extra"
DEF_KEY = "$defs"

def model_to_jsonable(
model: BaseModel,
Expand Down Expand Up @@ -131,6 +133,9 @@ def model_copy(model: ModelVar, **kwargs: Any) -> ModelVar:
CoreSchema = Any # type: ignore[assignment,misc]

SCHEMA_FIELD = "schema_extra"
DEF_KEY = "definitions"

PydanticUndefined = Ellipsis # type: ignore[assignment]

def dump_json(data: Any) -> bytes:
return json_dumps(data, default=pydantic_encoder)
Expand Down
50 changes: 33 additions & 17 deletions faststream/asyncapi/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union

from faststream._compat import HAS_FASTAPI, PYDANTIC_V2
from faststream._compat import DEF_KEY, HAS_FASTAPI
from faststream.app import FastStream
from faststream.asyncapi.schema import (
Channel,
Expand Down Expand Up @@ -181,19 +181,32 @@ def _resolve_msg_payloads(
) -> Reference:
one_of_list: List[Reference] = []

pydantic_key = "$defs" if PYDANTIC_V2 else "definitions"

for p_title, p in m.payload.get("oneOf", {}).items():
p = _move_pydantic_refs(p, pydantic_key)
payloads.update(p.pop(pydantic_key, {}))
payloads[p_title] = p
one_of_list.append(Reference(**{"$ref": f"#/components/schemas/{p_title}"}))
m.payload = _move_pydantic_refs(m.payload, DEF_KEY)
if DEF_KEY in m.payload:
payloads.update(m.payload.pop(DEF_KEY))
if "discriminator" in m.payload:
m.payload["discriminator"] = m.payload["discriminator"]["propertyName"]

one_of = m.payload.get("oneOf")
if isinstance(one_of, dict):
for p_title, p in one_of.items():
payloads.update(p.pop(DEF_KEY, {}))
if p_title not in payloads:
payloads[p_title] = p
one_of_list.append(Reference(**{"$ref": f"#/components/schemas/{p_title}"}))

elif one_of is not None:
for p in one_of:
p_title = next(iter(p.values())).split("/")[-1]
if p_title not in payloads:
payloads[p_title] = p
one_of_list.append(Reference(**{"$ref": f"#/components/schemas/{p_title}"}))

if not one_of_list:
p = _move_pydantic_refs(m.payload, pydantic_key)
payloads.update(p.pop(pydantic_key, {}))
p_title = p.get("title", f"{channel_name}Payload")
payloads[p_title] = p
payloads.update(m.payload.pop(DEF_KEY, {}))
p_title = m.payload.get("title", f"{channel_name}Payload")
if p_title not in payloads:
payloads[p_title] = m.payload
m.payload = {"$ref": f"#/components/schemas/{p_title}"}

else:
Expand All @@ -214,14 +227,17 @@ def _move_pydantic_refs(
data = original.copy()

for k in data:
if k == "$ref":
data[k] = data[k].replace(key, "components/schemas")
item = data[k]

if isinstance(item, str):
if key in item:
data[k] = data[k].replace(key, "components/schemas")

elif isinstance(data[k], dict):
elif isinstance(item, dict):
data[k] = _move_pydantic_refs(data[k], key)

elif isinstance(data[k], List):
elif isinstance(item, List):
for i in range(len(data[k])):
data[k][i] = _move_pydantic_refs(data[k][i], key)
data[k][i] = _move_pydantic_refs(item[i], key)

return data
5 changes: 4 additions & 1 deletion faststream/asyncapi/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fast_depends.core import CallModel
from pydantic import BaseModel, create_model

from faststream._compat import PYDANTIC_V2, get_model_fields, model_schema
from faststream._compat import DEF_KEY, PYDANTIC_V2, get_model_fields, model_schema


def parse_handler_params(call: CallModel[Any, Any], prefix: str = "") -> Dict[str, Any]:
Expand Down Expand Up @@ -184,6 +184,9 @@ def get_model_schema(
param_body: Dict[str, Any] = body.get("properties", {})
param_body = param_body[name]

if defs := body.get(DEF_KEY):
param_body[DEF_KEY] = defs

original_title = param.title if PYDANTIC_V2 else param.field_info.title # type: ignore[attr-defined]

if original_title:
Expand Down
72 changes: 63 additions & 9 deletions faststream/broker/core/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from fast_depends._compat import PYDANTIC_V2
from fast_depends.core import CallModel, build_call_model
from fast_depends.dependencies import Depends
from pydantic import create_model
from pydantic import Field, create_model

from faststream._compat import is_test_env
from faststream._compat import PydanticUndefined, is_test_env
from faststream.asyncapi import schema as asyncapi
from faststream.broker.core.mixins import LoggingMixin
from faststream.broker.handler import BaseHandler
Expand Down Expand Up @@ -551,17 +551,71 @@ def _patch_fastapi_dependant(
params.extend(d.query_params + d.body_params) # type: ignore[attr-defined]

params_unique = {}
params_names = set()
for p in params:
if p.name not in params_names:
params_names.add(p.name)
if p.name not in params_unique:
info = p.field_info if PYDANTIC_V2 else p
params_unique[p.name] = (info.annotation, info.default)

dependant.model = create_model( # type: ignore[call-overload]
getattr(dependant.call, "__name__", type(dependant.call).__name__),
**params_unique,
field_data = {
"default": ... if info.default is PydanticUndefined else info.default,
"default_factory": info.default_factory,
"alias": info.alias,
}

if PYDANTIC_V2:
from pydantic.fields import FieldInfo

info = cast(FieldInfo, info)

field_data.update(
{
"title": info.title,
"alias_priority": info.alias_priority,
"validation_alias": info.validation_alias,
"serialization_alias": info.serialization_alias,
"description": info.description,
"discriminator": info.discriminator,
"examples": info.examples,
"exclude": info.exclude,
"json_schema_extra": info.json_schema_extra,
}
)

f = next(
filter(
lambda x: isinstance(x, FieldInfo),
p.field_info.metadata or (),
),
Field(**field_data), # type: ignore[pydantic-field]
)

else:
from pydantic.fields import ModelField # type: ignore[attr-defined]

info = cast(ModelField, info)

field_data.update(
{
"title": info.field_info.title,
"description": info.field_info.description,
"discriminator": info.field_info.discriminator,
"exclude": info.field_info.exclude,
"gt": info.field_info.gt,
"ge": info.field_info.ge,
"lt": info.field_info.lt,
"le": info.field_info.le,
}
)
f = Field(**field_data) # type: ignore[pydantic-field]

params_unique[p.name] = (
info.annotation,
f,
)

dependant.model = create_model(
getattr(dependant.call, "__name__", type(dependant.call).__name__)
)

dependant.custom_fields = {}
dependant.flat_params = params_unique # type: ignore[assignment,misc]

Expand Down
10 changes: 8 additions & 2 deletions faststream/broker/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,15 @@ def get_payloads(self) -> List[Tuple[AnyDict, str]]:

for h, _, _, _, _, dep in self.calls:
body = parse_handler_params(
dep, prefix=f"{self._title or self.call_name}:Message"
dep,
prefix=f"{self._title or self.call_name}:Message",
)
payloads.append(
(
body,
to_camelcase(unwrap(h._original_call).__name__),
),
)
payloads.append((body, to_camelcase(unwrap(h._original_call).__name__)))

return payloads

Expand Down
8 changes: 4 additions & 4 deletions faststream/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
Sequence[JsonDecodable],
JsonDecodable,
]
JsonCompatible: TypeAlias = Union[JsonDecodable, datetime]
SendableMessage: TypeAlias = Union[
Dict[str, Union[JsonDecodable, datetime]],
Sequence[Union[JsonDecodable, datetime]],
Union[JsonDecodable, datetime],
datetime,
Dict[str, Union[JsonCompatible, Sequence[JsonCompatible]]],
Sequence[JsonCompatible],
JsonCompatible,
BaseModel,
None,
]
Expand Down
59 changes: 57 additions & 2 deletions tests/asyncapi/base/arguments.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from enum import Enum
from typing import Optional, Type
from typing import Optional, Type, Union

import pydantic
from dirty_equals import IsDict, IsPartialDict
from dirty_equals import IsDict, IsPartialDict, IsStr
from fast_depends import Depends
from fastapi import Depends as APIDepends
from typing_extensions import Annotated, Literal

from faststream import Context, FastStream
from faststream._compat import PYDANTIC_V2
from faststream.asyncapi.generate import get_app_schema
from faststream.broker.core.abc import BrokerUsecase
from tests.marks import pydanticV2


class FastAPICompatible: # noqa: D101
Expand Down Expand Up @@ -428,6 +430,59 @@ async def handle(id: int, message=message):
"type": "object",
}

@pydanticV2
def test_descriminator(self):
class Sub2(pydantic.BaseModel):
type: Literal["sub2"]

class Sub(pydantic.BaseModel):
type: Literal["sub"]

descriminator = Annotated[
Union[Sub2, Sub], pydantic.Field(..., discriminator="type")
]

broker = self.broker_class()

@broker.subscriber("test")
async def handle(user: descriminator):
...

schema = get_app_schema(self.build_app(broker)).to_jsonable()

key = tuple(schema["components"]["messages"].keys())[0]
assert key == IsStr(regex=r"test[\w:]*:Handle:Message")
assert schema["components"] == {
"messages": {
key: {
"title": key,
"correlationId": {"location": "$message.header#/correlation_id"},
"payload": {
"discriminator": "type",
"oneOf": [
{"$ref": "#/components/schemas/Sub2"},
{"$ref": "#/components/schemas/Sub"},
],
"title": "Handle:Message:Payload",
},
}
},
"schemas": {
"Sub": {
"properties": {"type": {"const": "sub", "title": "Type"}},
"required": ["type"],
"title": "Sub",
"type": "object",
},
"Sub2": {
"properties": {"type": {"const": "sub2", "title": "Type"}},
"required": ["type"],
"title": "Sub2",
"type": "object",
},
},
}, schema["components"]


class ArgumentsTestcase(FastAPICompatible): # noqa: D101
dependency_builder = staticmethod(Depends)
Expand Down

0 comments on commit 0318b97

Please sign in to comment.