Skip to content

Commit

Permalink
Update OpenAPI for msgspec
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Nov 2, 2023
1 parent e997ba7 commit eae733c
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 18 deletions.
8 changes: 5 additions & 3 deletions esmerald/datastructures/msgspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from pydantic._internal._schema_generation_shared import (
GetJsonSchemaHandler as GetJsonSchemaHandler,
)
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core.core_schema import (
CoreSchema,
PlainValidatorFunctionSchema,
with_info_plain_validator_function,
)

REF_TEMPLATE = "#/components/schemas/{name}"


class Struct(msgspec.Struct):
"""
Expand Down Expand Up @@ -66,8 +67,9 @@ def _validate(cls, __input_value: Any, _: Any) -> "Struct":
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return {"type": "object"}
) -> Any:
_, schema_definitions = msgspec.json.schema_components((cls,), REF_TEMPLATE)
return schema_definitions[cls.__name__]

@classmethod
def __get_pydantic_core_schema__(
Expand Down
8 changes: 5 additions & 3 deletions esmerald/openapi/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel, field_validator
from typing_extensions import Annotated, Doc

from esmerald.datastructures.msgspec import Struct
from esmerald.enums import MediaType


Expand Down Expand Up @@ -34,7 +35,7 @@ async def create() -> Union[None, ItemOut]:
"""

model: Annotated[
Union[Type[BaseModel], List[Type[BaseModel]]],
Union[Type[BaseModel], List[Type[BaseModel]], Type[Struct], List[Type[Struct]]],
Doc(
"""
A `pydantic.BaseModel` type of object of a `list` of
Expand Down Expand Up @@ -91,8 +92,9 @@ class Error(BaseModel):

@field_validator("model", mode="before")
def validate_model(
cls, model: Union[Type[BaseModel], List[Type[BaseModel]]]
) -> Union[Type[BaseModel], List[Type[BaseModel]]]:
cls,
model: Union[Type[BaseModel], List[Type[BaseModel]], Type[Struct], List[Type[Struct]]],
) -> Union[Type[BaseModel], List[Type[BaseModel]], Type[Struct], List[Type[Struct]]]:
if isinstance(model, list) and len(model) > 1:
raise ValueError(
"The representation of a list of models in OpenAPI can only be a total of one. Example: OpenAPIResponse(model=[MyModel])."
Expand Down
24 changes: 24 additions & 0 deletions esmerald/openapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Any, Dict, List, Tuple, Union

import msgspec
from pydantic import TypeAdapter
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from typing_extensions import Literal

from esmerald.datastructures.msgspec import Struct
from esmerald.openapi.validation import (
validation_error_definition,
validation_error_response_definition,
)
from esmerald.utils.helpers import is_class_and_subclass

VALIDATION_ERROR_DEFINITION = validation_error_definition.model_dump(exclude_none=True)
VALIDATION_ERROR_RESPONSE_DEFINITION = validation_error_response_definition.model_dump(
Expand All @@ -32,6 +35,25 @@
"4XX",
"5XX",
}
REF_TEMPLATE = "#/components/schemas/{name}"


def get_msgspec_definitions(
field_mapping: Dict[Tuple[FieldInfo, Literal["validation", "serialization"]], JsonSchemaValue]
) -> Dict[str, str]:
"""
Gets any field definition for a msgspec Struct declared
in the OpenAPI spec.
"""
definitions: Dict[str, str] = {}
for field, _ in field_mapping:
if isinstance(field.annotation, Struct) or is_class_and_subclass(field.annotation, Struct):
_, schema_definitions = msgspec.json.schema_components(
(field.annotation,), REF_TEMPLATE
)
definitions.update(**schema_definitions)

return definitions


def get_definitions(
Expand All @@ -46,6 +68,8 @@ def get_definitions(
field_mapping, definitions = schema_generator.generate_definitions(
inputs=inputs # type: ignore
)

definitions.update(**get_msgspec_definitions(field_mapping)) # type: ignore
return field_mapping, definitions # type: ignore[return-value]


Expand Down
14 changes: 4 additions & 10 deletions esmerald/routing/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,16 @@ def response_models(self) -> Dict[int, Any]:
responses: Dict[int, ResponseParam] = {}
if self.responses:
for status_code, response in self.responses.items():
annotation = (
List[response.model[0]] # type: ignore
if isinstance(response.model, list)
else response.model
)
model = response.model[0] if isinstance(response.model, list) else response.model

name = (
response.model[0].__name__
if isinstance(response.model, list)
else response.model.__name__
annotation = (
List[model] if isinstance(response.model, list) else model # type: ignore
)

responses[status_code] = ResponseParam(
annotation=annotation,
description=response.description,
alias=name,
alias=model.__name__,
)
return responses

Expand Down
60 changes: 58 additions & 2 deletions tests/test_msgspec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union
from typing import Annotated, Union

# from msgspec import Struct
import msgspec
from pydantic import BaseModel
from starlette import status

Expand Down Expand Up @@ -55,3 +55,59 @@ def test_user_msgspec_two(test_client_factory):
data = {"user": {"name": "Esmerald", "email": "[email protected]"}}
response = client.post("/", json=data)
assert response.json() == data


Id = Annotated[int, msgspec.Meta(gt=0)]
Email = Annotated[str, msgspec.Meta(min_length=5, max_length=100, pattern="[^@]+@[^@]+\\.[^@]+")]


class Comment(msgspec.Struct):
id: Id
email: Email


@post(status_code=status.HTTP_202_ACCEPTED)
def comments(payload: Comment) -> Comment:
return payload


def test_user_msgspec_constraints_name(test_client_factory):
with create_client(routes=[Gateway(handler=comments)]) as client:
data = {"id": -1, "email": "cenas"}
response = client.post("/", json=data)

assert response.status_code == 400
assert response.json()["errors"] == [{"id": "Expected `int` >= 1"}]


def test_user_msgspec_constraints_email(test_client_factory):
with create_client(routes=[Gateway(handler=comments)]) as client:
data = {"id": 4, "email": "cenas"}
response = client.post("/", json=data)

assert response.status_code == 400
assert response.json()["errors"] == [
{"email": "Expected `str` matching regex '[^@]+@[^@]+\\\\.[^@]+'"}
]


class Address(msgspec.Struct):
name: str


class AddressBook(msgspec.Struct):
address: Address


@post()
def nested(payload: AddressBook) -> AddressBook:
return payload


def test_nested_msgspec_struct(test_client_factory):
with create_client(routes=[Gateway(handler=nested)]) as client:
data = {"address": {"name": "Lisbon, Portugal"}}
response = client.post("/", json=data)

assert response.status_code == 201
assert response.json() == {"address": {"name": "Lisbon, Portugal"}}

0 comments on commit eae733c

Please sign in to comment.