From ae94f017c5e407653a1dd2686390aa6a2e342602 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jan 2024 01:59:29 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- conftest.py | 4 +- docs/source/conf.py | 4 +- etc/api_examples/api_intro.ipynb | 8 +- kernel_gateway/base/handlers.py | 5 +- kernel_gateway/gatewayapp.py | 74 +++++++++++---- kernel_gateway/jupyter_websocket/__init__.py | 8 +- kernel_gateway/notebook_http/__init__.py | 36 +++++--- kernel_gateway/notebook_http/cell/parser.py | 4 +- kernel_gateway/notebook_http/handlers.py | 12 ++- .../notebook_http/swagger/parser.py | 25 ++++-- kernel_gateway/services/kernels/handlers.py | 9 +- kernel_gateway/services/kernels/manager.py | 16 +++- kernel_gateway/services/kernels/pool.py | 4 +- kernel_gateway/services/sessions/handlers.py | 10 ++- .../services/sessions/sessionmanager.py | 12 ++- .../tests/notebook_http/cell/test_parser.py | 8 +- .../notebook_http/swagger/test_builders.py | 16 +++- .../notebook_http/swagger/test_parser.py | 32 ++++--- .../tests/notebook_http/test_request_utils.py | 10 ++- .../tests/resources/responses.ipynb | 4 +- .../tests/test_jupyter_websocket.py | 90 ++++++++++++++----- kernel_gateway/tests/test_notebook_http.py | 64 +++++++++---- 22 files changed, 341 insertions(+), 114 deletions(-) diff --git a/conftest.py b/conftest.py index 3ab563a..b4f6733 100644 --- a/conftest.py +++ b/conftest.py @@ -76,7 +76,9 @@ async def initialize_app(): jp_asyncio_loop.run_until_complete(initialize_app()) # Reroute all logging StreamHandlers away from stdin/stdout since pytest hijacks # these streams and closes them at unfortunate times. - stream_handlers = [h for h in app.log.handlers if isinstance(h, logging.StreamHandler)] + stream_handlers = [ + h for h in app.log.handlers if isinstance(h, logging.StreamHandler) + ] for handler in stream_handlers: handler.setStream(jp_logging_stream) app.log.propagate = True diff --git a/docs/source/conf.py b/docs/source/conf.py index 8b6645a..c425ea7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -259,7 +259,9 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, "kernel_gateway", "Kernel Gateway Documentation", [author], 1)] +man_pages = [ + (master_doc, "kernel_gateway", "Kernel Gateway Documentation", [author], 1) +] # If true, show URL addresses after external links. # man_show_urls = False diff --git a/etc/api_examples/api_intro.ipynb b/etc/api_examples/api_intro.ipynb index 10d8abd..c1f58f6 100644 --- a/etc/api_examples/api_intro.ipynb +++ b/etc/api_examples/api_intro.ipynb @@ -300,7 +300,9 @@ "body = req[\"body\"]\n", "contact_id = req[\"path\"][\"contact_id\"]\n", "if contact_id in contacts:\n", - " contacts[contact_id].update({field: body[field] for field in fields if field in body})\n", + " contacts[contact_id].update(\n", + " {field: body[field] for field in fields if field in body}\n", + " )\n", " status = 200\n", " print(json.dumps({\"contact_id\": contacts[contact_id]}))\n", "else:\n", @@ -618,7 +620,9 @@ " first_contact_id = post_resp.json()[\"contact_id\"]\n", "\n", " # update the contact\n", - " put_resp = requests.put(URL + \"/contacts/\" + first_contact_id, {\"phone\": \"919-444-5601\"})\n", + " put_resp = requests.put(\n", + " URL + \"/contacts/\" + first_contact_id, {\"phone\": \"919-444-5601\"}\n", + " )\n", " put_resp.raise_for_status()\n", " print(\"\\nupdated a contact:\", put_resp.json())\n", "\n", diff --git a/kernel_gateway/base/handlers.py b/kernel_gateway/base/handlers.py index c2dc7e1..bd8f839 100644 --- a/kernel_gateway/base/handlers.py +++ b/kernel_gateway/base/handlers.py @@ -9,7 +9,10 @@ class APIVersionHandler( - TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, server_handlers.APIVersionHandler + TokenAuthorizationMixin, + CORSMixin, + JSONErrorsMixin, + server_handlers.APIVersionHandler, ): """Extends the notebook server base API handler with token auth, CORS, and JSON errors. diff --git a/kernel_gateway/gatewayapp.py b/kernel_gateway/gatewayapp.py index 89d4024..3bb2a8c 100644 --- a/kernel_gateway/gatewayapp.py +++ b/kernel_gateway/gatewayapp.py @@ -23,12 +23,26 @@ from jupyter_core.paths import secure_write from jupyter_server.auth.authorizer import AllowAllAuthorizer, Authorizer from jupyter_server.serverapp import random_ports -from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection -from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection +from jupyter_server.services.kernels.connection.base import ( + BaseKernelWebsocketConnection, +) +from jupyter_server.services.kernels.connection.channels import ( + ZMQChannelsWebsocketConnection, +) from jupyter_server.services.kernels.kernelmanager import MappingKernelManager from tornado import httpserver, ioloop, web from tornado.log import LogFormatter, enable_pretty_logging -from traitlets import Bytes, CBool, Instance, Integer, List, Type, Unicode, default, observe +from traitlets import ( + Bytes, + CBool, + Instance, + Integer, + List, + Type, + Unicode, + default, + observe, +) from ._version import __version__ from .auth.identity import GatewayIdentityProvider @@ -84,7 +98,9 @@ class KernelGatewayApp(JupyterApp): port_env = "KG_PORT" port_default_value = 8888 port = Integer( - port_default_value, config=True, help="Port on which to listen (KG_PORT env var)" + port_default_value, + config=True, + help="Port on which to listen (KG_PORT env var)", ) @default("port") @@ -106,7 +122,9 @@ def port_retries_default(self): ip_env = "KG_IP" ip_default_value = "127.0.0.1" ip = Unicode( - ip_default_value, config=True, help="IP address on which to listen (KG_IP env var)" + ip_default_value, + config=True, + help="IP address on which to listen (KG_IP env var)", ) @default("ip") @@ -129,7 +147,8 @@ def base_url_default(self): # Token authorization auth_token_env = "KG_AUTH_TOKEN" auth_token = Unicode( - config=True, help="Authorization token required for all requests (KG_AUTH_TOKEN env var)" + config=True, + help="Authorization token required for all requests (KG_AUTH_TOKEN env var)", ) @default("auth_token") @@ -149,7 +168,8 @@ def allow_credentials_default(self): allow_headers_env = "KG_ALLOW_HEADERS" allow_headers = Unicode( - config=True, help="Sets the Access-Control-Allow-Headers header. (KG_ALLOW_HEADERS env var)" + config=True, + help="Sets the Access-Control-Allow-Headers header. (KG_ALLOW_HEADERS env var)", ) @default("allow_headers") @@ -158,7 +178,8 @@ def allow_headers_default(self): allow_methods_env = "KG_ALLOW_METHODS" allow_methods = Unicode( - config=True, help="Sets the Access-Control-Allow-Methods header. (KG_ALLOW_METHODS env var)" + config=True, + help="Sets the Access-Control-Allow-Methods header. (KG_ALLOW_METHODS env var)", ) @default("allow_methods") @@ -167,7 +188,8 @@ def allow_methods_default(self): allow_origin_env = "KG_ALLOW_ORIGIN" allow_origin = Unicode( - config=True, help="Sets the Access-Control-Allow-Origin header. (KG_ALLOW_ORIGIN env var)" + config=True, + help="Sets the Access-Control-Allow-Origin header. (KG_ALLOW_ORIGIN env var)", ) @default("allow_origin") @@ -381,7 +403,9 @@ def _default_cookie_secret(self): def _write_cookie_secret_file(self, secret): """write my secret to my secret_file""" - self.log.info("Writing Jupyter server cookie secret to %s", self.cookie_secret_file) + self.log.info( + "Writing Jupyter server cookie secret to %s", self.cookie_secret_file + ) try: with secure_write(self.cookie_secret_file, True) as f: f.write(secret) @@ -404,16 +428,16 @@ def _write_cookie_secret_file(self, secret): @default("ws_ping_interval") def _ws_ping_interval_default(self) -> int: - return int(os.getenv(self.ws_ping_interval_env, self.ws_ping_interval_default_value)) + return int( + os.getenv(self.ws_ping_interval_env, self.ws_ping_interval_default_value) + ) _log_formatter_cls = LogFormatter # traitlet default is LevelFormatter @default("log_format") def _default_log_format(self) -> str: """override default log format to include milliseconds""" - return ( - "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s" - ) + return "%(color)s[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s]%(end_color)s %(message)s" kernel_spec_manager = Instance(KernelSpecManager, allow_none=True) @@ -571,7 +595,9 @@ def init_configurables(self): **kwargs, ) - self.session_manager = SessionManager(log=self.log, kernel_manager=self.kernel_manager) + self.session_manager = SessionManager( + log=self.log, kernel_manager=self.kernel_manager + ) self.contents_manager = None self.identity_provider = self.identity_provider_class(parent=self, log=self.log) @@ -643,7 +669,9 @@ def init_webapp(self): ) # promote the current personality's "config" tagged traitlet values to webapp settings - for trait_name, trait_value in self.personality.class_traits(config=True).items(): + for trait_name, trait_value in self.personality.class_traits( + config=True + ).items(): kg_name = "kg_" + trait_name # a personality's traitlets may not overwrite the kernel gateway's if kg_name not in self.web_app.settings: @@ -695,9 +723,14 @@ def init_http_server(self): self.http_server.listen(port, self.ip) except OSError as e: if e.errno == errno.EADDRINUSE: - self.log.info("The port %i is already in use, trying another port." % port) + self.log.info( + "The port %i is already in use, trying another port." % port + ) continue - elif e.errno in (errno.EACCES, getattr(errno, "WSAEACCES", errno.EACCES)): + elif e.errno in ( + errno.EACCES, + getattr(errno, "WSAEACCES", errno.EACCES), + ): self.log.warning("Permission to listen on port %i denied" % port) continue else: @@ -786,7 +819,10 @@ def start_app(self): super().start() self.log.info( "Jupyter Kernel Gateway {} is available at http{}://{}:{}".format( - KernelGatewayApp.version, "s" if self.keyfile else "", self.ip, self.port + KernelGatewayApp.version, + "s" if self.keyfile else "", + self.ip, + self.port, ) ) diff --git a/kernel_gateway/jupyter_websocket/__init__.py b/kernel_gateway/jupyter_websocket/__init__.py index 8640c7e..073b033 100644 --- a/kernel_gateway/jupyter_websocket/__init__.py +++ b/kernel_gateway/jupyter_websocket/__init__.py @@ -11,7 +11,9 @@ from ..base.handlers import default_handlers as default_base_handlers from ..services.kernels.handlers import default_handlers as default_kernel_handlers from ..services.kernels.pool import KernelPool -from ..services.kernelspecs.handlers import default_handlers as default_kernelspec_handlers +from ..services.kernelspecs.handlers import ( + default_handlers as default_kernelspec_handlers, +) from ..services.sessions.handlers import default_handlers as default_session_handlers from .handlers import default_handlers as default_api_handlers @@ -51,7 +53,9 @@ def __init__(self, *args, **kwargs): self.kernel_pool = KernelPool() async def init_configurables(self): - await self.kernel_pool.initialize(self.parent.prespawn_count, self.parent.kernel_manager) + await self.kernel_pool.initialize( + self.parent.prespawn_count, self.parent.kernel_manager + ) def create_request_handlers(self): """Create default Jupyter handlers and redefine them off of the diff --git a/kernel_gateway/notebook_http/__init__.py b/kernel_gateway/notebook_http/__init__.py index 877d398..314c4dd 100644 --- a/kernel_gateway/notebook_http/__init__.py +++ b/kernel_gateway/notebook_http/__init__.py @@ -33,11 +33,15 @@ class NotebookHTTPPersonality(LoggingConfigurable): @default("cell_parser") def cell_parser_default(self): - return os.getenv(self.cell_parser_env, "kernel_gateway.notebook_http.cell.parser") + return os.getenv( + self.cell_parser_env, "kernel_gateway.notebook_http.cell.parser" + ) # Intentionally not defining an env var option for a dict type comment_prefix = Dict( - {"scala": "//", None: "#"}, config=True, help="Maps kernel language to code comment syntax" + {"scala": "//", None: "#"}, + config=True, + help="Maps kernel language to code comment syntax", ) allow_notebook_download_env = "KG_ALLOW_NOTEBOOK_DOWNLOAD" @@ -71,7 +75,9 @@ def __init__(self, *args, **kwargs): # Build the parser using the comment syntax for the notebook language func = cell_parser_module.create_parser try: - kernel_language = self.parent.seed_notebook["metadata"]["language_info"]["name"] + kernel_language = self.parent.seed_notebook["metadata"]["language_info"][ + "name" + ] except (AttributeError, KeyError): kernel_language = None prefix = self.comment_prefix.get(kernel_language, "#") @@ -85,7 +91,9 @@ def __init__(self, *args, **kwargs): self.kernel_pool = ManagedKernelPool() async def init_configurables(self): - await self.kernel_pool.initialize(self.parent.prespawn_count, self.parent.kernel_manager) + await self.kernel_pool.initialize( + self.parent.prespawn_count, self.parent.kernel_manager + ) def create_request_handlers(self): """Create handlers and redefine them off of the base_url path. Assumes @@ -97,13 +105,17 @@ def create_request_handlers(self): if self.allow_notebook_download: path = url_path_join("/", self.parent.base_url, r"/_api/source") self.log.info(f"Registering resource: {path}, methods: (GET)") - handlers.append((path, NotebookDownloadHandler, {"path": self.parent.seed_uri})) + handlers.append( + (path, NotebookDownloadHandler, {"path": self.parent.seed_uri}) + ) # Register a static path handler if configuration allows if self.static_path is not None: path = url_path_join("/", self.parent.base_url, r"/public/(.*)") self.log.info(f"Registering resource: {path}, methods: (GET)") - handlers.append((path, tornado.web.StaticFileHandler, {"path": self.static_path})) + handlers.append( + (path, tornado.web.StaticFileHandler, {"path": self.static_path}) + ) # Discover the notebook endpoints and their implementations endpoints = self.api_parser.endpoints(self.parent.kernel_manager.seed_source) @@ -118,14 +130,18 @@ def create_request_handlers(self): # Cycle through the (endpoint_path, source) tuples and register their handlers for endpoint_path, verb_source_map in endpoints: parameterized_path = parameterize_path(endpoint_path) - parameterized_path = url_path_join("/", self.parent.base_url, parameterized_path) + parameterized_path = url_path_join( + "/", self.parent.base_url, parameterized_path + ) self.log.info( "Registering resource: {}, methods: ({})".format( parameterized_path, list(verb_source_map.keys()) ) ) response_source_map = ( - response_sources[endpoint_path] if endpoint_path in response_sources else {} + response_sources[endpoint_path] + if endpoint_path in response_sources + else {} ) handler_args = { "sources": verb_source_map, @@ -159,9 +175,9 @@ def should_seed_cell(self, code): """Determines whether the given code cell source should be executed when seeding a new kernel.""" # seed cells that are uninvolved with the presented API - return not self.api_parser.is_api_cell(code) and not self.api_parser.is_api_response_cell( + return not self.api_parser.is_api_cell( code - ) + ) and not self.api_parser.is_api_response_cell(code) async def shutdown(self): """Stop all kernels in the pool.""" diff --git a/kernel_gateway/notebook_http/cell/parser.py b/kernel_gateway/notebook_http/cell/parser.py index 2f0208d..2119421 100644 --- a/kernel_gateway/notebook_http/cell/parser.py +++ b/kernel_gateway/notebook_http/cell/parser.py @@ -75,7 +75,9 @@ class APICellParser(LoggingConfigurable): def __init__(self, comment_prefix, notebook_cells=None, **kwargs): super().__init__(**kwargs) - self.kernelspec_api_indicator = re.compile(self.api_indicator.format(comment_prefix)) + self.kernelspec_api_indicator = re.compile( + self.api_indicator.format(comment_prefix) + ) self.kernelspec_api_response_indicator = re.compile( self.api_response_indicator.format(comment_prefix) ) diff --git a/kernel_gateway/notebook_http/handlers.py b/kernel_gateway/notebook_http/handlers.py index 55e3494..b079bf1 100644 --- a/kernel_gateway/notebook_http/handlers.py +++ b/kernel_gateway/notebook_http/handlers.py @@ -53,7 +53,9 @@ class NotebookAPIHandler( are identified, parsed, and associated with HTTP verbs and paths. """ - def initialize(self, sources, response_sources, kernel_pool, kernel_name, kernel_language=""): + def initialize( + self, sources, response_sources, kernel_pool, kernel_name, kernel_language="" + ): self.kernel_pool = kernel_pool self.sources = sources self.kernel_name = kernel_name @@ -202,12 +204,16 @@ async def _handle_request(self): # Run the request and source code and yield until there's a result access_log.debug(f"Request code for notebook cell is: {request_code}") await self.execute_code(kernel_client, kernel_id, request_code) - source_result = await self.execute_code(kernel_client, kernel_id, source_code) + source_result = await self.execute_code( + kernel_client, kernel_id, source_code + ) # If a response code cell exists, execute it if self.request.method in self.response_sources: response_code = self.response_sources[self.request.method] - response_future = self.execute_code(kernel_client, kernel_id, response_code) + response_future = self.execute_code( + kernel_client, kernel_id, response_code + ) # Wait for the response and parse the json value response_result = await response_future diff --git a/kernel_gateway/notebook_http/swagger/parser.py b/kernel_gateway/notebook_http/swagger/parser.py index c8cff5d..f826915 100644 --- a/kernel_gateway/notebook_http/swagger/parser.py +++ b/kernel_gateway/notebook_http/swagger/parser.py @@ -68,7 +68,9 @@ class SwaggerCellParser(LoggingConfigurable): """ operation_indicator = Unicode(default_value=r"{}\s*operationId:\s*(.*)") - operation_response_indicator = Unicode(default_value=r"{}\s*ResponseInfo\s+operationId:\s*(.*)") + operation_response_indicator = Unicode( + default_value=r"{}\s*ResponseInfo\s+operationId:\s*(.*)" + ) notebook_cells = List() def __init__(self, comment_prefix, notebook_cells, **kwargs): @@ -98,7 +100,9 @@ def __init__(self, comment_prefix, notebook_cells, **kwargs): for endpoint in self.swagger["paths"].keys(): for verb in self.swagger["paths"][endpoint].keys(): if "operationId" in self.swagger["paths"][endpoint][verb]: - operationId = self.swagger["paths"][endpoint][verb]["operationId"] + operationId = self.swagger["paths"][endpoint][verb][ + "operationId" + ] operationIdsDeclared.append(operationId) for operationId in operationIdsDeclared: if operationId not in operationIdsFound: @@ -227,18 +231,23 @@ def _endpoint_verb_source_mappings(self, source_cells, operationIdRegex): for verb in self.swagger["paths"][endpoint].keys(): if ( "operationId" in self.swagger["paths"][endpoint][verb] - and self.swagger["paths"][endpoint][verb]["operationId"] in operationIds + and self.swagger["paths"][endpoint][verb]["operationId"] + in operationIds ): operationId = self.swagger["paths"][endpoint][verb]["operationId"] if "parameters" in self.swagger["paths"][endpoint][verb]: endpoint_with_param = endpoint ## do we need to sort these names as well? - for parameter in self.swagger["paths"][endpoint][verb]["parameters"]: + for parameter in self.swagger["paths"][endpoint][verb][ + "parameters" + ]: if "name" in parameter: endpoint_with_param = "/:".join( [endpoint_with_param, parameter["name"]] ) - mappings.setdefault(endpoint_with_param, {}).setdefault(verb, "") + mappings.setdefault(endpoint_with_param, {}).setdefault( + verb, "" + ) mappings[endpoint_with_param][verb] = operationIds[operationId] else: mappings.setdefault(endpoint, {}).setdefault(verb, "") @@ -269,7 +278,8 @@ def get_cell_endpoint_and_verb(self, cell_source): for verb in self.swagger["paths"][endpoint].keys(): if ( "operationId" in self.swagger["paths"][endpoint][verb] - and self.swagger["paths"][endpoint][verb]["operationId"] == operationId + and self.swagger["paths"][endpoint][verb]["operationId"] + == operationId ): return (endpoint, verb) return (None, None) @@ -294,7 +304,8 @@ def get_path_content(self, cell_source): for verb in self.swagger["paths"][endpoint].keys(): if ( "operationId" in self.swagger["paths"][endpoint][verb] - and self.swagger["paths"][endpoint][verb]["operationId"] == operationId + and self.swagger["paths"][endpoint][verb]["operationId"] + == operationId ): return self.swagger["paths"][endpoint][verb] # mismatched operationId? return a default diff --git a/kernel_gateway/services/kernels/handlers.py b/kernel_gateway/services/kernels/handlers.py index deaf0c5..ee30e0b 100644 --- a/kernel_gateway/services/kernels/handlers.py +++ b/kernel_gateway/services/kernels/handlers.py @@ -12,7 +12,10 @@ class MainKernelHandler( - TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, server_handlers.MainKernelHandler + TokenAuthorizationMixin, + CORSMixin, + JSONErrorsMixin, + server_handlers.MainKernelHandler, ): """Extends the notebook main kernel handler with token auth, CORS, and JSON errors. @@ -74,7 +77,9 @@ async def post(self): # No way to override the call to start_kernel on the kernel manager # so do a temporary partial (ugh) orig_start = self.kernel_manager.start_kernel - self.kernel_manager.start_kernel = partial(self.kernel_manager.start_kernel, env=env) + self.kernel_manager.start_kernel = partial( + self.kernel_manager.start_kernel, env=env + ) try: await super().post() finally: diff --git a/kernel_gateway/services/kernels/manager.py b/kernel_gateway/services/kernels/manager.py index 45b5b3e..c441176 100644 --- a/kernel_gateway/services/kernels/manager.py +++ b/kernel_gateway/services/kernels/manager.py @@ -22,7 +22,9 @@ def _default_root_dir(self): return os.getcwd() def _kernel_manager_class_default(self): - return "kernel_gateway.services.kernels.manager.KernelGatewayIOLoopKernelManager" + return ( + "kernel_gateway.services.kernels.manager.KernelGatewayIOLoopKernelManager" + ) @property def seed_kernelspec(self) -> Optional[str]: @@ -43,7 +45,9 @@ def seed_kernelspec(self) -> Optional[str]: if self.parent.force_kernel_name: self._seed_kernelspec = self.parent.force_kernel_name else: - self._seed_kernelspec = self.parent.seed_notebook["metadata"]["kernelspec"]["name"] + self._seed_kernelspec = self.parent.seed_notebook["metadata"][ + "kernelspec" + ]["name"] else: self._seed_kernelspec = None @@ -86,7 +90,9 @@ async def start_kernel(self, *args, **kwargs): """ if self.parent.force_kernel_name: kwargs["kernel_name"] = self.parent.force_kernel_name - kernel_id = await super(SeedingMappingKernelManager, self).start_kernel(*args, **kwargs) + kernel_id = await super(SeedingMappingKernelManager, self).start_kernel( + *args, **kwargs + ) if kernel_id and self.seed_source is not None: # Only run source if the kernel spec matches the notebook kernel spec @@ -117,7 +123,9 @@ async def start_kernel(self, *args, **kwargs): client.stop_channels() # Shutdown the kernel await self.shutdown_kernel(kernel_id) - raise RuntimeError("Error seeding kernel memory", msg["content"]) + raise RuntimeError( + "Error seeding kernel memory", msg["content"] + ) # Shutdown the channels to remove any lingering ZMQ messages client.stop_channels() return kernel_id diff --git a/kernel_gateway/services/kernels/pool.py b/kernel_gateway/services/kernels/pool.py index c6d74b2..dc41866 100644 --- a/kernel_gateway/services/kernels/pool.py +++ b/kernel_gateway/services/kernels/pool.py @@ -109,7 +109,9 @@ async def initialize(self, prespawn_count, kernel_manager, **kwargs): # Create clients and iopub handlers for prespawned kernels for kernel_id in kernel_ids: - self.kernel_clients[kernel_id] = kernel_manager.get_kernel(kernel_id).client() + self.kernel_clients[kernel_id] = kernel_manager.get_kernel( + kernel_id + ).client() self.kernel_pool.append(kernel_id) iopub = self.kernel_manager.connect_iopub(kernel_id) iopub.on_recv(self.create_on_reply(kernel_id)) diff --git a/kernel_gateway/services/sessions/handlers.py b/kernel_gateway/services/sessions/handlers.py index d328866..61c5a23 100644 --- a/kernel_gateway/services/sessions/handlers.py +++ b/kernel_gateway/services/sessions/handlers.py @@ -9,7 +9,10 @@ class SessionRootHandler( - TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, server_handlers.SessionRootHandler + TokenAuthorizationMixin, + CORSMixin, + JSONErrorsMixin, + server_handlers.SessionRootHandler, ): """Extends the notebook root session handler with token auth, CORS, and JSON errors. @@ -24,7 +27,10 @@ async def get(self): tornado.web.HTTPError If kg_list_kernels is False, respond with 403 Forbidden """ - if "kg_list_kernels" not in self.settings or self.settings["kg_list_kernels"] != True: + if ( + "kg_list_kernels" not in self.settings + or self.settings["kg_list_kernels"] != True + ): raise tornado.web.HTTPError(403, "Forbidden") else: await super(SessionRootHandler, self).get() diff --git a/kernel_gateway/services/sessions/sessionmanager.py b/kernel_gateway/services/sessions/sessionmanager.py index cd9fd74..c3967f0 100644 --- a/kernel_gateway/services/sessions/sessionmanager.py +++ b/kernel_gateway/services/sessions/sessionmanager.py @@ -75,10 +75,14 @@ async def create_session( """ session_id = self.new_session_id() # allow nbm to specify kernels cwd - kernel_id = await self.kernel_manager.start_kernel(path=path, kernel_name=kernel_name) + kernel_id = await self.kernel_manager.start_kernel( + path=path, kernel_name=kernel_name + ) return self.save_session(session_id, path=path, kernel_id=kernel_id) - def save_session(self, session_id, path=None, kernel_id=None, *args, **kwargs) -> dict: + def save_session( + self, session_id, path=None, kernel_id=None, *args, **kwargs + ) -> dict: """Saves the metadata for the session with the given `session_id`. Given a `session_id` (and any other of the arguments), this method @@ -98,7 +102,9 @@ def save_session(self, session_id, path=None, kernel_id=None, *args, **kwargs) - dict Session model with `session_id`, `path`, and `kernel_id` keys """ - self._sessions.append({"session_id": session_id, "path": path, "kernel_id": kernel_id}) + self._sessions.append( + {"session_id": session_id, "path": path, "kernel_id": kernel_id} + ) return self.get_session(session_id=session_id) diff --git a/kernel_gateway/tests/notebook_http/cell/test_parser.py b/kernel_gateway/tests/notebook_http/cell/test_parser.py index bb06261..b593182 100644 --- a/kernel_gateway/tests/notebook_http/cell/test_parser.py +++ b/kernel_gateway/tests/notebook_http/cell/test_parser.py @@ -30,7 +30,9 @@ def test_endpoint_sort_default_strategy(self): for index in range(len(expected_values)): endpoint, _ = endpoints[index] - assert expected_values[index] == endpoint, "Endpoint was not found in expected order" + assert ( + expected_values[index] == endpoint + ), "Endpoint was not found in expected order" def test_endpoint_sort_custom_strategy(self): """Parser should sort duplicate endpoint paths using a custom sort @@ -53,7 +55,9 @@ def custom_sort_fun(endpoint): for index in range(len(expected_values)): endpoint, _ = endpoints[index] - assert expected_values[index] == endpoint, "Endpoint was not found in expected order" + assert ( + expected_values[index] == endpoint + ), "Endpoint was not found in expected order" def test_get_cell_endpoint_and_verb(self): """Parser should extract API endpoint and verb from cell annotations.""" diff --git a/kernel_gateway/tests/notebook_http/swagger/test_builders.py b/kernel_gateway/tests/notebook_http/swagger/test_builders.py index 8a650b5..e3f6432 100644 --- a/kernel_gateway/tests/notebook_http/swagger/test_builders.py +++ b/kernel_gateway/tests/notebook_http/swagger/test_builders.py @@ -26,7 +26,9 @@ def test_add_cell_adds_api_cell_to_spec(self): builder = SwaggerSpecBuilder(APICellParser(comment_prefix="#")) builder.add_cell("# GET /some/resource") result = builder.build() - assert result["paths"]["/some/resource"] == expected, "Title was not set to new value" + assert ( + result["paths"]["/some/resource"] == expected + ), "Title was not set to new value" def test_all_swagger_preserved_in_spec(self): """Builder should store the swagger documented cell.""" @@ -83,18 +85,24 @@ def test_all_swagger_preserved_in_spec(self): assert ( result["info"]["title"] == json.loads(expected)["info"]["title"] ), "title was not preserved" - assert json.dumps(result["paths"]["/some/resource"], sort_keys=True) == json.dumps( + assert json.dumps( + result["paths"]["/some/resource"], sort_keys=True + ) == json.dumps( json.loads(expected)["paths"]["/some/resource"], sort_keys=True ), "operations were not as expected" new_title = "new title. same contents." builder.set_default_title(new_title) result = builder.build() - assert result["info"]["title"] != new_title, "title should not have been changed" + assert ( + result["info"]["title"] != new_title + ), "title should not have been changed" def test_add_undocumented_cell_does_not_add_non_api_cell_to_spec(self): """Builder should store ignore non-API cells.""" - builder = SwaggerSpecBuilder(SwaggerCellParser(comment_prefix="#", notebook_cells=[])) + builder = SwaggerSpecBuilder( + SwaggerCellParser(comment_prefix="#", notebook_cells=[]) + ) builder.add_cell("regular code cell") builder.add_cell("# regular commented cell") result = builder.build() diff --git a/kernel_gateway/tests/notebook_http/swagger/test_parser.py b/kernel_gateway/tests/notebook_http/swagger/test_parser.py index 7bd61b5..27c673e 100644 --- a/kernel_gateway/tests/notebook_http/swagger/test_parser.py +++ b/kernel_gateway/tests/notebook_http/swagger/test_parser.py @@ -30,14 +30,17 @@ def test_basic_is_api_cell(self): } ], ) - assert parser.is_api_cell("#operationId:foo"), "API cell was not detected with " + str( - parser.kernelspec_operation_indicator + assert parser.is_api_cell("#operationId:foo"), ( + "API cell was not detected with " + + str(parser.kernelspec_operation_indicator) ) - assert parser.is_api_cell("# operationId:foo"), "API cell was not detected with " + str( - parser.kernelspec_operation_indicator + assert parser.is_api_cell("# operationId:foo"), ( + "API cell was not detected with " + + str(parser.kernelspec_operation_indicator) ) - assert parser.is_api_cell("#operationId: foo"), "API cell was not detected with " + str( - parser.kernelspec_operation_indicator + assert parser.is_api_cell("#operationId: foo"), ( + "API cell was not detected with " + + str(parser.kernelspec_operation_indicator) ) assert parser.is_api_cell("no") is False, "API cell was detected" assert parser.is_api_cell("# another comment") is False, "API cell was detected" @@ -118,13 +121,17 @@ def custom_sort_fun(endpoint): return 2 parser = SwaggerCellParser(comment_prefix="#", notebook_cells=source_cells) - endpoints = parser.endpoints((cell["source"] for cell in source_cells), custom_sort_fun) + endpoints = parser.endpoints( + (cell["source"] for cell in source_cells), custom_sort_fun + ) print(str(endpoints)) expected_values = ["/+", "/a", "/1"] for index in range(len(expected_values)): endpoint, _ = endpoints[index] - assert expected_values[index] == endpoint, "Endpoint was not found in expected order" + assert ( + expected_values[index] == endpoint + ), "Endpoint was not found in expected order" def test_get_cell_endpoint_and_verb(self): """Parser should extract API endpoint and verb from cell annotations.""" @@ -139,7 +146,9 @@ def test_get_cell_endpoint_and_verb(self): endpoint, verb = parser.get_cell_endpoint_and_verb("# operationId: getFoo") assert endpoint == "/foo", "Endpoint was not extracted correctly" assert verb.lower() == "get", "Endpoint was not extracted correctly" - endpoint, verb = parser.get_cell_endpoint_and_verb("# operationId: post_bar_Quo") + endpoint, verb = parser.get_cell_endpoint_and_verb( + "# operationId: post_bar_Quo" + ) assert endpoint == "/bar/quo", "Endpoint was not extracted correctly" assert verb.lower() == "post", "Endpoint was not extracted correctly" @@ -173,7 +182,10 @@ def test_endpoint_concatenation(self): assert len(endpoints["/foo/:bar"]) == 2 assert endpoints["/foo"]["post"] == "# operationId: postFooBody \n" assert endpoints["/foo/:bar"]["get"] == "# operationId: getFoo\n" - assert endpoints["/foo/:bar"]["put"] == "# operationId: putFoo\n# operationId: putFoo\n" + assert ( + endpoints["/foo/:bar"]["put"] + == "# operationId: putFoo\n# operationId: putFoo\n" + ) def test_endpoint_response_concatenation(self): """Parser should concatenate multiple response cells with the same verb+path.""" diff --git a/kernel_gateway/tests/notebook_http/test_request_utils.py b/kernel_gateway/tests/notebook_http/test_request_utils.py index 0f754c7..75bb917 100644 --- a/kernel_gateway/tests/notebook_http/test_request_utils.py +++ b/kernel_gateway/tests/notebook_http/test_request_utils.py @@ -87,7 +87,9 @@ def test_parse_body_defaults_to_text_plain(self): request.body = b'{"foo" : "bar"}' request.headers = {} result = parse_body(request) - self.assertEqual(result, '{"foo" : "bar"}', "Did not properly handle body = empty string.") + self.assertEqual( + result, '{"foo" : "bar"}', "Did not properly handle body = empty string." + ) def test_parse_args(self): """Should parse URL argument byte streams to strings.""" @@ -116,7 +118,11 @@ def test_headers_to_dict(self): """Should parse headers into a dictionary.""" result = headers_to_dict( MockHeaders( - [("Content-Type", "application/json"), ("Set-Cookie", "A=B"), ("Set-Cookie", "C=D")] + [ + ("Content-Type", "application/json"), + ("Set-Cookie", "A=B"), + ("Set-Cookie", "C=D"), + ] ) ) self.assertEqual( diff --git a/kernel_gateway/tests/resources/responses.ipynb b/kernel_gateway/tests/resources/responses.ipynb index 0021108..3639a90 100644 --- a/kernel_gateway/tests/resources/responses.ipynb +++ b/kernel_gateway/tests/resources/responses.ipynb @@ -79,7 +79,9 @@ "outputs": [], "source": [ "# ResponseInfo GET /etag\n", - "print(json.dumps({\"headers\": {\"Content-Type\": \"application/json\", \"Etag\": \"1234567890\"}}))" + "print(\n", + " json.dumps({\"headers\": {\"Content-Type\": \"application/json\", \"Etag\": \"1234567890\"}})\n", + ")" ] }, { diff --git a/kernel_gateway/tests/test_jupyter_websocket.py b/kernel_gateway/tests/test_jupyter_websocket.py index 4cce7e9..9637b1c 100644 --- a/kernel_gateway/tests/test_jupyter_websocket.py +++ b/kernel_gateway/tests/test_jupyter_websocket.py @@ -73,7 +73,12 @@ def get_execute_request(code: str) -> dict: }, "parent_header": {}, "channel": "shell", - "content": {"code": code, "silent": False, "store_history": False, "user_expressions": {}}, + "content": { + "code": code, + "silent": False, + "store_history": False, + "user_expressions": {}, + }, "metadata": {}, "buffers": {}, } @@ -93,7 +98,9 @@ async def await_stream(ws): class TestDefaults: """Tests gateway behavior.""" - @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],)) + @pytest.mark.parametrize( + "jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],) + ) async def test_startup(self, jp_fetch, jp_argv): """Root of kernels resource should be OK.""" response = await jp_fetch("api", "kernels", method="GET") @@ -116,7 +123,9 @@ async def test_headless(self, jp_fetch): async def test_check_origin(self, jp_fetch, jp_web_app): """Allow origin setting should pass through to base handlers.""" with pytest.raises(HTTPClientError) as e: - await jp_fetch("api", "kernelspecs", headers={"Origin": "fake.com:8888"}, method="GET") + await jp_fetch( + "api", "kernelspecs", headers={"Origin": "fake.com:8888"}, method="GET" + ) assert e.value.code == 404 jp_web_app.settings["allow_origin"] = "*" @@ -138,7 +147,9 @@ async def test_check_origin(self, jp_fetch, jp_web_app): ), ), ) - async def test_config_bad_api_value(self, jp_configurable_serverapp, jp_server_config): + async def test_config_bad_api_value( + self, jp_configurable_serverapp, jp_server_config + ): """Should raise an ImportError for nonexistent API personality modules.""" with pytest.raises(ImportError): await jp_configurable_serverapp() @@ -161,7 +172,9 @@ async def test_options_without_auth_token(self, jp_fetch, jp_web_app): ), ), ) - async def test_auth_token(self, jp_server_config, jp_fetch, jp_web_app, jp_ws_fetch): + async def test_auth_token( + self, jp_server_config, jp_fetch, jp_web_app, jp_ws_fetch + ): """All server endpoints should check the configured auth token.""" # Request API without the token @@ -180,19 +193,28 @@ async def test_auth_token(self, jp_server_config, jp_fetch, jp_web_app, jp_ws_fe # Request kernelspecs without the token with pytest.raises(HTTPClientError) as e: - await jp_fetch("api", "kernelspecs", method="GET", headers={"Authorization": ""}) + await jp_fetch( + "api", "kernelspecs", method="GET", headers={"Authorization": ""} + ) assert e.value.response.code == 401 # Now with it response = await jp_fetch( - "api", "kernelspecs", method="GET", headers={"Authorization": "token fake-token"} + "api", + "kernelspecs", + method="GET", + headers={"Authorization": "token fake-token"}, ) assert response.code == 200 # Request a kernel without the token with pytest.raises(HTTPClientError) as e: await jp_fetch( - "api", "kernels", method="POST", body="{}", headers={"Authorization": ""} + "api", + "kernels", + method="POST", + body="{}", + headers={"Authorization": ""}, ) assert e.value.response.code == 401 @@ -210,12 +232,18 @@ async def test_auth_token(self, jp_server_config, jp_fetch, jp_web_app, jp_ws_fe # Request kernel info without the token with pytest.raises(HTTPClientError) as e: - await jp_fetch("api", "kernels", kernel_id, method="GET", headers={"Authorization": ""}) + await jp_fetch( + "api", "kernels", kernel_id, method="GET", headers={"Authorization": ""} + ) assert e.value.response.code == 401 # Now with it response = await jp_fetch( - "api", "kernels", kernel_id, method="GET", headers={"Authorization": "token fake-token"} + "api", + "kernels", + kernel_id, + method="GET", + headers={"Authorization": "token fake-token"}, ) assert response.code == 200 @@ -230,7 +258,11 @@ async def test_auth_token(self, jp_server_config, jp_fetch, jp_web_app, jp_ws_fe # Now request the websocket with the token ws = await jp_ws_fetch( - "api", "kernels", kernel_id, "channels", headers={"Authorization": "token fake-token"} + "api", + "kernels", + kernel_id, + "channels", + headers={"Authorization": "token fake-token"}, ) ws.close() @@ -249,7 +281,10 @@ async def test_cors_headers(self, jp_fetch, jp_web_app): response = await jp_fetch("api", "kernels", method="GET") assert response.code == 200 assert response.headers["Access-Control-Allow-Credentials"] == "false" - assert response.headers["Access-Control-Allow-Headers"] == "Authorization,Content-Type" + assert ( + response.headers["Access-Control-Allow-Headers"] + == "Authorization,Content-Type" + ) assert response.headers["Access-Control-Allow-Methods"] == "GET,POST" assert response.headers["Access-Control-Allow-Origin"] == "https://jupyter.org" assert response.headers["Access-Control-Expose-Headers"] == "X-My-Fake-Header" @@ -281,7 +316,9 @@ async def test_max_kernels(self, jp_fetch, jp_web_app): # Shut down the kernel kernel = json_decode(response.body) - response = await jp_fetch("api", "kernels", url_escape(kernel["id"]), method="DELETE") + response = await jp_fetch( + "api", "kernels", url_escape(kernel["id"]), method="DELETE" + ) assert response.code == 204 # Try creation again @@ -430,7 +467,9 @@ async def test_json_errors(self, jp_fetch): body = json_decode(e.value.response.body) assert body["reason"] == "Not Found" - @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.env_whitelist=TEST_VAR"],)) + @pytest.mark.parametrize( + "jp_argv", (["--JupyterWebsocketPersonality.env_whitelist=TEST_VAR"],) + ) async def test_kernel_env(self, spawn_kernel, jp_argv): """Kernel should start with environment vars defined in the request.""" @@ -488,7 +527,9 @@ async def test_kernel_env_auth_token(self, monkeypatch, spawn_kernel): ws.close() -@pytest.mark.parametrize("jp_argv", (["--KernelGatewayApp.default_kernel_name=fake-kernel"],)) +@pytest.mark.parametrize( + "jp_argv", (["--KernelGatewayApp.default_kernel_name=fake-kernel"],) +) class TestCustomDefaultKernel: """Tests gateway behavior when setting a custom default kernelspec.""" @@ -515,7 +556,9 @@ class TestForceKernel: async def test_force_kernel_name(self, jp_argv, jp_fetch): """Should create a Python kernel.""" - response = await jp_fetch("api", "kernels", method="POST", body='{"name": "fake-kernel"}') + response = await jp_fetch( + "api", "kernels", method="POST", body='{"name": "fake-kernel"}' + ) assert response.code == 201 kernel = json_decode(response.body) assert kernel["name"] == "python3" @@ -524,7 +567,9 @@ async def test_force_kernel_name(self, jp_argv, jp_fetch): class TestEnableDiscovery: """Tests gateway behavior with kernel listing enabled.""" - @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],)) + @pytest.mark.parametrize( + "jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],) + ) async def test_enable_kernel_list(self, jp_fetch, jp_argv): """The list of kernels, sessions, and activities should be available.""" @@ -563,7 +608,9 @@ def test_prespawn_max_conflict(self): class TestBaseURL: """Tests gateway behavior when a custom base URL is configured.""" - @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],)) + @pytest.mark.parametrize( + "jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],) + ) @pytest.mark.parametrize("jp_base_url", ("/fake/path",)) async def test_base_url(self, jp_base_url, jp_argv, jp_fetch): """Server should mount resources under configured base.""" @@ -576,7 +623,9 @@ async def test_base_url(self, jp_base_url, jp_argv, jp_fetch): class TestRelativeBaseURL: """Tests gateway behavior when a relative base URL is configured.""" - @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],)) + @pytest.mark.parametrize( + "jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],) + ) @pytest.mark.parametrize("jp_base_url", ("/fake/path",)) async def test_base_url(self, jp_base_url, jp_argv, jp_fetch): """Server should mount resources under fixed base.""" @@ -591,7 +640,8 @@ class TestSeedURI: """Tests gateway behavior when a seeding kernel memory with code from a notebook.""" @pytest.mark.parametrize( - "jp_argv", ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'zen.ipynb')}"],) + "jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'zen.ipynb')}"],), ) async def test_seed(self, jp_argv, spawn_kernel): """Kernel should have variables pre-seeded from the notebook.""" diff --git a/kernel_gateway/tests/test_notebook_http.py b/kernel_gateway/tests/test_notebook_http.py index 1393ff1..ce6dd27 100644 --- a/kernel_gateway/tests/test_notebook_http.py +++ b/kernel_gateway/tests/test_notebook_http.py @@ -39,14 +39,20 @@ async def test_api_get_endpoint_with_path_param(self, jp_fetch): """GET HTTP method should be callable with a path param""" response = await jp_fetch("hello", "governor", method="GET") assert response.code == 200, "GET endpoint did not return 200." - assert response.body == b"hello governor\n", "Unexpected body in response to GET." + assert ( + response.body == b"hello governor\n" + ), "Unexpected body in response to GET." async def test_api_get_endpoint_with_query_param(self, jp_fetch): """GET HTTP method should be callable with a query param""" - response = await jp_fetch("hello", "person", params={"person": "governor"}, method="GET") + response = await jp_fetch( + "hello", "person", params={"person": "governor"}, method="GET" + ) assert response.code == 200, "GET endpoint did not return 200." print(f"response.body = '{response.body}'") - assert response.body == b"hello governor\n", "Unexpected body in response to GET." + assert ( + response.body == b"hello governor\n" + ), "Unexpected body in response to GET." async def test_api_get_endpoint_with_multiple_query_params(self, jp_fetch): """GET HTTP method should be callable with multiple query params""" @@ -54,7 +60,9 @@ async def test_api_get_endpoint_with_multiple_query_params(self, jp_fetch): "hello", "persons", params={"person": "governor, rick"}, method="GET" ) assert response.code == 200, "GET endpoint did not return 200." - assert response.body == b"hello governor, rick\n", "Unexpected body in response to GET." + assert ( + response.body == b"hello governor, rick\n" + ), "Unexpected body in response to GET." async def test_api_put_endpoint(self, jp_fetch): """PUT HTTP method should be callable""" @@ -126,7 +134,12 @@ async def test_api_undefined(self, jp_fetch): async def test_api_access_http_header(self, jp_fetch): """HTTP endpoints should be able to access request headers""" - content_types = ["text/plain", "application/json", "application/atom+xml", "foo"] + content_types = [ + "text/plain", + "application/json", + "application/atom+xml", + "foo", + ] for content_type in content_types: response = await jp_fetch( "content-type", method="GET", headers={"Content-Type": content_type} @@ -147,7 +160,9 @@ async def test_format_request_code_escaped_integration(self, jp_fetch): headers={"If-None-Match": '""9a28a9262f954494a8de7442c63d6d0715ce0998""'}, ) assert response.code == 200, "GET endpoint did not return 200." - assert response.body == b"hello governor\n", "Unexpected body in response to GET." + assert ( + response.body == b"hello governor\n" + ), "Unexpected body in response to GET." async def test_blocked_download_notebook_source(self, jp_fetch): """Notebook source should not exist under the path /_api/source when @@ -155,7 +170,9 @@ async def test_blocked_download_notebook_source(self, jp_fetch): """ with pytest.raises(HTTPClientError) as e: await jp_fetch("_api", "source", method="GET") - assert e.value.code == 404, "/_api/source found when allow_notebook_download is false" + assert ( + e.value.code == 404 + ), "/_api/source found when allow_notebook_download is false" async def test_blocked_public(self, jp_fetch): """Public static assets should not exist under the path /public when @@ -169,7 +186,9 @@ async def test_api_returns_execute_result(self, jp_fetch): """GET HTTP method should return the result of cell execution""" response = await jp_fetch("execute_result", method="GET") assert response.code == 200, "GET endpoint did not return 200." - assert response.body == b'{"text/plain": "2"}', "Unexpected body in response to GET." + assert ( + response.body == b'{"text/plain": "2"}' + ), "Unexpected body in response to GET." async def test_cells_concatenate(self, jp_fetch): """Multiple cells with the same verb and path should concatenate.""" @@ -181,11 +200,14 @@ async def test_kernel_gateway_environment_set(self, jp_fetch): """GET HTTP method should be callable with multiple query params""" response = await jp_fetch("env_kernel_gateway", method="GET") assert response.code == 200, "GET endpoint did not return 200." - assert response.body == b"KERNEL_GATEWAY is 1\n", "Unexpected body in response to GET." + assert ( + response.body == b"KERNEL_GATEWAY is 1\n" + ), "Unexpected body in response to GET." @pytest.mark.parametrize( - "jp_argv", ([f"--NotebookHTTPPersonality.static_path={os.path.join(RESOURCES, 'public')}"],) + "jp_argv", + ([f"--NotebookHTTPPersonality.static_path={os.path.join(RESOURCES, 'public')}"],), ) class TestPublicStatic: """Tests gateway behavior when public static assets are enabled.""" @@ -197,21 +219,26 @@ async def test_get_public(self, jp_fetch, jp_argv): assert response.headers.get("Content-Type") == "text/html" -@pytest.mark.parametrize("jp_argv", (["--NotebookHTTPPersonality.allow_notebook_download=True"],)) +@pytest.mark.parametrize( + "jp_argv", (["--NotebookHTTPPersonality.allow_notebook_download=True"],) +) class TestSourceDownload: """Tests gateway behavior when notebook download is allowed.""" async def test_download_notebook_source(self, jp_fetch, jp_argv): """Notebook source should exist under the path `/_api/source`.""" response = await jp_fetch("_api", "source", method="GET") - assert response.code == 200, "/_api/source did not correctly return the downloaded notebook" + assert ( + response.code == 200 + ), "/_api/source did not correctly return the downloaded notebook" nb = json.loads(response.body) for key in ["cells", "metadata", "nbformat", "nbformat_minor"]: assert key in nb @pytest.mark.parametrize( - "jp_argv", ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'responses.ipynb')}"],) + "jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'responses.ipynb')}"],), ) class TestCustomResponse: """Tests gateway behavior when the notebook contains ResponseInfo cells.""" @@ -295,7 +322,8 @@ async def test_locking_semaphore_of_kernel_resources(self, jp_fetch, jp_argv): @pytest.mark.parametrize( - "jp_argv", ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'simple_api.ipynb')}"],) + "jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'simple_api.ipynb')}"],), ) class TestSwaggerSpec: async def test_generation_of_swagger_spec(self, jp_fetch, jp_argv): @@ -315,8 +343,12 @@ async def test_generation_of_swagger_spec(self, jp_fetch, jp_argv): response = await jp_fetch("_api", "spec", "swagger.json", method="GET") result = json.loads(response.body.decode("UTF-8")) - assert response.code == 200, "Swagger spec endpoint did not return the correct status code" - assert result == expected_response, "Swagger spec endpoint did not return the correct value" + assert ( + response.code == 200 + ), "Swagger spec endpoint did not return the correct status code" + assert ( + result == expected_response + ), "Swagger spec endpoint did not return the correct value" assert ( SwaggerSpecHandler.output is not None ), "Swagger spec output wasn't cached for later requests"