Skip to content

Commit

Permalink
Add tests for requires with Security
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Jan 13, 2025
1 parent 48c7ef1 commit 22e8329
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 12 deletions.
7 changes: 6 additions & 1 deletion esmerald/transformers/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_origin,
)

from lilya.exceptions import HTTPException as LilyaHTTPException
from orjson import loads
from pydantic import ValidationError, create_model
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -200,6 +201,10 @@ def encode_value(encoder: "Encoder", annotation: Any, value: Any) -> Any:
decoded_list = extract_arguments(annotation)
annotation = decoded_list[0] # type: ignore

if is_requires(value):
kwargs[key] = await async_resolve_dependencies(value.dependency)
continue

kwargs[key] = encode_value(encoder, annotation, value)

return kwargs
Expand Down Expand Up @@ -296,7 +301,7 @@ def extract_error_message(exception: Exception) -> Dict[str, Any]:
return str(exception) # type: ignore

try:
if isinstance(exception, HTTPException):
if isinstance(exception, (HTTPException, LilyaHTTPException)):
return exception

method, url = get_connection_info(connection)
Expand Down
5 changes: 5 additions & 0 deletions esmerald/utils/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Union

from lilya.compat import run_sync
from lilya.context import request_context

from esmerald import params
from esmerald.security.scopes import Scopes
Expand Down Expand Up @@ -78,6 +79,10 @@ async def async_resolve_dependencies(func: Any, overrides: Union[Dict[str, Any]]
kwargs = {}

for name, param in signature.parameters.items():
# If in one of the requirements happens to be Security, we need to resolve it
# By passing the Request object to the dependency
if isinstance(param.default, params.Security):
kwargs[name] = await param.default.dependency(request_context)
if isinstance(param.default, params.Requires):
dep_func = param.default.dependency
dep_func = overrides.get(dep_func, dep_func) # type: ignore
Expand Down
29 changes: 18 additions & 11 deletions tests/dependencies/test_injects_with_fastapi_examples.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from typing import AsyncGenerator, Generator
from typing import Any, AsyncGenerator, Callable, Generator, Optional

import pytest

from esmerald import Esmerald, Gateway, get
from esmerald.param_functions import Requires
from esmerald import Esmerald, Gateway, get, params
from esmerald.testclient import EsmeraldTestClient


def BaseRequires(
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
) -> Any:
return params.BaseRequires(dependency=dependency, use_cache=use_cache)


class CallableDependency: # pragma: no cover
def __call__(self, value: str) -> str:
return value
Expand Down Expand Up @@ -49,53 +56,53 @@ async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:


@get("/callable-dependency")
async def get_callable_dependency(value: str = Requires(callable_dependency)) -> str:
async def get_callable_dependency(value: str = BaseRequires(callable_dependency)) -> str:
return value


@get("/callable-gen-dependency")
async def get_callable_gen_dependency(value: str = Requires(callable_gen_dependency)) -> str:
async def get_callable_gen_dependency(value: str = BaseRequires(callable_gen_dependency)) -> str:
return value


@get("/async-callable-dependency")
async def get_async_callable_dependency(
value: str = Requires(async_callable_dependency),
value: str = BaseRequires(async_callable_dependency),
) -> str:
return value


@get("/async-callable-gen-dependency")
async def get_async_callable_gen_dependency(
value: str = Requires(async_callable_gen_dependency),
value: str = BaseRequires(async_callable_gen_dependency),
) -> str:
return value


@get("/synchronous-method-dependency")
async def get_synchronous_method_dependency(
value: str = Requires(methods_dependency.synchronous),
value: str = BaseRequires(methods_dependency.synchronous),
) -> str:
return value


@get("/synchronous-method-gen-dependency")
async def get_synchronous_method_gen_dependency(
value: str = Requires(methods_dependency.synchronous_gen),
value: str = BaseRequires(methods_dependency.synchronous_gen),
) -> str:
return value


@get("/asynchronous-method-dependency")
async def get_asynchronous_method_dependency(
value: str = Requires(methods_dependency.asynchronous),
value: str = BaseRequires(methods_dependency.asynchronous),
) -> str:
return value


@get("/asynchronous-method-gen-dependency")
async def get_asynchronous_method_gen_dependency(
value: str = Requires(methods_dependency.asynchronous_gen),
value: str = BaseRequires(methods_dependency.asynchronous_gen),
) -> str:
return value

Expand Down
160 changes: 160 additions & 0 deletions tests/security/http/test_security_api_key_with_requires.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import Any

from lilya.middleware import DefineMiddleware
from lilya.middleware.request_context import RequestContextMiddleware
from pydantic import BaseModel

from esmerald import Gateway, Requires, Security, get
from esmerald.security.api_key import APIKeyInCookie
from esmerald.testclient import create_client

api_key = APIKeyInCookie(name="key")


class User(BaseModel):
username: str


def get_current_user(oauth_header: str = Security(api_key)):
user = User(username=oauth_header)
return user


@get("/users/me", security=[api_key])
def read_current_user(current_user: User = Requires(get_current_user)) -> Any:
return current_user


def test_security_api_key():
with create_client(
routes=[
Gateway(handler=read_current_user),
],
middleware=[DefineMiddleware(RequestContextMiddleware)],
) as client:
response = client.get("/users/me", cookies={"key": "secret"})
assert response.status_code == 200, response.text
assert response.json() == {"username": "secret"}


def test_security_api_key_no_key():
with create_client(
routes=[
Gateway(handler=read_current_user),
],
middleware=[DefineMiddleware(RequestContextMiddleware)],
) as client:
response = client.get("/users/me")
assert response.status_code == 403, response.text
assert response.json() == {"detail": "Not authenticated"}


def test_openapi_schema():
with create_client(
routes=[
Gateway(handler=read_current_user),
],
enable_openapi=True,
) as client:
response = client.get("/openapi.json")
assert response.status_code == 200, response.text

assert response.json() == {
"openapi": "3.1.0",
"info": {
"title": "Esmerald",
"summary": "Esmerald application",
"description": "Highly scalable, performant, easy to learn and for every application.",
"contact": {"name": "admin", "email": "[email protected]"},
"version": client.app.version,
},
"servers": [{"url": "/"}],
"paths": {
"/users/me": {
"get": {
"summary": "Read Current User",
"description": "",
"operationId": "read_current_user_users_me_get",
"deprecated": False,
"security": [
{
"APIKeyInCookie": {
"type": "apiKey",
"name": "key",
"in": "cookie",
"scheme_name": "APIKeyInCookie",
}
}
],
"requestBody": {
"required": True,
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/DataField"}
}
},
},
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "string"}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
}
}
},
"components": {
"schemas": {
"DataField": {
"properties": {"current_user": {"$ref": "#/components/schemas/User"}},
"type": "object",
"required": ["current_user"],
"title": "DataField",
},
"HTTPValidationError": {
"properties": {
"detail": {
"items": {"$ref": "#/components/schemas/ValidationError"},
"type": "array",
"title": "Detail",
}
},
"type": "object",
"title": "HTTPValidationError",
},
"User": {
"properties": {"username": {"type": "string", "title": "Username"}},
"type": "object",
"required": ["username"],
"title": "User",
},
"ValidationError": {
"properties": {
"loc": {
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
"type": "array",
"title": "Location",
},
"msg": {"type": "string", "title": "Message"},
"type": {"type": "string", "title": "Error Type"},
},
"type": "object",
"required": ["loc", "msg", "type"],
"title": "ValidationError",
},
},
"securitySchemes": {
"APIKeyInCookie": {"type": "apiKey", "name": "key", "in": "cookie"}
},
},
}

0 comments on commit 22e8329

Please sign in to comment.