Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvement: prompt to install missing deps during optional dependency checks #3363

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions marimo/_dependencies/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def has(self) -> bool:
has_dep = importlib.util.find_spec(self.pkg) is not None
if not has_dep:
return False
except ModuleNotFoundError:
except (ModuleNotFoundError, importlib.metadata.PackageNotFoundError):
# Could happen for nested imports (e.g. foo.bar)
return False

Expand Down Expand Up @@ -54,7 +54,8 @@ def require(self, why: str) -> None:
if not self.has():
message = f"{self.pkg} is required {why}."
sys.stderr.write(message + "\n\n")
raise ModuleNotFoundError(message) from None
# Including the `name` helps with auto-installations
raise ModuleNotFoundError(message, name=self.pkg) from None

def require_at_version(
self,
Expand Down
80 changes: 80 additions & 0 deletions marimo/_server/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from starlette.exceptions import HTTPException
from starlette.responses import JSONResponse

from marimo import _loggers
from marimo._messaging.ops import MissingPackageAlert
from marimo._runtime.packages.utils import is_python_isolated
from marimo._server.api.deps import AppState
from marimo._server.api.status import (
HTTPException as MarimoHTTPException,
is_client_error,
)
from marimo._server.ids import ConsumerId
from marimo._server.model import SessionMode
from marimo._server.sessions import send_message_to_consumer

if TYPE_CHECKING:
from starlette.requests import Request

LOGGER = _loggers.marimo_logger()


# Convert exceptions to JSON responses
# In the case of a ModuleNotFoundError, we try to send a MissingPackageAlert to the client
# to install the missing package
async def handle_error(request: Request, response: Any) -> Any:
if isinstance(response, HTTPException):
# Turn 403s into 401s to collect auth
if response.status_code == 403:
return JSONResponse(
status_code=401,
content={"detail": "Authorization header required"},
headers={"WWW-Authenticate": "Basic"},
)
return JSONResponse(
{"detail": response.detail},
status_code=response.status_code,
headers=response.headers,
)
if isinstance(response, MarimoHTTPException):
# Log server errors
if not is_client_error(response.status_code):
LOGGER.exception(response)
return JSONResponse(
{"detail": response.detail},
status_code=response.status_code,
)
if isinstance(response, ModuleNotFoundError) and response.name:
try:
app_state = AppState(request)
session_id = app_state.get_current_session_id()
session = app_state.get_current_session()
# If we're in an edit session, send an package installation request
if (
session_id is not None
and session is not None
and app_state.mode == SessionMode.EDIT
):
send_message_to_consumer(
session=session,
operation=MissingPackageAlert(
packages=[response.name],
isolated=is_python_isolated(),
),
consumer_id=ConsumerId(session_id),
)
return JSONResponse({"detail": str(response)}, status_code=500)
except Exception as e:
LOGGER.warning(f"Failed to send missing package alert: {e}")
if isinstance(response, NotImplementedError):
return JSONResponse({"detail": "Not supported"}, status_code=501)
if isinstance(response, TypeError):
return JSONResponse({"detail": str(response)}, status_code=500)
if isinstance(response, Exception):
return JSONResponse({"detail": str(response)}, status_code=500)
return response
39 changes: 2 additions & 37 deletions marimo/_server/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, List, Optional

from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse

from marimo import _loggers
from marimo._server.api.auth import (
Expand All @@ -25,49 +24,15 @@
from marimo._server.api.router import build_routes
from marimo._server.api.status import (
HTTPException as MarimoHTTPException,
is_client_error,
)
from marimo._server.errors import handle_error

if TYPE_CHECKING:
from starlette.requests import Request
from starlette.types import Lifespan

LOGGER = _loggers.marimo_logger()


# Convert exceptions to JSON responses
async def handle_error(request: Request, response: Any) -> Any:
del request
if isinstance(response, HTTPException):
# Turn 403s into 401s to collect auth
if response.status_code == 403:
return JSONResponse(
status_code=401,
content={"detail": "Authorization header required"},
headers={"WWW-Authenticate": "Basic"},
)
return JSONResponse(
{"detail": response.detail},
status_code=response.status_code,
headers=response.headers,
)
if isinstance(response, MarimoHTTPException):
# Log server errors
if not is_client_error(response.status_code):
LOGGER.exception(response)
return JSONResponse(
{"detail": response.detail},
status_code=response.status_code,
)
if isinstance(response, NotImplementedError):
return JSONResponse({"detail": "Not supported"}, status_code=501)
if isinstance(response, TypeError):
return JSONResponse({"detail": str(response)}, status_code=500)
if isinstance(response, Exception):
return JSONResponse({"detail": str(response)}, status_code=500)
return response


# Create app
def create_starlette_app(
*,
Expand Down
11 changes: 11 additions & 0 deletions marimo/_server/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,3 +1046,14 @@ def start(self) -> None:

def stop(self) -> None:
pass


def send_message_to_consumer(
session: Session,
operation: MessageOperation,
consumer_id: Optional[ConsumerId],
) -> None:
if session.connection_state() == ConnectionState.OPEN:
for consumer, c_id in session.room.consumers.items():
if c_id == consumer_id:
consumer.write_operation(operation)
3 changes: 2 additions & 1 deletion marimo/_sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def sql(
else:
raise ModuleNotFoundError(
"pandas or polars is required to execute sql. "
+ "You can install them with 'pip install pandas polars'"
+ "You can install them with 'pip install pandas polars'",
name="polars",
)

if output:
Expand Down
9 changes: 9 additions & 0 deletions tests/_dependencies/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def test_without_dependencies() -> None:
with pytest.raises(ModuleNotFoundError) as excinfo:
missing.require("for testing")

assert excinfo.value.name == "missing"

assert "for testing" in str(excinfo.value)


Expand Down Expand Up @@ -86,6 +88,13 @@ def test_versions():
)


def test_has_as_version_when_not_installed():
missing = Dependency("missing")
assert missing is not None
assert missing.has() is False
assert missing.has_at_version(min_version="2.0.0") is False


def test_version_check():
# within range
assert (
Expand Down
4 changes: 0 additions & 4 deletions tests/_server/api/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ def test_proxy_static_file_streaming(
content_length += len(chunk)
assert content_length == 1024 * 1024

@pytest.mark.asyncio
async def test_http_client_streaming(
self, app_with_proxy: Starlette
) -> None:
Expand All @@ -465,7 +464,6 @@ async def test_http_client_streaming(
assert len(chunks) > 0
await response.aclose()

@pytest.mark.asyncio
async def test_http_client_body_types(
self, app_with_proxy: Starlette
) -> None:
Expand Down Expand Up @@ -499,7 +497,6 @@ def read(self, size: int | None = -1) -> bytes:
assert response.status_code == 200
await response.aclose()

@pytest.mark.asyncio
async def test_http_client_headers(
self, app_with_proxy: Starlette
) -> None:
Expand Down Expand Up @@ -564,7 +561,6 @@ def test_original_app_auth_still_works(
assert response.status_code == 401, response.text
assert response.headers.get("Set-Cookie") is None

@pytest.mark.asyncio
async def test_proxy_websocket(self, app_with_proxy: Starlette) -> None:
client = TestClient(app_with_proxy)
with client.websocket_connect("/proxy/ws") as websocket:
Expand Down
84 changes: 84 additions & 0 deletions tests/_server/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

from unittest.mock import MagicMock, patch

from starlette.exceptions import HTTPException
from starlette.requests import Request

from marimo._server.api.status import HTTPException as MarimoHTTPException
from marimo._server.errors import handle_error
from marimo._server.model import SessionMode


async def test_http_exception():
# Test 403 to 401 conversion
exc = HTTPException(status_code=403)
response = await handle_error(Request({"type": "http"}), exc)
assert response.status_code == 401
assert response.body == b'{"detail":"Authorization header required"}'
assert response.headers["WWW-Authenticate"] == "Basic"

# Test other HTTP exceptions
exc = HTTPException(status_code=404, detail="Not found")
response = await handle_error(Request({"type": "http"}), exc)
assert response.status_code == 404
assert response.body == b'{"detail":"Not found"}'


async def test_marimo_http_exception():
exc = MarimoHTTPException(status_code=400, detail="Bad request")
response = await handle_error(Request({"type": "http"}), exc)
assert response.status_code == 400
assert response.body == b'{"detail":"Bad request"}'


async def test_module_not_found_error():
# Mock AppState and session
mock_app_state = MagicMock()
mock_app_state.mode = SessionMode.EDIT
mock_app_state.get_current_session_id.return_value = "test_session"
mock_app_state.get_current_session.return_value = MagicMock()
with (
patch("marimo._server.errors.AppState", return_value=mock_app_state),
patch(
"marimo._server.errors.send_message_to_consumer"
) as mock_send_message,
):
exc = ModuleNotFoundError(
"No module named 'missing_package'", name="missing_package"
)
response = await handle_error(Request({"type": "http"}), exc)

assert response.status_code == 500
assert (
response.body
== b'{"detail":"No module named \'missing_package\'"}'
)
mock_send_message.assert_called_once()


async def test_not_implemented_error():
exc = NotImplementedError("Feature not implemented")
response = await handle_error(Request({"type": "http"}), exc)
assert response.status_code == 501
assert response.body == b'{"detail":"Not supported"}'


async def test_type_error():
exc = TypeError("Invalid type")
response = await handle_error(Request({"type": "http"}), exc)
assert response.status_code == 500
assert response.body == b'{"detail":"Invalid type"}'


async def test_generic_exception():
exc = Exception("Something went wrong")
response = await handle_error(Request({"type": "http"}), exc)
assert response.status_code == 500
assert response.body == b'{"detail":"Something went wrong"}'


async def test_non_exception_response():
response = "Not an exception"
result = await handle_error(Request({"type": "http"}), response)
assert result == response
Loading