diff --git a/lilya/context.py b/lilya/context.py index 21acbda..d89ba6b 100644 --- a/lilya/context.py +++ b/lilya/context.py @@ -3,7 +3,6 @@ import contextvars import copy import warnings -from functools import lru_cache from typing import TYPE_CHECKING, Annotated, Any, cast from lilya.datastructures import URL @@ -233,12 +232,50 @@ def __repr__(self: G) -> str: return f"{self.__class__.__name__}()" -@lru_cache +g_context: contextvars.ContextVar[G] = contextvars.ContextVar("g_context") + + def get_g() -> G: - return G() + return g_context.get() + + +class LazyGProxy: + """ + A proxy to lazily evaluate attributes of the current request. + """ + @property + def store(self) -> Any: + return get_g().store -g = get_g() + def __getattr__(self, item: str) -> Any: + return getattr(get_g(), item) + + def __setattr__(self, key: Any, value: Any) -> None: + get_g().__setattr__(key, value) + + def __delattr__(self, key: Any) -> None: + delattr(get_g(), key) + + def __copy__(self) -> G: + try: + return get_g() + except LookupError: + return G() + + def __len__(self) -> int: + return len(get_g()) + + def __getitem__(self, key: str) -> Any: + return get_g()[key] + + def __repr__(self: LazyRequestProxy) -> str: + if get_g() is None: + return "" + return '' + + +g: G = cast("G", LazyGProxy()) class LazyRequestProxy: @@ -258,7 +295,7 @@ def __getattr__(self, item: str) -> Any: return getattr(request, item) def __repr__(self: LazyRequestProxy) -> str: - if self._request_getter is None: + if self._request_getter() is None: return "" return '' @@ -275,11 +312,11 @@ class RequestContext: lazy_request = LazyRequestProxy(lambda: RequestContext._request_context.get(None)) @classmethod - def set_request(cls, request: Request) -> None: + def set_request(cls, request: Request) -> contextvars.Token: """ Set the current request in the context. """ - cls._request_context.set(request) + return cls._request_context.set(request) @classmethod def get_request(cls) -> Request: diff --git a/lilya/middleware/global_context.py b/lilya/middleware/global_context.py index d37a3c6..0ac8133 100644 --- a/lilya/middleware/global_context.py +++ b/lilya/middleware/global_context.py @@ -2,7 +2,7 @@ from abc import ABC -from lilya.context import g +from lilya.context import G, g_context from lilya.protocols.middleware import MiddlewareProtocol from lilya.types import ASGIApp, Receive, Scope, Send @@ -47,7 +47,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: Returns: None """ + token = g_context.set(G()) try: await self.app(scope, receive, send) finally: - g.clear() + g_context.reset(token) diff --git a/lilya/middleware/request_context.py b/lilya/middleware/request_context.py index 500ea69..94d51ff 100644 --- a/lilya/middleware/request_context.py +++ b/lilya/middleware/request_context.py @@ -55,7 +55,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return global_request = Request(scope, receive) - token = RequestContext._request_context.set(global_request) + token = RequestContext.set_request(global_request) try: await self.app(scope, receive, send) finally: diff --git a/tests/test_global_context.py b/tests/test_global_context.py index d24f257..7126f06 100644 --- a/tests/test_global_context.py +++ b/tests/test_global_context.py @@ -1,6 +1,8 @@ import functools -from lilya.context import g +import pytest + +from lilya.context import g, get_g as lilya_get_g from lilya.routing import Path from lilya.testclient import create_client @@ -14,14 +16,23 @@ async def show_g() -> dict[str, str]: return g.store -def test_global_context(): +async def activate_g() -> dict[str, str]: activate_stuff() + return await show_g() - with create_client(routes=[Path("/show", show_g)]) as client: - response = client.get("/show") + +def test_global_context(): + with pytest.raises(LookupError): + lilya_get_g() + with create_client(routes=[Path("/activate", activate_g), Path("/show", show_g)]) as client: + response = client.get("/activate") assert response.status_code == 200 assert response.json() == {"name": "Lilya", "age": 25} + response = client.get("/show") + assert response.status_code == 200 + assert response.json() == {} + def test_empty_global_context(): with create_client(routes=[Path("/show", show_g)]) as client: