From 096a4b411646b947e9ea63850724f649b7a8beb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:02:53 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../_impl/dataframes/transforms/print_code.py | 14 ++++++------- .../ui/_impl/dataframes/test_dataframe.py | 20 ++++++++++++------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py b/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py index 0f70eedd79b..6400a91c6f8 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py @@ -282,29 +282,29 @@ def generate_where_clause_polars(where: Condition) -> str: agg_alias = f"{column_id}_{aggregation}" if aggregation == "count": aggs.append( - f'pl.col({col_ref}).count().alias({_as_literal(agg_alias)})' + f"pl.col({col_ref}).count().alias({_as_literal(agg_alias)})" ) elif aggregation == "sum": aggs.append( - f'pl.col({col_ref}).sum().alias({_as_literal(agg_alias)})' + f"pl.col({col_ref}).sum().alias({_as_literal(agg_alias)})" ) elif aggregation == "mean": aggs.append( - f'pl.col({col_ref}).mean().alias({_as_literal(agg_alias)})' + f"pl.col({col_ref}).mean().alias({_as_literal(agg_alias)})" ) elif aggregation == "median": aggs.append( - f'pl.col({col_ref}).median().alias({_as_literal(agg_alias)})' + f"pl.col({col_ref}).median().alias({_as_literal(agg_alias)})" ) elif aggregation == "min": aggs.append( - f'pl.col({col_ref}).min().alias({_as_literal(agg_alias)})' + f"pl.col({col_ref}).min().alias({_as_literal(agg_alias)})" ) elif aggregation == "max": aggs.append( - f'pl.col({col_ref}).max().alias({_as_literal(agg_alias)})' + f"pl.col({col_ref}).max().alias({_as_literal(agg_alias)})" ) - group_cols = [f'pl.col({_as_literal(col)})' for col in column_ids] + group_cols = [f"pl.col({_as_literal(col)})" for col in column_ids] return f"{df_name}.group_by([{', '.join(group_cols)}], maintain_order=True).agg([{', '.join(aggs)}])" # noqa: E501 elif transform.type == TransformType.SELECT_COLUMNS: diff --git a/tests/_plugins/ui/_impl/dataframes/test_dataframe.py b/tests/_plugins/ui/_impl/dataframes/test_dataframe.py index 28f4e9cdf9d..d68fb5c9bc8 100644 --- a/tests/_plugins/ui/_impl/dataframes/test_dataframe.py +++ b/tests/_plugins/ui/_impl/dataframes/test_dataframe.py @@ -280,10 +280,12 @@ def test_polars_groupby_alias() -> None: import marimo as mo # Create a test dataframe with age and group columns - df = pl.DataFrame({ - "group": ["a", "a", "b", "b"], - "age": [10, 20, 30, 40], - }) + df = pl.DataFrame( + { + "group": ["a", "a", "b", "b"], + "age": [10, 20, 30, 40], + } + ) # Test the transformation directly using TransformsContainer from marimo._plugins.ui._impl.dataframes.transforms.apply import ( TransformsContainer, @@ -291,13 +293,13 @@ def test_polars_groupby_alias() -> None: ) from marimo._plugins.ui._impl.dataframes.transforms.types import ( GroupByTransform, - TransformType, Transformations, + TransformType, ) handler = get_handler_for_dataframe(df) transform_container = TransformsContainer(df, handler) - + # Create and apply the transformation transform = GroupByTransform( type=TransformType.GROUP_BY, @@ -313,7 +315,10 @@ def test_polars_groupby_alias() -> None: assert "group" in transformed_df.columns assert "age_max" in transformed_df.columns assert transformed_df.shape == (2, 2) - assert transformed_df["age_max"].to_list() == [20, 40] # max age for each group + assert transformed_df["age_max"].to_list() == [ + 20, + 40, + ] # max age for each group # The resulting frame should have correct column names and values # Convert to dict and verify values @@ -330,6 +335,7 @@ def test_polars_groupby_alias() -> None: from marimo._plugins.ui._impl.dataframes.transforms.print_code import ( python_print_polars, ) + code = python_print_polars( "df", ["group", "age"],