Skip to content

Commit

Permalink
fix(polars-groupby): reference original column name in aggregator #3348
Browse files Browse the repository at this point in the history
Co-Authored-By: Myles Scolnick <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and mscolnick committed Jan 6, 2025
1 parent a63cd29 commit e4f2938
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 7 deletions.
18 changes: 11 additions & 7 deletions marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,33 +275,37 @@ def generate_where_clause_polars(where: Condition) -> str:
elif transform.type == TransformType.GROUP_BY:
column_ids, aggregation = transform.column_ids, transform.aggregation
aggs: list[str] = []
# Use _as_literal to properly escape column names
for column_id in all_columns:
if column_id not in column_ids:
col_ref = _as_literal(column_id)
agg_alias = f"{column_id}_{aggregation}"
if aggregation == "count":
aggs.append(
f'pl.col("{column_id}").count().alias("{column_id}_count")'
f'pl.col({col_ref}).count().alias({_as_literal(agg_alias)})'
)
elif aggregation == "sum":
aggs.append(
f'pl.col("{column_id}").sum().alias("{column_id}_sum")'
f'pl.col({col_ref}).sum().alias({_as_literal(agg_alias)})'
)
elif aggregation == "mean":
aggs.append(
f'pl.col("{column_id}").mean().alias("{column_id}_mean")'
f'pl.col({col_ref}).mean().alias({_as_literal(agg_alias)})'
)
elif aggregation == "median":
aggs.append(
f'pl.col("{column_id}").median().alias("{column_id}_median")'
f'pl.col({col_ref}).median().alias({_as_literal(agg_alias)})'
)
elif aggregation == "min":
aggs.append(
f'pl.col("{column_id}").min().alias("{column_id}_min")'
f'pl.col({col_ref}).min().alias({_as_literal(agg_alias)})'
)
elif aggregation == "max":
aggs.append(
f'pl.col("{column_id}").max().alias("{column_id}_max")'
f'pl.col({col_ref}).max().alias({_as_literal(agg_alias)})'
)
return f"{df_name}.group_by({_list_of_strings(column_ids)}, maintain_order=True).agg([{', '.join(aggs)}])" # noqa: E501
group_cols = [f'pl.col({_as_literal(col)})' for col in column_ids]
return f"{df_name}.group_by([{', '.join(group_cols)}], maintain_order=True).agg([{', '.join(aggs)}])" # noqa: E501

elif transform.type == TransformType.SELECT_COLUMNS:
column_ids = transform.column_ids
Expand Down
69 changes: 69 additions & 0 deletions tests/_plugins/ui/_impl/dataframes/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,75 @@ def test_dataframe_error_handling(df: Any) -> None:
with pytest.raises(ColumnNotFound):
subject._get_column_values(GetColumnValuesArgs(column="C"))

@staticmethod
@pytest.mark.skipif(not HAS_POLARS, reason="Polars not installed")
def test_polars_groupby_alias() -> None:
"""Test that group by operations use original column names correctly."""
import polars as pl

import marimo as mo

# Create a test dataframe with age and group columns
df = pl.DataFrame({
"group": ["a", "a", "b", "b"],
"age": [10, 20, 30, 40],
})
# Test the transformation directly using TransformsContainer
from marimo._plugins.ui._impl.dataframes.transforms.apply import (
TransformsContainer,
get_handler_for_dataframe,
)
from marimo._plugins.ui._impl.dataframes.transforms.types import (
GroupByTransform,
TransformType,
Transformations,
)

handler = get_handler_for_dataframe(df)
transform_container = TransformsContainer(df, handler)

# Create and apply the transformation
transform = GroupByTransform(
type=TransformType.GROUP_BY,
column_ids=["group"],
drop_na=True,
aggregation="max",
)
transformations = Transformations([transform])
transformed_df = transform_container.apply(transformations)

# Verify the transformed DataFrame
assert isinstance(transformed_df, pl.DataFrame)
assert "group" in transformed_df.columns
assert "age_max" in transformed_df.columns
assert transformed_df.shape == (2, 2)
assert transformed_df["age_max"].to_list() == [20, 40] # max age for each group

# The resulting frame should have correct column names and values
# Convert to dict and verify values
result_dict = {
col: transformed_df[col].to_list()
for col in transformed_df.columns
}
assert result_dict == {
"group": ["a", "b"],
"age_max": [20, 40],
}

# Verify the generated code uses original column names
from marimo._plugins.ui._impl.dataframes.transforms.print_code import (
python_print_polars,
)
code = python_print_polars(
"df",
["group", "age"],
transform,
)
# Code should reference original "age" column, not "age_max"
assert 'pl.col("age")' in code
assert 'alias("age_max")' in code
assert 'pl.col("group")' in code # Original column name in group by


@pytest.mark.skipif(
not HAS_IBIS or not HAS_POLARS,
Expand Down

0 comments on commit e4f2938

Please sign in to comment.