diff --git a/esmerald/transformers/signature.py b/esmerald/transformers/signature.py index 5bed2438..dd9def3a 100644 --- a/esmerald/transformers/signature.py +++ b/esmerald/transformers/signature.py @@ -16,6 +16,7 @@ get_origin, ) +from lilya.exceptions import HTTPException as LilyaHTTPException from orjson import loads from pydantic import ValidationError, create_model from pydantic.fields import FieldInfo @@ -200,6 +201,10 @@ def encode_value(encoder: "Encoder", annotation: Any, value: Any) -> Any: decoded_list = extract_arguments(annotation) annotation = decoded_list[0] # type: ignore + if is_requires(value): + kwargs[key] = await async_resolve_dependencies(value.dependency) + continue + kwargs[key] = encode_value(encoder, annotation, value) return kwargs @@ -296,7 +301,7 @@ def extract_error_message(exception: Exception) -> Dict[str, Any]: return str(exception) # type: ignore try: - if isinstance(exception, HTTPException): + if isinstance(exception, (HTTPException, LilyaHTTPException)): return exception method, url = get_connection_info(connection) diff --git a/esmerald/utils/dependencies.py b/esmerald/utils/dependencies.py index 12419710..4f064ff9 100644 --- a/esmerald/utils/dependencies.py +++ b/esmerald/utils/dependencies.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Union from lilya.compat import run_sync +from lilya.context import request_context from esmerald import params from esmerald.security.scopes import Scopes @@ -78,6 +79,10 @@ async def async_resolve_dependencies(func: Any, overrides: Union[Dict[str, Any]] kwargs = {} for name, param in signature.parameters.items(): + # If in one of the requirements happens to be Security, we need to resolve it + # By passing the Request object to the dependency + if isinstance(param.default, params.Security): + kwargs[name] = await param.default.dependency(request_context) if isinstance(param.default, params.Requires): dep_func = param.default.dependency dep_func = overrides.get(dep_func, dep_func) # type: ignore diff --git a/tests/dependencies/test_injects_with_fastapi_examples.py b/tests/dependencies/test_injects_with_fastapi_examples.py index b0fe40d5..b7ab684d 100644 --- a/tests/dependencies/test_injects_with_fastapi_examples.py +++ b/tests/dependencies/test_injects_with_fastapi_examples.py @@ -1,12 +1,19 @@ -from typing import AsyncGenerator, Generator +from typing import Any, AsyncGenerator, Callable, Generator, Optional import pytest -from esmerald import Esmerald, Gateway, get -from esmerald.param_functions import Requires +from esmerald import Esmerald, Gateway, get, params from esmerald.testclient import EsmeraldTestClient +def BaseRequires( + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True, +) -> Any: + return params.BaseRequires(dependency=dependency, use_cache=use_cache) + + class CallableDependency: # pragma: no cover def __call__(self, value: str) -> str: return value @@ -49,53 +56,53 @@ async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]: @get("/callable-dependency") -async def get_callable_dependency(value: str = Requires(callable_dependency)) -> str: +async def get_callable_dependency(value: str = BaseRequires(callable_dependency)) -> str: return value @get("/callable-gen-dependency") -async def get_callable_gen_dependency(value: str = Requires(callable_gen_dependency)) -> str: +async def get_callable_gen_dependency(value: str = BaseRequires(callable_gen_dependency)) -> str: return value @get("/async-callable-dependency") async def get_async_callable_dependency( - value: str = Requires(async_callable_dependency), + value: str = BaseRequires(async_callable_dependency), ) -> str: return value @get("/async-callable-gen-dependency") async def get_async_callable_gen_dependency( - value: str = Requires(async_callable_gen_dependency), + value: str = BaseRequires(async_callable_gen_dependency), ) -> str: return value @get("/synchronous-method-dependency") async def get_synchronous_method_dependency( - value: str = Requires(methods_dependency.synchronous), + value: str = BaseRequires(methods_dependency.synchronous), ) -> str: return value @get("/synchronous-method-gen-dependency") async def get_synchronous_method_gen_dependency( - value: str = Requires(methods_dependency.synchronous_gen), + value: str = BaseRequires(methods_dependency.synchronous_gen), ) -> str: return value @get("/asynchronous-method-dependency") async def get_asynchronous_method_dependency( - value: str = Requires(methods_dependency.asynchronous), + value: str = BaseRequires(methods_dependency.asynchronous), ) -> str: return value @get("/asynchronous-method-gen-dependency") async def get_asynchronous_method_gen_dependency( - value: str = Requires(methods_dependency.asynchronous_gen), + value: str = BaseRequires(methods_dependency.asynchronous_gen), ) -> str: return value diff --git a/tests/security/http/test_security_api_key_with_requires.py b/tests/security/http/test_security_api_key_with_requires.py new file mode 100644 index 00000000..2b982042 --- /dev/null +++ b/tests/security/http/test_security_api_key_with_requires.py @@ -0,0 +1,160 @@ +from typing import Any + +from lilya.middleware import DefineMiddleware +from lilya.middleware.request_context import RequestContextMiddleware +from pydantic import BaseModel + +from esmerald import Gateway, Requires, Security, get +from esmerald.security.api_key import APIKeyInCookie +from esmerald.testclient import create_client + +api_key = APIKeyInCookie(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@get("/users/me", security=[api_key]) +def read_current_user(current_user: User = Requires(get_current_user)) -> Any: + return current_user + + +def test_security_api_key(): + with create_client( + routes=[ + Gateway(handler=read_current_user), + ], + middleware=[DefineMiddleware(RequestContextMiddleware)], + ) as client: + response = client.get("/users/me", cookies={"key": "secret"}) + assert response.status_code == 200, response.text + assert response.json() == {"username": "secret"} + + +def test_security_api_key_no_key(): + with create_client( + routes=[ + Gateway(handler=read_current_user), + ], + middleware=[DefineMiddleware(RequestContextMiddleware)], + ) as client: + response = client.get("/users/me") + assert response.status_code == 403, response.text + assert response.json() == {"detail": "Not authenticated"} + + +def test_openapi_schema(): + with create_client( + routes=[ + Gateway(handler=read_current_user), + ], + enable_openapi=True, + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/users/me": { + "get": { + "summary": "Read Current User", + "description": "", + "operationId": "read_current_user_users_me_get", + "deprecated": False, + "security": [ + { + "APIKeyInCookie": { + "type": "apiKey", + "name": "key", + "in": "cookie", + "scheme_name": "APIKeyInCookie", + } + } + ], + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/DataField"} + } + }, + }, + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "DataField": { + "properties": {"current_user": {"$ref": "#/components/schemas/User"}}, + "type": "object", + "required": ["current_user"], + "title": "DataField", + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "User": { + "properties": {"username": {"type": "string", "title": "Username"}}, + "type": "object", + "required": ["username"], + "title": "User", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + }, + "securitySchemes": { + "APIKeyInCookie": {"type": "apiKey", "name": "key", "in": "cookie"} + }, + }, + }