Skip to content

Commit

Permalink
feat: support pyarrow and column major data in mo.ui.table (#1091)
Browse files Browse the repository at this point in the history
* feat: support pyarrow and column major data in mo.ui.table

* fix test

* fixes

* maybe fix type

* fixes

* fix ciruclar import

* fix

* py3.8 compat

---------

Co-authored-by: Akshay Agrawal <[email protected]>
  • Loading branch information
mscolnick and akshayka authored Apr 9, 2024
1 parent 005052f commit f2d5ba7
Show file tree
Hide file tree
Showing 20 changed files with 907 additions and 277 deletions.
10 changes: 10 additions & 0 deletions marimo/_dependencies/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def require_plotly(why: str) -> None:
+ "You can install it with 'pip install plotly'"
) from None

@staticmethod
def has(pkg: str) -> bool:
"""Return True if any lib is installed."""
return importlib.util.find_spec(pkg) is not None

@staticmethod
def has_openai() -> bool:
"""Return True if openai is installed."""
Expand All @@ -107,6 +112,11 @@ def has_pandas() -> bool:
"""Return True if pandas is installed."""
return importlib.util.find_spec("pandas") is not None

@staticmethod
def has_pyarrow() -> bool:
"""Return True if pyarrow is installed."""
return importlib.util.find_spec("pyarrow") is not None

@staticmethod
def has_polars() -> bool:
"""Return True if polars is installed."""
Expand Down
49 changes: 3 additions & 46 deletions marimo/_output/data/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright 2024 Marimo. All rights reserved.
import base64
import io
from typing import TYPE_CHECKING, Union
from typing import Union

from marimo._dependencies.dependencies import DependencyManager
from marimo._plugins.core.media import is_data_empty
from marimo._runtime.context import get_context
from marimo._runtime.virtual_file import (
Expand All @@ -12,10 +11,6 @@
VirtualFileLifecycleItem,
)

if TYPE_CHECKING:
import pandas as pd
import polars as pl


def pdf(data: bytes) -> VirtualFile:
"""Create a virtual file from a PDF.
Expand Down Expand Up @@ -65,9 +60,7 @@ def audio(data: bytes, ext: str = "wav") -> VirtualFile:
return item.virtual_file


def csv(
data: Union[str, bytes, io.BytesIO, "pd.DataFrame", "pl.DataFrame"]
) -> VirtualFile:
def csv(data: Union[str, bytes, io.BytesIO]) -> VirtualFile:
"""Create a virtual file for CSV data.
**Args.**
Expand All @@ -79,30 +72,10 @@ def csv(
A `VirtualFile` object.
"""
# Pandas DataFrame
if DependencyManager.has_pandas():
import pandas as pd

if isinstance(data, pd.DataFrame):
buffer = data.to_csv(
index=False,
).encode("utf-8")
return any_data(buffer, ext="csv")

# Polars DataFrame
if DependencyManager.has_polars():
import polars as pl

if isinstance(data, pl.DataFrame):
buffer = data.write_csv().encode("utf-8")
return any_data(buffer, ext="csv")

return any_data(data, ext="csv") # type: ignore


def json(
data: Union[str, bytes, io.BytesIO, "pd.DataFrame", "pl.DataFrame"]
) -> VirtualFile:
def json(data: Union[str, bytes, io.BytesIO]) -> VirtualFile:
"""Create a virtual file for JSON data.
**Args.**
Expand All @@ -114,22 +87,6 @@ def json(
A `VirtualFile` object.
"""
# Pandas DataFrame
if DependencyManager.has_pandas():
import pandas as pd

if isinstance(data, pd.DataFrame):
buffer = data.to_json(orient="records").encode("utf-8")
return any_data(buffer, ext="json")

# Polars DataFrame
if DependencyManager.has_polars():
import polars as pl

if isinstance(data, pl.DataFrame):
buffer = data.write_json(row_oriented=True).encode("utf-8")
return any_data(buffer, ext="json")

return any_data(data, ext="json") # type: ignore


Expand Down
11 changes: 8 additions & 3 deletions marimo/_plugins/ui/_impl/data_explorer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright 2023 Marimo. All rights reserved.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Dict, Final, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Final, Optional, Union

from marimo._plugins.ui._impl.tables.utils import get_table_manager

if TYPE_CHECKING:
import pandas as pd
import polars as pl


import marimo._output.data.data as mo_data
Expand Down Expand Up @@ -36,18 +39,20 @@ class data_explorer(UIElement[Dict[str, Any], Dict[str, Any]]):

def __init__(
self,
df: pd.DataFrame,
df: Union[pd.DataFrame, pl.DataFrame],
on_change: Optional[Callable[[Dict[str, Any]], None]] = None,
) -> None:
self._data = df

manager = get_table_manager(df)

super().__init__(
component_name=data_explorer._name,
initial_value={},
on_change=on_change,
label="",
args={
"data": mo_data.csv(df).url,
"data": mo_data.csv(manager.to_csv()).url,
},
)

Expand Down
11 changes: 8 additions & 3 deletions marimo/_plugins/ui/_impl/dataframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Optional

from marimo._plugins.ui._impl.tables.pandas_table import (
PandasTableManagerFactory,
)

if TYPE_CHECKING:
import pandas as pd

Expand All @@ -13,7 +17,6 @@
import marimo._output.data.data as mo_data
from marimo._output.rich_help import mddoc
from marimo._plugins.ui._core.ui_element import UIElement
from marimo._plugins.ui._impl.utils.dataframe import get_row_headers
from marimo._runtime.functions import EmptyArgs, Function
from marimo._utils.parse_dataclass import parse_raw

Expand Down Expand Up @@ -92,6 +95,7 @@ def __init__(
pass

self._data = df
self._manager = PandasTableManagerFactory.create()(df)
self._transform_container = TransformsContainer(df)
self._error: Optional[str] = None

Expand Down Expand Up @@ -129,13 +133,14 @@ def get_dataframe(self, _args: EmptyArgs) -> GetDataFrameResponse:
if self._error is not None:
raise Exception(self._error)

url = mo_data.csv(self._value.head(LIMIT)).url
manager = PandasTableManagerFactory.create()(self._value.head(LIMIT))
url = mo_data.csv(manager.to_csv()).url
total_rows = len(self._value)
return GetDataFrameResponse(
url=url,
total_rows=total_rows,
has_more=total_rows > LIMIT,
row_headers=get_row_headers(self._value),
row_headers=manager.get_row_headers(),
)

def get_column_values(
Expand Down
Loading

0 comments on commit f2d5ba7

Please sign in to comment.