diff --git a/docs/en/docs/release-notes.md b/docs/en/docs/release-notes.md index 42e2e92..6eaa787 100644 --- a/docs/en/docs/release-notes.md +++ b/docs/en/docs/release-notes.md @@ -7,9 +7,14 @@ hide: ## Unreleased +### Changed + +- Encoders saved on responses are ensured to be instances and not classes. + ### Fixed - Crash when passing string to is_type_structure (e.g. string annotations). +- Fix Encoder type in responses. ## 0.11.1 diff --git a/lilya/_internal/_encoders.py b/lilya/_internal/_encoders.py index 394abe1..502b230 100644 --- a/lilya/_internal/_encoders.py +++ b/lilya/_internal/_encoders.py @@ -240,7 +240,7 @@ def json_encode( *, json_encode_fn: Callable[..., Any] = json.dumps, post_transform_fn: Callable[[Any], Any] | None = json.loads, - with_encoders: Sequence[EncoderProtocol] | None = None, + with_encoders: Sequence[EncoderProtocol | MoldingProtocol] | None = None, ) -> Any: """ Encode a value to a JSON-compatible format using a list of encoder types. @@ -286,7 +286,7 @@ def apply_structure( structure: Any, value: Any, *, - with_encoders: Sequence[EncoderProtocol] | None = None, + with_encoders: Sequence[EncoderProtocol | MoldingProtocol] | None = None, ) -> Any: """ Apply structure to value. Decoding for e.g. input parameters diff --git a/lilya/responses.py b/lilya/responses.py index ea98c27..d104af6 100644 --- a/lilya/responses.py +++ b/lilya/responses.py @@ -9,6 +9,7 @@ from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, Sequence from datetime import datetime from email.utils import format_datetime, formatdate +from inspect import isclass from mimetypes import guess_type from typing import ( Any, @@ -26,15 +27,18 @@ from lilya.compat import md5_hexdigest from lilya.concurrency import iterate_in_threadpool from lilya.datastructures import URL, Header -from lilya.encoders import ENCODER_TYPES, Encoder, json_encode +from lilya.encoders import ENCODER_TYPES, EncoderProtocol, MoldingProtocol, json_encode from lilya.enums import Event, HTTPMethod, MediaType from lilya.types import Receive, Scope, Send Content = Union[str, bytes] +Encoder = Union[EncoderProtocol, MoldingProtocol] SyncContentStream = Iterable[Content] AsyncContentStream = AsyncIterable[Content] ContentStream = Union[AsyncContentStream, SyncContentStream] +_empty: tuple[Any, ...] = () + class Response: media_type: str | None = None @@ -49,7 +53,7 @@ def __init__( cookies: Mapping[str, str] | Any | None = None, media_type: str | None = None, background: Task | None = None, - encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None, + encoders: Sequence[Encoder | type[Encoder]] | None = None, ) -> None: if status_code is not None: self.status_code = status_code @@ -57,7 +61,9 @@ def __init__( self.media_type = media_type self.background = background self.cookies = cookies - self.encoders = encoders or [] + self.encoders: list[Encoder] = [ + encoder() if isclass(encoder) else encoder for encoder in encoders or _empty + ] self.body = self.make_response(content) self.raw_headers: list[Any] = [] @@ -249,7 +255,7 @@ def __init__( headers: Mapping[str, str] | None = None, media_type: str | None = None, background: Task | None = None, - encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None, + encoders: Sequence[Encoder | type[Encoder]] | None = None, ) -> None: super().__init__( content=content, @@ -281,7 +287,7 @@ def __init__( status_code: int = status.HTTP_307_TEMPORARY_REDIRECT, headers: Mapping[str, str] | None = None, background: Task | None = None, - encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None, + encoders: Sequence[Encoder | type[Encoder]] | None = None, ) -> None: super().__init__( content=b"", @@ -303,9 +309,11 @@ def __init__( headers: Mapping[str, str] | None = None, media_type: str | None = None, background: Task | None = None, - encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None, + encoders: Sequence[Encoder | type[Encoder]] | None = None, ) -> None: - self.encoders = encoders or [] + self.encoders: list[Encoder] = [ + encoder() if isclass(encoder) else encoder for encoder in encoders or _empty + ] if isinstance(content, AsyncIterable): self.body_iterator = content @@ -360,7 +368,7 @@ def __init__( stat_result: os.stat_result | None = None, method: str | None = None, content_disposition_type: str = "attachment", - encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None, + encoders: Sequence[Encoder | type[Encoder]] | None = None, ) -> None: self.path = path self.status_code = status_code @@ -371,7 +379,9 @@ def __init__( self.media_type = media_type self.background = background - self.encoders = encoders or [] + self.encoders: list[Encoder] = [ + encoder() if isclass(encoder) else encoder for encoder in encoders or _empty + ] self.make_headers(headers) if self.filename is not None: @@ -440,7 +450,7 @@ def __init__( background: Task | None = None, headers: dict[str, Any] | None = None, media_type: MediaType | str = MediaType.HTML, - encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None, + encoders: Sequence[Encoder | type[Encoder]] | None = None, ): self.template = template self.context = context or {} diff --git a/tests/test_responses.py b/tests/test_responses.py index 53178a9..845b9be 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -9,6 +9,7 @@ from lilya import status from lilya.background import Task +from lilya.encoders import Encoder from lilya.requests import Request from lilya.responses import ( Error, @@ -22,9 +23,28 @@ from lilya.testclient import TestClient +class Foo: ... + + +# check that encoders are saved as instances on responses +class FooEncoder(Encoder): + __type__ = Foo + + def serialize(self, obj: Foo) -> bool: + return True + + def encode( + self, + structure: type[Foo], + obj, + ): + return True + + def test_text_response(test_client_factory): async def app(scope, receive, send): - response = Response("hello, world", media_type="text/plain") + response = Response("hello, world", media_type="text/plain", encoders=[FooEncoder]) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -34,7 +54,8 @@ async def app(scope, receive, send): def test_ok_response(test_client_factory): async def app(scope, receive, send): - response = Ok("hello, world") + response = Ok("hello, world", encoders=[FooEncoder]) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -44,7 +65,8 @@ async def app(scope, receive, send): def test_error_response(test_client_factory): async def app(scope, receive, send): - response = Error("hello, world") + response = Error("hello, world", encoders=[FooEncoder]) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -55,7 +77,8 @@ async def app(scope, receive, send): def test_bytes_response(test_client_factory): async def app(scope, receive, send): - response = Response(b"xxxxx", media_type="image/png") + response = Response(b"xxxxx", media_type="image/png", encoders=[FooEncoder]) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -65,7 +88,8 @@ async def app(scope, receive, send): def test_json_none_response(test_client_factory): async def app(scope, receive, send): - response = JSONResponse(None) + response = JSONResponse(None, encoders=[FooEncoder]) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -77,9 +101,10 @@ async def app(scope, receive, send): def test_redirect_response(test_client_factory): async def app(scope, receive, send): if scope["path"] == "/": - response = Response("hello, world", media_type="text/plain") + response = Response("hello, world", media_type="text/plain", encoders=[FooEncoder]) else: - response = RedirectResponse("/") + response = RedirectResponse("/", encoders=[FooEncoder]) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -106,7 +131,10 @@ async def numbers_for_cleanup(start=1, stop=5): cleanup_task = Task(numbers_for_cleanup, start=6, stop=9) generator = numbers(1, 5) - response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task) + response = StreamingResponse( + generator, media_type="text/plain", background=cleanup_task, encoders=[FooEncoder] + ) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) assert filled_by_bg_task == "" @@ -131,7 +159,10 @@ async def __anext__(self): self._called += 1 return str(self._called) - response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") + response = StreamingResponse( + CustomAsyncIterator(), media_type="text/plain", encoders=[FooEncoder] + ) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -146,7 +177,10 @@ async def __aiter__(self): for i in range(5): yield str(i + 1) - response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain") + response = StreamingResponse( + CustomAsyncIterable(), media_type="text/plain", encoders=[FooEncoder] + ) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -163,7 +197,8 @@ def numbers(minimum, maximum): yield ", " generator = numbers(1, 5) - response = StreamingResponse(generator, media_type="text/plain") + response = StreamingResponse(generator, media_type="text/plain", encoders=[FooEncoder]) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) client = test_client_factory(app) @@ -219,7 +254,10 @@ async def numbers_for_cleanup(start=1, stop=5): cleanup_task = Task(numbers_for_cleanup, start=6, stop=9) async def app(scope, receive, send): - response = FileResponse(path=path, filename="example.png", background=cleanup_task) + response = FileResponse( + path=path, filename="example.png", background=cleanup_task, encoders=[FooEncoder] + ) + assert isinstance(response.encoders[0], FooEncoder) await response(scope, receive, send) assert filled_by_bg_task == "" @@ -289,7 +327,8 @@ def test_set_cookie(test_client_factory, monkeypatch): monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) async def app(scope, receive, send): - response = Response("Hello, world!", media_type="text/plain") + response = Response("Hello, world!", media_type="text/plain", encoders=[FooEncoder]) + assert isinstance(response.encoders[0], FooEncoder) response.set_cookie( "mycookie", "myvalue",