Skip to content

Commit

Permalink
Fix Polars GroupBy Issue #3348 (#3349)
Browse files Browse the repository at this point in the history
# Fix Polars GroupBy Issue #3348

This PR fixes an issue where group by transformations in Polars were not
correctly referencing the original column names in the generated code.

## Changes
- Modified the code generation in `print_code.py` to use `pl.col()` for
group by columns
- Added test case `test_polars_groupby_alias` to verify proper column
name handling in group by transformations
- Ensures both group by and aggregation operations reference original
column names correctly

## Testing
Added a new test that:
- Creates a test DataFrame with "group" and "age" columns
- Applies a group by transformation with max aggregation
- Verifies the transformed DataFrame structure and values
- Checks that the generated code correctly uses `pl.col()` for both
group by and aggregation columns

## Link to Devin run
https://app.devin.ai/sessions/ba11f083aa6b4f63857d6f1fbe11ac00

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Myles Scolnick <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 6, 2025
1 parent ad6a250 commit ce8cede
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
18 changes: 11 additions & 7 deletions marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,33 +275,37 @@ def generate_where_clause_polars(where: Condition) -> str:
elif transform.type == TransformType.GROUP_BY:
column_ids, aggregation = transform.column_ids, transform.aggregation
aggs: list[str] = []
# Use _as_literal to properly escape column names
for column_id in all_columns:
if column_id not in column_ids:
col_ref = _as_literal(column_id)
agg_alias = f"{column_id}_{aggregation}"
if aggregation == "count":
aggs.append(
f'pl.col("{column_id}").count().alias("{column_id}_count")'
f"pl.col({col_ref}).count().alias({_as_literal(agg_alias)})"
)
elif aggregation == "sum":
aggs.append(
f'pl.col("{column_id}").sum().alias("{column_id}_sum")'
f"pl.col({col_ref}).sum().alias({_as_literal(agg_alias)})"
)
elif aggregation == "mean":
aggs.append(
f'pl.col("{column_id}").mean().alias("{column_id}_mean")'
f"pl.col({col_ref}).mean().alias({_as_literal(agg_alias)})"
)
elif aggregation == "median":
aggs.append(
f'pl.col("{column_id}").median().alias("{column_id}_median")'
f"pl.col({col_ref}).median().alias({_as_literal(agg_alias)})"
)
elif aggregation == "min":
aggs.append(
f'pl.col("{column_id}").min().alias("{column_id}_min")'
f"pl.col({col_ref}).min().alias({_as_literal(agg_alias)})"
)
elif aggregation == "max":
aggs.append(
f'pl.col("{column_id}").max().alias("{column_id}_max")'
f"pl.col({col_ref}).max().alias({_as_literal(agg_alias)})"
)
return f"{df_name}.group_by({_list_of_strings(column_ids)}, maintain_order=True).agg([{', '.join(aggs)}])" # noqa: E501
group_cols = [f"pl.col({_as_literal(col)})" for col in column_ids]
return f"{df_name}.group_by([{', '.join(group_cols)}], maintain_order=True).agg([{', '.join(aggs)}])" # noqa: E501

elif transform.type == TransformType.SELECT_COLUMNS:
column_ids = transform.column_ids
Expand Down
73 changes: 73 additions & 0 deletions tests/_plugins/ui/_impl/dataframes/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,79 @@ def test_dataframe_error_handling(df: Any) -> None:
with pytest.raises(ColumnNotFound):
subject._get_column_values(GetColumnValuesArgs(column="C"))

@staticmethod
@pytest.mark.skipif(not HAS_POLARS, reason="Polars not installed")
def test_polars_groupby_alias() -> None:
"""Test that group by operations use original column names correctly."""
import polars as pl

# Create a test dataframe with age and group columns
df = pl.DataFrame(
{
"group": ["a", "a", "b", "b"],
"age": [10, 20, 30, 40],
}
)
# Test the transformation directly using TransformsContainer
from marimo._plugins.ui._impl.dataframes.transforms.apply import (
TransformsContainer,
get_handler_for_dataframe,
)
from marimo._plugins.ui._impl.dataframes.transforms.types import (
GroupByTransform,
Transformations,
TransformType,
)

handler = get_handler_for_dataframe(df)
transform_container = TransformsContainer(df, handler)

# Create and apply the transformation
transform = GroupByTransform(
type=TransformType.GROUP_BY,
column_ids=["group"],
drop_na=True,
aggregation="max",
)
transformations = Transformations([transform])
transformed_df = transform_container.apply(transformations)

# Verify the transformed DataFrame
assert isinstance(transformed_df, pl.DataFrame)
assert "group" in transformed_df.columns
assert "age_max" in transformed_df.columns
assert transformed_df.shape == (2, 2)
assert transformed_df["age_max"].to_list() == [
20,
40,
] # max age for each group

# The resulting frame should have correct column names and values
# Convert to dict and verify values
result_dict = {
col: transformed_df[col].to_list()
for col in transformed_df.columns
}
assert result_dict == {
"group": ["a", "b"],
"age_max": [20, 40],
}

# Verify the generated code uses original column names
from marimo._plugins.ui._impl.dataframes.transforms.print_code import (
python_print_polars,
)

code = python_print_polars(
"df",
["group", "age"],
transform,
)
# Code should reference original "age" column, not "age_max"
assert 'pl.col("age")' in code
assert 'alias("age_max")' in code
assert 'pl.col("group")' in code # Original column name in group by


@pytest.mark.skipif(
not HAS_IBIS or not HAS_POLARS,
Expand Down

0 comments on commit ce8cede

Please sign in to comment.