From 99f4aa467f3a08c4ff89381e7d5d80953b5ed50e Mon Sep 17 00:00:00 2001 From: tarsil Date: Mon, 23 Oct 2023 17:20:00 +0100 Subject: [PATCH] Fix inheritance of tags * Add annotation for tags --- esmerald/applications.py | 142 +++++++++++++++++++++++++++++-- esmerald/conf/global_settings.py | 4 +- esmerald/config/openapi.py | 4 +- esmerald/openapi/models.py | 2 +- esmerald/openapi/openapi.py | 11 +-- esmerald/routing/base.py | 15 ++++ esmerald/routing/gateways.py | 8 ++ esmerald/routing/router.py | 3 + esmerald/testclient.py | 4 +- tests/openapi/test_tags.py | 99 +++++++++++++++++++++ 10 files changed, 273 insertions(+), 19 deletions(-) create mode 100644 tests/openapi/test_tags.py diff --git a/esmerald/applications.py b/esmerald/applications.py index d8461401..4382ac27 100644 --- a/esmerald/applications.py +++ b/esmerald/applications.py @@ -14,7 +14,7 @@ cast, ) -from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme, Tag +from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme from openapi_schemas_pydantic.v3_1_0.open_api import OpenAPI from pydantic import AnyUrl, ValidationError from starlette.applications import Starlette @@ -1123,10 +1123,142 @@ async def validation_error_exception_handler( """ ), ] = None, - on_startup: Optional[List["LifeSpanHandler"]] = None, - on_shutdown: Optional[List["LifeSpanHandler"]] = None, - lifespan: Optional[Lifespan[AppType]] = None, - tags: Optional[List[Tag]] = None, + on_startup: Annotated[ + Optional[List["LifeSpanHandler"]], + Doc( + """ + A `list` of events that are trigger upon the application + starts. + + Read more about the [events](https://esmerald.dev/lifespan-events/). + + **Example** + + ```python + from pydantic import BaseModel + from saffier import Database, Registry + + from esmerald import Esmerald, Gateway, post + + database = Database("postgresql+asyncpg://user:password@host:port/database") + registry = Registry(database=database) + + + class User(BaseModel): + name: str + email: str + password: str + retype_password: str + + + @post("/create", tags=["user"], description="Creates a new user in the database") + async def create_user(data: User) -> None: + # Logic to create the user + ... + + + app = Esmerald( + routes=[Gateway(handler=create_user)], + on_startup=[database.connect], + ) + ``` + """ + ), + ] = None, + on_shutdown: Annotated[ + Optional[List["LifeSpanHandler"]], + Doc( + """ + A `list` of events that are trigger upon the application + shuts down. + + Read more about the [events](https://esmerald.dev/lifespan-events/). + + **Example** + + ```python + from pydantic import BaseModel + from saffier import Database, Registry + + from esmerald import Esmerald, Gateway, post + + database = Database("postgresql+asyncpg://user:password@host:port/database") + registry = Registry(database=database) + + + class User(BaseModel): + name: str + email: str + password: str + retype_password: str + + + @post("/create", tags=["user"], description="Creates a new user in the database") + async def create_user(data: User) -> None: + # Logic to create the user + ... + + + app = Esmerald( + routes=[Gateway(handler=create_user)], + on_shutdown=[database.disconnect], + ) + ``` + """ + ), + ] = None, + lifespan: Annotated[ + Optional[Lifespan[AppType]], + Doc( + """ + A `lifespan` context manager handler. This is an alternative + to `on_startup` and `on_shutdown` and you **cannot used all combined**. + + Read more about the [lifespan](https://esmerald.dev/lifespan-events/). + """ + ), + ] = None, + tags: Annotated[ + Optional[List[str]], + Doc( + """ + A list of strings/enums tags to be applied to the *path operation*. + + It will be added to the generated OpenAPI documentation. + + **Note** almost everything in Esmerald can be done in [levels](https://esmerald.dev/application/levels/), which means + these tags on a Esmerald instance, means it will be added to every route even + if those routes also contain tags. + + **Example** + + ```python + from esmerald import Esmerald + + app = Esmerald(tags=["application"]) + ``` + + **Example with nested routes** + + When tags are added on a level bases, those are concatenated into the + final handler. + + ```python + from esmerald import Esmerald, Gateway, get + + @get("/home", tags=["home"]) + async def home() -> Dict[str, str]: + return {"hello": "world"} + + + app = Esmerald( + routes=[Gateway(handler=home)], + tags=["application"] + ) + ``` + """ + ), + ] = None, include_in_schema: Optional[bool] = None, deprecated: Optional[bool] = None, enable_openapi: Optional[bool] = None, diff --git a/esmerald/conf/global_settings.py b/esmerald/conf/global_settings.py index 06eb3307..b2db8e1b 100644 --- a/esmerald/conf/global_settings.py +++ b/esmerald/conf/global_settings.py @@ -1,7 +1,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union -from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme, Tag +from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme from pydantic import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.types import Lifespan @@ -51,7 +51,7 @@ class EsmeraldAPISettings(BaseSettings): response_cookies: Optional[ResponseCookies] = None response_headers: Optional[ResponseHeaders] = None include_in_schema: bool = True - tags: Optional[List[Tag]] = None + tags: Optional[List[str]] = None timezone: str = "UTC" use_tz: bool = False root_path: Optional[str] = "" diff --git a/esmerald/config/openapi.py b/esmerald/config/openapi.py index 9d140cde..1cf2c2eb 100644 --- a/esmerald/config/openapi.py +++ b/esmerald/config/openapi.py @@ -9,7 +9,7 @@ get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html, ) -from esmerald.openapi.models import Contact, License, Tag +from esmerald.openapi.models import Contact, License from esmerald.openapi.openapi import get_openapi from esmerald.requests import Request from esmerald.responses import HTMLResponse, JSONResponse @@ -26,7 +26,7 @@ class OpenAPIConfig(BaseModel): license: Optional[License] = None security: Optional[List[SecurityScheme]] = None servers: Optional[List[Dict[str, Union[str, Any]]]] = None - tags: Optional[List[Tag]] = None + tags: Optional[List[str]] = None openapi_version: Optional[str] = None openapi_url: Optional[str] = None root_path_in_servers: bool = True diff --git a/esmerald/openapi/models.py b/esmerald/openapi/models.py index ff31da89..0f9f6733 100644 --- a/esmerald/openapi/models.py +++ b/esmerald/openapi/models.py @@ -110,6 +110,6 @@ class OpenAPI(BaseModel): webhooks: Optional[Dict[str, Union[PathItem, Reference]]] = None components: Optional[Components] = None security: Optional[List[Dict[str, List[str]]]] = None - tags: Optional[List[Tag]] = None + tags: Optional[List[str]] = None externalDocs: Optional[ExternalDocumentation] = None model_config = ConfigDict(extra="allow") diff --git a/esmerald/openapi/openapi.py b/esmerald/openapi/openapi.py index fc8b5d84..c5a96ea9 100644 --- a/esmerald/openapi/openapi.py +++ b/esmerald/openapi/openapi.py @@ -22,7 +22,6 @@ Operation, Parameter, SecurityScheme, - Tag, ) from esmerald.openapi.responses import create_internal_response from esmerald.openapi.utils import ( @@ -113,12 +112,12 @@ def get_fields_from_routes( def get_openapi_operation( - *, route: Union[router.HTTPHandler, Any], method: str, operation_ids: Set[str] + *, route: Union[router.HTTPHandler, Any], operation_ids: Set[str] ) -> Dict[str, Any]: # pragma: no cover operation = Operation() if route.tags: - operation.tags = cast("List[str]", route.tags) + operation.tags = route.get_handler_tags() if route.summary: operation.summary = route.summary @@ -240,9 +239,7 @@ def get_openapi_path( # For each method for method in route.handler.methods: - operation = get_openapi_operation( - route=handler, method=method, operation_ids=operation_ids - ) + operation = get_openapi_operation(route=handler, operation_ids=operation_ids) # If the parent if marked as deprecated, it takes precedence if is_deprecated or route.deprecated: operation["deprecated"] = is_deprecated if is_deprecated else route.deprecated @@ -411,7 +408,7 @@ def get_openapi( summary: Optional[str] = None, description: Optional[str] = None, routes: Sequence[BaseRoute], - tags: Optional[List[Tag]] = None, + tags: Optional[List[str]] = None, servers: Optional[List[Dict[str, Union[str, Any]]]] = None, terms_of_service: Optional[Union[str, AnyUrl]] = None, contact: Optional[Contact] = None, diff --git a/esmerald/routing/base.py b/esmerald/routing/base.py index b0f74804..211ef519 100644 --- a/esmerald/routing/base.py +++ b/esmerald/routing/base.py @@ -594,6 +594,21 @@ def get_security_schemes(self) -> List["SecurityScheme"]: security_schemes.extend(layer.security or []) return security_schemes + def get_handler_tags(self) -> List[str]: + """ + Returns all the tags associated with the handler + by checking the parents as well. + """ + tags: List[str] = [] + for layer in self.parent_levels: + tags.extend(layer.tags or []) + + tags_clean: List[str] = [] + for tag in tags: + if tag not in tags_clean: + tags_clean.append(tag) + return tags_clean + class BaseInterceptorMixin(BaseHandlerMixin): # pragma: no cover def get_interceptors(self) -> List["AsyncCallable"]: diff --git a/esmerald/routing/gateways.py b/esmerald/routing/gateways.py index 4dc64386..d45ff204 100644 --- a/esmerald/routing/gateways.py +++ b/esmerald/routing/gateways.py @@ -35,6 +35,7 @@ class Gateway(StarletteRoute, BaseInterceptorMixin): "interceptors", "permissions", "deprecated", + "tags", ) def __init__( @@ -53,6 +54,7 @@ def __init__( deprecated: Optional[bool] = None, is_from_router: bool = False, security: Optional[Sequence["SecurityScheme"]] = None, + tags: Optional[Sequence[str]] = None, ) -> None: if not path: path = "/" @@ -98,6 +100,7 @@ def __init__( self.deprecated = deprecated self.parent = parent self.security = security + self.tags = tags or [] ( handler.path_regex, handler.path_format, @@ -146,6 +149,7 @@ class WebSocketGateway(StarletteWebSocketRoute, BaseInterceptorMixin): "permissions", "parent", "security", + "tags", ) def __init__( @@ -221,6 +225,8 @@ class WebhookGateway(StarletteRoute, BaseInterceptorMixin): "exception_handlers", "interceptors", "permissions", + "security", + "tags", ) def __init__( @@ -232,6 +238,7 @@ def __init__( parent: Optional["ParentType"] = None, deprecated: Optional[bool] = None, security: Optional[Sequence["SecurityScheme"]] = None, + tags: Optional[Sequence[str]] = None, ) -> None: if is_class_and_subclass(handler, View): handler = handler(parent=self) # type: ignore @@ -262,6 +269,7 @@ def __init__( self.deprecated = deprecated self.parent = parent self.security = security + self.tags = tags or [] ( handler.path_regex, handler.path_format, diff --git a/esmerald/routing/router.py b/esmerald/routing/router.py index 034d309c..6bdf7a1f 100644 --- a/esmerald/routing/router.py +++ b/esmerald/routing/router.py @@ -995,6 +995,7 @@ class Include(Mount): "middleware", "deprecated", "security", + "tags", ) def __init__( @@ -1014,6 +1015,7 @@ def __init__( include_in_schema: Optional[bool] = True, deprecated: Optional[bool] = None, security: Optional[Sequence["SecurityScheme"]] = None, + tags: Optional[Sequence[str]] = None, ) -> None: self.path = path if not path: @@ -1050,6 +1052,7 @@ def __init__( self.response_headers = None self.parent = parent self.security = security or [] + self.tags = tags or [] if routes: routes = self.resolve_route_path_handler(routes) diff --git a/esmerald/testclient.py b/esmerald/testclient.py index 01d9143b..1f07c974 100644 --- a/esmerald/testclient.py +++ b/esmerald/testclient.py @@ -13,7 +13,7 @@ import httpx # noqa from httpx._client import CookieTypes -from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme, Tag +from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme from pydantic import AnyUrl from starlette.testclient import TestClient # noqa @@ -120,7 +120,7 @@ def create_client( lifespan: Optional[Callable[["Esmerald"], "AsyncContextManager"]] = None, cookies: Optional[CookieTypes] = None, redirect_slashes: Optional[bool] = None, - tags: Optional[List[Tag]] = None, + tags: Optional[List[str]] = None, webhooks: Optional[Sequence["WebhookGateway"]] = None, ) -> EsmeraldTestClient: return EsmeraldTestClient( diff --git a/tests/openapi/test_tags.py b/tests/openapi/test_tags.py new file mode 100644 index 00000000..97aff0ac --- /dev/null +++ b/tests/openapi/test_tags.py @@ -0,0 +1,99 @@ +from typing import Dict + +from esmerald import Esmerald, Gateway, Include, get +from esmerald.testclient import EsmeraldTestClient, create_client + + +@get("/bar", tags=["bar"]) +async def bar() -> Dict[str, str]: + return {"hello": "world"} + + +app = Esmerald( + routes=[Gateway(handler=bar)], + enable_openapi=True, + tags=["test"], +) + + +client = EsmeraldTestClient(app) + + +def test_openapi_schema_tags(test_client_factory): + response = client.get("/openapi.json") + + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/bar": { + "get": { + "tags": ["test", "bar"], + "summary": "Bar", + "operationId": "bar_bar_get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + } + }, + "deprecated": False, + } + } + }, + "tags": ["test"], + } + + +def test_tags_nested(test_client_factory): + with create_client( + routes=[ + Include( + routes=[Gateway(handler=bar, tags=["gateway"])], + tags=["include"], + ) + ], + enable_openapi=True, + tags=["test"], + ) as client: + response = client.get("/openapi.json") + + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/bar": { + "get": { + "tags": ["test", "include", "gateway", "bar"], + "summary": "Bar", + "operationId": "bar_bar_get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + } + }, + "deprecated": False, + } + } + }, + "tags": ["test"], + }