Skip to content

Commit

Permalink
fix: mo.notebook_location() with incorrect :// (#3393)
Browse files Browse the repository at this point in the history
When converting to `Path()`, `://` gets normalized to `:/`, so this adds
it back if it was removed
  • Loading branch information
mscolnick authored Jan 10, 2025
1 parent 9c07b3e commit b9090ce
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
22 changes: 17 additions & 5 deletions marimo/_runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,17 @@ def notebook_dir() -> pathlib.Path | None:
return None


class URLPath(pathlib.PurePosixPath):
"""
Wrapper around pathlib.Path that preserves the "://" in the URL protocol.
"""

def __str__(self) -> str:
return super().__str__().replace(":/", "://")


@mddoc
def notebook_location() -> pathlib.Path | None:
def notebook_location() -> pathlib.PurePath | None:
"""Get the location of the currently executing notebook.
In WASM, this is the URL of webpage, for example, `https://my-site.com`.
Expand All @@ -342,7 +351,7 @@ def notebook_location() -> pathlib.Path | None:
```
Returns:
pathlib.Path | None: A pathlib.Path object representing the URL or directory of the current
Path | None: A Path object representing the URL or directory of the current
notebook, or None if the notebook's directory cannot be determined.
"""
if is_pyodide():
Expand All @@ -352,11 +361,14 @@ def notebook_location() -> pathlib.Path | None:
# The location looks like https://marimo-team.github.io/marimo-gh-pages-template/notebooks/assets/worker-BxJ8HeOy.js
# We want to crawl out of the assets/ folder
if "assets" in path_location.parts:
return path_location.parent.parent
return path_location
return URLPath(str(path_location.parent.parent))
return URLPath(str(path_location))

else:
return notebook_dir()
nb_dir = notebook_dir()
if nb_dir is not None:
return nb_dir
return None


@dataclasses.dataclass
Expand Down
13 changes: 10 additions & 3 deletions tests/_runtime/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import sys
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Sequence

import pytest
Expand Down Expand Up @@ -1039,6 +1038,9 @@ async def test_notebook_location(
assert "dir" in k.globals
assert k.globals["dir"] is k.globals["loc"]

@pytest.mark.skipif(
sys.platform == "win32", reason="Windows paths behave differently"
)
async def test_notebook_location_for_pyodide(
self, any_kernel: Kernel, exec_req: ExecReqProvider
) -> None:
Expand All @@ -1060,8 +1062,13 @@ async def test_notebook_location_for_pyodide(
)
]
)
assert k.globals["loc"] == Path(
"https://marimo-team.github.io/marimo-gh-pages-template/notebooks"
assert (
str(k.globals["loc"])
== "https://marimo-team.github.io/marimo-gh-pages-template/notebooks"
)
assert (
str(k.globals["loc"] / "public" / "data.csv")
== "https://marimo-team.github.io/marimo-gh-pages-template/notebooks/public/data.csv"
)
finally:
del sys.modules["pyodide"]
Expand Down

0 comments on commit b9090ce

Please sign in to comment.