Skip to content

Commit

Permalink
ensure all encoders on a response are instances (#106)
Browse files Browse the repository at this point in the history
- Ensure all encoders on a response are instances
- Add test for check that encoders are saved as instances on responses
  • Loading branch information
devkral authored Nov 26, 2024
1 parent 894442f commit 892400a
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 25 deletions.
5 changes: 5 additions & 0 deletions docs/en/docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ hide:

## Unreleased

### Changed

- Encoders saved on responses are ensured to be instances and not classes.

### Fixed

- Crash when passing string to is_type_structure (e.g. string annotations).
- Fix Encoder type in responses.

## 0.11.1

Expand Down
4 changes: 2 additions & 2 deletions lilya/_internal/_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def json_encode(
*,
json_encode_fn: Callable[..., Any] = json.dumps,
post_transform_fn: Callable[[Any], Any] | None = json.loads,
with_encoders: Sequence[EncoderProtocol] | None = None,
with_encoders: Sequence[EncoderProtocol | MoldingProtocol] | None = None,
) -> Any:
"""
Encode a value to a JSON-compatible format using a list of encoder types.
Expand Down Expand Up @@ -286,7 +286,7 @@ def apply_structure(
structure: Any,
value: Any,
*,
with_encoders: Sequence[EncoderProtocol] | None = None,
with_encoders: Sequence[EncoderProtocol | MoldingProtocol] | None = None,
) -> Any:
"""
Apply structure to value. Decoding for e.g. input parameters
Expand Down
30 changes: 20 additions & 10 deletions lilya/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, Sequence
from datetime import datetime
from email.utils import format_datetime, formatdate
from inspect import isclass
from mimetypes import guess_type
from typing import (
Any,
Expand All @@ -26,15 +27,18 @@
from lilya.compat import md5_hexdigest
from lilya.concurrency import iterate_in_threadpool
from lilya.datastructures import URL, Header
from lilya.encoders import ENCODER_TYPES, Encoder, json_encode
from lilya.encoders import ENCODER_TYPES, EncoderProtocol, MoldingProtocol, json_encode
from lilya.enums import Event, HTTPMethod, MediaType
from lilya.types import Receive, Scope, Send

Content = Union[str, bytes]
Encoder = Union[EncoderProtocol, MoldingProtocol]
SyncContentStream = Iterable[Content]
AsyncContentStream = AsyncIterable[Content]
ContentStream = Union[AsyncContentStream, SyncContentStream]

_empty: tuple[Any, ...] = ()


class Response:
media_type: str | None = None
Expand All @@ -49,15 +53,17 @@ def __init__(
cookies: Mapping[str, str] | Any | None = None,
media_type: str | None = None,
background: Task | None = None,
encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None,
encoders: Sequence[Encoder | type[Encoder]] | None = None,
) -> None:
if status_code is not None:
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.cookies = cookies
self.encoders = encoders or []
self.encoders: list[Encoder] = [
encoder() if isclass(encoder) else encoder for encoder in encoders or _empty
]

self.body = self.make_response(content)
self.raw_headers: list[Any] = []
Expand Down Expand Up @@ -249,7 +255,7 @@ def __init__(
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: Task | None = None,
encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None,
encoders: Sequence[Encoder | type[Encoder]] | None = None,
) -> None:
super().__init__(
content=content,
Expand Down Expand Up @@ -281,7 +287,7 @@ def __init__(
status_code: int = status.HTTP_307_TEMPORARY_REDIRECT,
headers: Mapping[str, str] | None = None,
background: Task | None = None,
encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None,
encoders: Sequence[Encoder | type[Encoder]] | None = None,
) -> None:
super().__init__(
content=b"",
Expand All @@ -303,9 +309,11 @@ def __init__(
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: Task | None = None,
encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None,
encoders: Sequence[Encoder | type[Encoder]] | None = None,
) -> None:
self.encoders = encoders or []
self.encoders: list[Encoder] = [
encoder() if isclass(encoder) else encoder for encoder in encoders or _empty
]

if isinstance(content, AsyncIterable):
self.body_iterator = content
Expand Down Expand Up @@ -360,7 +368,7 @@ def __init__(
stat_result: os.stat_result | None = None,
method: str | None = None,
content_disposition_type: str = "attachment",
encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None,
encoders: Sequence[Encoder | type[Encoder]] | None = None,
) -> None:
self.path = path
self.status_code = status_code
Expand All @@ -371,7 +379,9 @@ def __init__(
self.media_type = media_type
self.background = background

self.encoders = encoders or []
self.encoders: list[Encoder] = [
encoder() if isclass(encoder) else encoder for encoder in encoders or _empty
]
self.make_headers(headers)

if self.filename is not None:
Expand Down Expand Up @@ -440,7 +450,7 @@ def __init__(
background: Task | None = None,
headers: dict[str, Any] | None = None,
media_type: MediaType | str = MediaType.HTML,
encoders: Sequence[Encoder] | Sequence[type[Encoder]] | None = None,
encoders: Sequence[Encoder | type[Encoder]] | None = None,
):
self.template = template
self.context = context or {}
Expand Down
65 changes: 52 additions & 13 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from lilya import status
from lilya.background import Task
from lilya.encoders import Encoder
from lilya.requests import Request
from lilya.responses import (
Error,
Expand All @@ -22,9 +23,28 @@
from lilya.testclient import TestClient


class Foo: ...


# check that encoders are saved as instances on responses
class FooEncoder(Encoder):
__type__ = Foo

def serialize(self, obj: Foo) -> bool:
return True

def encode(
self,
structure: type[Foo],
obj,
):
return True


def test_text_response(test_client_factory):
async def app(scope, receive, send):
response = Response("hello, world", media_type="text/plain")
response = Response("hello, world", media_type="text/plain", encoders=[FooEncoder])
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand All @@ -34,7 +54,8 @@ async def app(scope, receive, send):

def test_ok_response(test_client_factory):
async def app(scope, receive, send):
response = Ok("hello, world")
response = Ok("hello, world", encoders=[FooEncoder])
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand All @@ -44,7 +65,8 @@ async def app(scope, receive, send):

def test_error_response(test_client_factory):
async def app(scope, receive, send):
response = Error("hello, world")
response = Error("hello, world", encoders=[FooEncoder])
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand All @@ -55,7 +77,8 @@ async def app(scope, receive, send):

def test_bytes_response(test_client_factory):
async def app(scope, receive, send):
response = Response(b"xxxxx", media_type="image/png")
response = Response(b"xxxxx", media_type="image/png", encoders=[FooEncoder])
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand All @@ -65,7 +88,8 @@ async def app(scope, receive, send):

def test_json_none_response(test_client_factory):
async def app(scope, receive, send):
response = JSONResponse(None)
response = JSONResponse(None, encoders=[FooEncoder])
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand All @@ -77,9 +101,10 @@ async def app(scope, receive, send):
def test_redirect_response(test_client_factory):
async def app(scope, receive, send):
if scope["path"] == "/":
response = Response("hello, world", media_type="text/plain")
response = Response("hello, world", media_type="text/plain", encoders=[FooEncoder])
else:
response = RedirectResponse("/")
response = RedirectResponse("/", encoders=[FooEncoder])
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand All @@ -106,7 +131,10 @@ async def numbers_for_cleanup(start=1, stop=5):

cleanup_task = Task(numbers_for_cleanup, start=6, stop=9)
generator = numbers(1, 5)
response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task)
response = StreamingResponse(
generator, media_type="text/plain", background=cleanup_task, encoders=[FooEncoder]
)
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

assert filled_by_bg_task == ""
Expand All @@ -131,7 +159,10 @@ async def __anext__(self):
self._called += 1
return str(self._called)

response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain")
response = StreamingResponse(
CustomAsyncIterator(), media_type="text/plain", encoders=[FooEncoder]
)
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand All @@ -146,7 +177,10 @@ async def __aiter__(self):
for i in range(5):
yield str(i + 1)

response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain")
response = StreamingResponse(
CustomAsyncIterable(), media_type="text/plain", encoders=[FooEncoder]
)
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand All @@ -163,7 +197,8 @@ def numbers(minimum, maximum):
yield ", "

generator = numbers(1, 5)
response = StreamingResponse(generator, media_type="text/plain")
response = StreamingResponse(generator, media_type="text/plain", encoders=[FooEncoder])
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

client = test_client_factory(app)
Expand Down Expand Up @@ -219,7 +254,10 @@ async def numbers_for_cleanup(start=1, stop=5):
cleanup_task = Task(numbers_for_cleanup, start=6, stop=9)

async def app(scope, receive, send):
response = FileResponse(path=path, filename="example.png", background=cleanup_task)
response = FileResponse(
path=path, filename="example.png", background=cleanup_task, encoders=[FooEncoder]
)
assert isinstance(response.encoders[0], FooEncoder)
await response(scope, receive, send)

assert filled_by_bg_task == ""
Expand Down Expand Up @@ -289,7 +327,8 @@ def test_set_cookie(test_client_factory, monkeypatch):
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())

async def app(scope, receive, send):
response = Response("Hello, world!", media_type="text/plain")
response = Response("Hello, world!", media_type="text/plain", encoders=[FooEncoder])
assert isinstance(response.encoders[0], FooEncoder)
response.set_cookie(
"mycookie",
"myvalue",
Expand Down

0 comments on commit 892400a

Please sign in to comment.