Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 6, 2025
1 parent e4f2938 commit 096a4b4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
14 changes: 7 additions & 7 deletions marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,29 +282,29 @@ def generate_where_clause_polars(where: Condition) -> str:
agg_alias = f"{column_id}_{aggregation}"
if aggregation == "count":
aggs.append(
f'pl.col({col_ref}).count().alias({_as_literal(agg_alias)})'
f"pl.col({col_ref}).count().alias({_as_literal(agg_alias)})"
)
elif aggregation == "sum":
aggs.append(
f'pl.col({col_ref}).sum().alias({_as_literal(agg_alias)})'
f"pl.col({col_ref}).sum().alias({_as_literal(agg_alias)})"
)
elif aggregation == "mean":
aggs.append(
f'pl.col({col_ref}).mean().alias({_as_literal(agg_alias)})'
f"pl.col({col_ref}).mean().alias({_as_literal(agg_alias)})"
)
elif aggregation == "median":
aggs.append(
f'pl.col({col_ref}).median().alias({_as_literal(agg_alias)})'
f"pl.col({col_ref}).median().alias({_as_literal(agg_alias)})"
)
elif aggregation == "min":
aggs.append(
f'pl.col({col_ref}).min().alias({_as_literal(agg_alias)})'
f"pl.col({col_ref}).min().alias({_as_literal(agg_alias)})"
)
elif aggregation == "max":
aggs.append(
f'pl.col({col_ref}).max().alias({_as_literal(agg_alias)})'
f"pl.col({col_ref}).max().alias({_as_literal(agg_alias)})"
)
group_cols = [f'pl.col({_as_literal(col)})' for col in column_ids]
group_cols = [f"pl.col({_as_literal(col)})" for col in column_ids]
return f"{df_name}.group_by([{', '.join(group_cols)}], maintain_order=True).agg([{', '.join(aggs)}])" # noqa: E501

elif transform.type == TransformType.SELECT_COLUMNS:
Expand Down
20 changes: 13 additions & 7 deletions tests/_plugins/ui/_impl/dataframes/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,24 +280,26 @@ def test_polars_groupby_alias() -> None:
import marimo as mo

# Create a test dataframe with age and group columns
df = pl.DataFrame({
"group": ["a", "a", "b", "b"],
"age": [10, 20, 30, 40],
})
df = pl.DataFrame(
{
"group": ["a", "a", "b", "b"],
"age": [10, 20, 30, 40],
}
)
# Test the transformation directly using TransformsContainer
from marimo._plugins.ui._impl.dataframes.transforms.apply import (
TransformsContainer,
get_handler_for_dataframe,
)
from marimo._plugins.ui._impl.dataframes.transforms.types import (
GroupByTransform,
TransformType,
Transformations,
TransformType,
)

handler = get_handler_for_dataframe(df)
transform_container = TransformsContainer(df, handler)

# Create and apply the transformation
transform = GroupByTransform(
type=TransformType.GROUP_BY,
Expand All @@ -313,7 +315,10 @@ def test_polars_groupby_alias() -> None:
assert "group" in transformed_df.columns
assert "age_max" in transformed_df.columns
assert transformed_df.shape == (2, 2)
assert transformed_df["age_max"].to_list() == [20, 40] # max age for each group
assert transformed_df["age_max"].to_list() == [
20,
40,
] # max age for each group

# The resulting frame should have correct column names and values
# Convert to dict and verify values
Expand All @@ -330,6 +335,7 @@ def test_polars_groupby_alias() -> None:
from marimo._plugins.ui._impl.dataframes.transforms.print_code import (
python_print_polars,
)

code = python_print_polars(
"df",
["group", "age"],
Expand Down

0 comments on commit 096a4b4

Please sign in to comment.