From 066bcab15202e4fb454a688c0ae4ef482aa37784 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Tue, 7 Jan 2025 18:51:01 -0500 Subject: [PATCH] improvement: prompt to install missing deps during optional dependency checks --- marimo/_dependencies/dependencies.py | 5 +- marimo/_server/errors.py | 80 ++++++++++++++++++++++ marimo/_server/main.py | 39 +---------- marimo/_server/sessions.py | 11 ++++ marimo/_sql/sql.py | 3 +- tests/_dependencies/test_dependencies.py | 9 +++ tests/_server/api/test_middleware.py | 4 -- tests/_server/test_errors.py | 84 ++++++++++++++++++++++++ 8 files changed, 191 insertions(+), 44 deletions(-) create mode 100644 marimo/_server/errors.py create mode 100644 tests/_server/test_errors.py diff --git a/marimo/_dependencies/dependencies.py b/marimo/_dependencies/dependencies.py index cf2eef406c5..6afe5229e48 100644 --- a/marimo/_dependencies/dependencies.py +++ b/marimo/_dependencies/dependencies.py @@ -20,7 +20,7 @@ def has(self) -> bool: has_dep = importlib.util.find_spec(self.pkg) is not None if not has_dep: return False - except ModuleNotFoundError: + except (ModuleNotFoundError, importlib.metadata.PackageNotFoundError): # Could happen for nested imports (e.g. foo.bar) return False @@ -54,7 +54,8 @@ def require(self, why: str) -> None: if not self.has(): message = f"{self.pkg} is required {why}." sys.stderr.write(message + "\n\n") - raise ModuleNotFoundError(message) from None + # Including the `name` helps with auto-installations + raise ModuleNotFoundError(message, name=self.pkg) from None def require_at_version( self, diff --git a/marimo/_server/errors.py b/marimo/_server/errors.py new file mode 100644 index 00000000000..99a88ee29c7 --- /dev/null +++ b/marimo/_server/errors.py @@ -0,0 +1,80 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from starlette.exceptions import HTTPException +from starlette.responses import JSONResponse + +from marimo import _loggers +from marimo._messaging.ops import MissingPackageAlert +from marimo._runtime.packages.utils import is_python_isolated +from marimo._server.api.deps import AppState +from marimo._server.api.status import ( + HTTPException as MarimoHTTPException, + is_client_error, +) +from marimo._server.ids import ConsumerId +from marimo._server.model import SessionMode +from marimo._server.sessions import send_message_to_consumer + +if TYPE_CHECKING: + from starlette.requests import Request + +LOGGER = _loggers.marimo_logger() + + +# Convert exceptions to JSON responses +# In the case of a ModuleNotFoundError, we try to send a MissingPackageAlert to the client +# to install the missing package +async def handle_error(request: Request, response: Any) -> Any: + if isinstance(response, HTTPException): + # Turn 403s into 401s to collect auth + if response.status_code == 403: + return JSONResponse( + status_code=401, + content={"detail": "Authorization header required"}, + headers={"WWW-Authenticate": "Basic"}, + ) + return JSONResponse( + {"detail": response.detail}, + status_code=response.status_code, + headers=response.headers, + ) + if isinstance(response, MarimoHTTPException): + # Log server errors + if not is_client_error(response.status_code): + LOGGER.exception(response) + return JSONResponse( + {"detail": response.detail}, + status_code=response.status_code, + ) + if isinstance(response, ModuleNotFoundError) and response.name: + try: + app_state = AppState(request) + session_id = app_state.get_current_session_id() + session = app_state.get_current_session() + # If we're in an edit session, send an package installation request + if ( + session_id is not None + and session is not None + and app_state.mode == SessionMode.EDIT + ): + send_message_to_consumer( + session=session, + operation=MissingPackageAlert( + packages=[response.name], + isolated=is_python_isolated(), + ), + consumer_id=ConsumerId(session_id), + ) + return JSONResponse({"detail": str(response)}, status_code=500) + except Exception as e: + LOGGER.warning(f"Failed to send missing package alert: {e}") + if isinstance(response, NotImplementedError): + return JSONResponse({"detail": "Not supported"}, status_code=501) + if isinstance(response, TypeError): + return JSONResponse({"detail": str(response)}, status_code=500) + if isinstance(response, Exception): + return JSONResponse({"detail": str(response)}, status_code=500) + return response diff --git a/marimo/_server/main.py b/marimo/_server/main.py index 623f20b3585..8f7568c8762 100644 --- a/marimo/_server/main.py +++ b/marimo/_server/main.py @@ -1,14 +1,13 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, List, Optional from starlette.applications import Starlette from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.cors import CORSMiddleware -from starlette.responses import JSONResponse from marimo import _loggers from marimo._server.api.auth import ( @@ -25,49 +24,15 @@ from marimo._server.api.router import build_routes from marimo._server.api.status import ( HTTPException as MarimoHTTPException, - is_client_error, ) +from marimo._server.errors import handle_error if TYPE_CHECKING: - from starlette.requests import Request from starlette.types import Lifespan LOGGER = _loggers.marimo_logger() -# Convert exceptions to JSON responses -async def handle_error(request: Request, response: Any) -> Any: - del request - if isinstance(response, HTTPException): - # Turn 403s into 401s to collect auth - if response.status_code == 403: - return JSONResponse( - status_code=401, - content={"detail": "Authorization header required"}, - headers={"WWW-Authenticate": "Basic"}, - ) - return JSONResponse( - {"detail": response.detail}, - status_code=response.status_code, - headers=response.headers, - ) - if isinstance(response, MarimoHTTPException): - # Log server errors - if not is_client_error(response.status_code): - LOGGER.exception(response) - return JSONResponse( - {"detail": response.detail}, - status_code=response.status_code, - ) - if isinstance(response, NotImplementedError): - return JSONResponse({"detail": "Not supported"}, status_code=501) - if isinstance(response, TypeError): - return JSONResponse({"detail": str(response)}, status_code=500) - if isinstance(response, Exception): - return JSONResponse({"detail": str(response)}, status_code=500) - return response - - # Create app def create_starlette_app( *, diff --git a/marimo/_server/sessions.py b/marimo/_server/sessions.py index 20f4c9e2a2f..11ca538d2b6 100644 --- a/marimo/_server/sessions.py +++ b/marimo/_server/sessions.py @@ -1046,3 +1046,14 @@ def start(self) -> None: def stop(self) -> None: pass + + +def send_message_to_consumer( + session: Session, + operation: MessageOperation, + consumer_id: Optional[ConsumerId], +) -> None: + if session.connection_state() == ConnectionState.OPEN: + for consumer, c_id in session.room.consumers.items(): + if c_id == consumer_id: + consumer.write_operation(operation) diff --git a/marimo/_sql/sql.py b/marimo/_sql/sql.py index d706cd1589d..5fb82369c6e 100644 --- a/marimo/_sql/sql.py +++ b/marimo/_sql/sql.py @@ -85,7 +85,8 @@ def sql( else: raise ModuleNotFoundError( "pandas or polars is required to execute sql. " - + "You can install them with 'pip install pandas polars'" + + "You can install them with 'pip install pandas polars'", + name="polars", ) if output: diff --git a/tests/_dependencies/test_dependencies.py b/tests/_dependencies/test_dependencies.py index c1464c3f8fc..091fa4d2311 100644 --- a/tests/_dependencies/test_dependencies.py +++ b/tests/_dependencies/test_dependencies.py @@ -36,6 +36,8 @@ def test_without_dependencies() -> None: with pytest.raises(ModuleNotFoundError) as excinfo: missing.require("for testing") + assert excinfo.value.name == "missing" + assert "for testing" in str(excinfo.value) @@ -86,6 +88,13 @@ def test_versions(): ) +def test_has_as_version_when_not_installed(): + missing = Dependency("missing") + assert missing is not None + assert missing.has() is False + assert missing.has_at_version(min_version="2.0.0") is False + + def test_version_check(): # within range assert ( diff --git a/tests/_server/api/test_middleware.py b/tests/_server/api/test_middleware.py index 55f93e7a228..e4bfd30ca49 100644 --- a/tests/_server/api/test_middleware.py +++ b/tests/_server/api/test_middleware.py @@ -446,7 +446,6 @@ def test_proxy_static_file_streaming( content_length += len(chunk) assert content_length == 1024 * 1024 - @pytest.mark.asyncio async def test_http_client_streaming( self, app_with_proxy: Starlette ) -> None: @@ -465,7 +464,6 @@ async def test_http_client_streaming( assert len(chunks) > 0 await response.aclose() - @pytest.mark.asyncio async def test_http_client_body_types( self, app_with_proxy: Starlette ) -> None: @@ -499,7 +497,6 @@ def read(self, size: int | None = -1) -> bytes: assert response.status_code == 200 await response.aclose() - @pytest.mark.asyncio async def test_http_client_headers( self, app_with_proxy: Starlette ) -> None: @@ -564,7 +561,6 @@ def test_original_app_auth_still_works( assert response.status_code == 401, response.text assert response.headers.get("Set-Cookie") is None - @pytest.mark.asyncio async def test_proxy_websocket(self, app_with_proxy: Starlette) -> None: client = TestClient(app_with_proxy) with client.websocket_connect("/proxy/ws") as websocket: diff --git a/tests/_server/test_errors.py b/tests/_server/test_errors.py new file mode 100644 index 00000000000..57d7bbfe6e0 --- /dev/null +++ b/tests/_server/test_errors.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from starlette.exceptions import HTTPException +from starlette.requests import Request + +from marimo._server.api.status import HTTPException as MarimoHTTPException +from marimo._server.errors import handle_error +from marimo._server.model import SessionMode + + +async def test_http_exception(): + # Test 403 to 401 conversion + exc = HTTPException(status_code=403) + response = await handle_error(Request({"type": "http"}), exc) + assert response.status_code == 401 + assert response.body == b'{"detail":"Authorization header required"}' + assert response.headers["WWW-Authenticate"] == "Basic" + + # Test other HTTP exceptions + exc = HTTPException(status_code=404, detail="Not found") + response = await handle_error(Request({"type": "http"}), exc) + assert response.status_code == 404 + assert response.body == b'{"detail":"Not found"}' + + +async def test_marimo_http_exception(): + exc = MarimoHTTPException(status_code=400, detail="Bad request") + response = await handle_error(Request({"type": "http"}), exc) + assert response.status_code == 400 + assert response.body == b'{"detail":"Bad request"}' + + +async def test_module_not_found_error(): + # Mock AppState and session + mock_app_state = MagicMock() + mock_app_state.mode = SessionMode.EDIT + mock_app_state.get_current_session_id.return_value = "test_session" + mock_app_state.get_current_session.return_value = MagicMock() + with ( + patch("marimo._server.errors.AppState", return_value=mock_app_state), + patch( + "marimo._server.errors.send_message_to_consumer" + ) as mock_send_message, + ): + exc = ModuleNotFoundError( + "No module named 'missing_package'", name="missing_package" + ) + response = await handle_error(Request({"type": "http"}), exc) + + assert response.status_code == 500 + assert ( + response.body + == b'{"detail":"No module named \'missing_package\'"}' + ) + mock_send_message.assert_called_once() + + +async def test_not_implemented_error(): + exc = NotImplementedError("Feature not implemented") + response = await handle_error(Request({"type": "http"}), exc) + assert response.status_code == 501 + assert response.body == b'{"detail":"Not supported"}' + + +async def test_type_error(): + exc = TypeError("Invalid type") + response = await handle_error(Request({"type": "http"}), exc) + assert response.status_code == 500 + assert response.body == b'{"detail":"Invalid type"}' + + +async def test_generic_exception(): + exc = Exception("Something went wrong") + response = await handle_error(Request({"type": "http"}), exc) + assert response.status_code == 500 + assert response.body == b'{"detail":"Something went wrong"}' + + +async def test_non_exception_response(): + response = "Not an exception" + result = await handle_error(Request({"type": "http"}), response) + assert result == response