Skip to content

Commit

Permalink
fixes context, make global_context threadsafe (#131)
Browse files Browse the repository at this point in the history
Changes:

- make global context request specific as documented and thread safe
- fix repr
- expose token in set and cleanup
  • Loading branch information
devkral authored Jan 17, 2025
1 parent d830822 commit 967ce88
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 14 deletions.
51 changes: 44 additions & 7 deletions lilya/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextvars
import copy
import warnings
from functools import lru_cache
from typing import TYPE_CHECKING, Annotated, Any, cast

from lilya.datastructures import URL
Expand Down Expand Up @@ -233,12 +232,50 @@ def __repr__(self: G) -> str:
return f"{self.__class__.__name__}()"


@lru_cache
g_context: contextvars.ContextVar[G] = contextvars.ContextVar("g_context")


def get_g() -> G:
return G()
return g_context.get()


class LazyGProxy:
"""
A proxy to lazily evaluate attributes of the current request.
"""

@property
def store(self) -> Any:
return get_g().store

g = get_g()
def __getattr__(self, item: str) -> Any:
return getattr(get_g(), item)

def __setattr__(self, key: Any, value: Any) -> None:
get_g().__setattr__(key, value)

def __delattr__(self, key: Any) -> None:
delattr(get_g(), key)

def __copy__(self) -> G:
try:
return get_g()
except LookupError:
return G()

def __len__(self) -> int:
return len(get_g())

def __getitem__(self, key: str) -> Any:
return get_g()[key]

def __repr__(self: LazyRequestProxy) -> str:
if get_g() is None:
return "<LazyGProxy [Unevaluated]>"
return '<LazyGProxy "G()">'


g: G = cast("G", LazyGProxy())


class LazyRequestProxy:
Expand All @@ -258,7 +295,7 @@ def __getattr__(self, item: str) -> Any:
return getattr(request, item)

def __repr__(self: LazyRequestProxy) -> str:
if self._request_getter is None:
if self._request_getter() is None:
return "<LazyRequestProxy [Unevaluated]>"
return '<LazyRequestProxy "Request()">'

Expand All @@ -275,11 +312,11 @@ class RequestContext:
lazy_request = LazyRequestProxy(lambda: RequestContext._request_context.get(None))

@classmethod
def set_request(cls, request: Request) -> None:
def set_request(cls, request: Request) -> contextvars.Token:
"""
Set the current request in the context.
"""
cls._request_context.set(request)
return cls._request_context.set(request)

@classmethod
def get_request(cls) -> Request:
Expand Down
5 changes: 3 additions & 2 deletions lilya/middleware/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC

from lilya.context import g
from lilya.context import G, g_context
from lilya.protocols.middleware import MiddlewareProtocol
from lilya.types import ASGIApp, Receive, Scope, Send

Expand Down Expand Up @@ -47,7 +47,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Returns:
None
"""
token = g_context.set(G())
try:
await self.app(scope, receive, send)
finally:
g.clear()
g_context.reset(token)
2 changes: 1 addition & 1 deletion lilya/middleware/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return

global_request = Request(scope, receive)
token = RequestContext._request_context.set(global_request)
token = RequestContext.set_request(global_request)
try:
await self.app(scope, receive, send)
finally:
Expand Down
19 changes: 15 additions & 4 deletions tests/test_global_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools

from lilya.context import g
import pytest

from lilya.context import g, get_g as lilya_get_g
from lilya.routing import Path
from lilya.testclient import create_client

Expand All @@ -14,14 +16,23 @@ async def show_g() -> dict[str, str]:
return g.store


def test_global_context():
async def activate_g() -> dict[str, str]:
activate_stuff()
return await show_g()

with create_client(routes=[Path("/show", show_g)]) as client:
response = client.get("/show")

def test_global_context():
with pytest.raises(LookupError):
lilya_get_g()
with create_client(routes=[Path("/activate", activate_g), Path("/show", show_g)]) as client:
response = client.get("/activate")
assert response.status_code == 200
assert response.json() == {"name": "Lilya", "age": 25}

response = client.get("/show")
assert response.status_code == 200
assert response.json() == {}


def test_empty_global_context():
with create_client(routes=[Path("/show", show_g)]) as client:
Expand Down

0 comments on commit 967ce88

Please sign in to comment.