diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 77a88fa796..5a464af29f 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -1,24 +1,150 @@ from __future__ import annotations -import collections import warnings -from typing import TYPE_CHECKING, Any, ClassVar +from functools import lru_cache +from itertools import chain +from operator import methodcaller +from typing import TYPE_CHECKING, Any, ClassVar, Literal from narwhals._compliant import EagerGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._pandas_like.utils import select_columns_by_name from narwhals._utils import find_stacklevel +from narwhals.dependencies import is_pandas_like_dataframe if TYPE_CHECKING: - from collections.abc import Iterator, Mapping, Sequence + from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence - from narwhals._compliant.typing import NarwhalsAggregation + import pandas as pd + from pandas.api.typing import DataFrameGroupBy as _NativeGroupBy + from typing_extensions import TypeAlias, Unpack + + from narwhals._compliant.typing import NarwhalsAggregation, ScalarKwargs from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr + NativeGroupBy: TypeAlias = "_NativeGroupBy[tuple[str, ...], Literal[True]]" + +NativeApply: TypeAlias = "Callable[[pd.DataFrame], pd.Series[Any]]" +InefficientNativeAggregation: TypeAlias = Literal["cov", "skew"] +NativeAggregation: TypeAlias = Literal[ + "any", + "all", + "count", + "first", + "idxmax", + "idxmin", + "last", + "max", + "mean", + "median", + "min", + "nunique", + "prod", + "quantile", + "sem", + "size", + "std", + "sum", + "var", + InefficientNativeAggregation, +] +"""https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html#built-in-aggregation-methods""" + +_NativeAgg: TypeAlias = "Callable[[Any], pd.DataFrame | pd.Series[Any]]" +"""Equivalent to a partial method call on `DataFrameGroupBy`.""" + + +NonStrHashable: TypeAlias = Any +"""Because `pandas` allows *"names"* like that 😭""" + + +@lru_cache(maxsize=32) +def _native_agg(name: NativeAggregation, /, **kwds: Unpack[ScalarKwargs]) -> _NativeAgg: + if name == "nunique": + return methodcaller(name, dropna=False) + if not kwds or kwds.get("ddof") == 1: + return methodcaller(name) + return methodcaller(name, **kwds) + + +class AggExpr: + """Wrapper storing the intermediate state per-`PandasLikeExpr`. + + There's a lot of edge cases to handle, so aim to evaluate as little + as possible - and store anything that's needed twice. + + Warning: + While a `PandasLikeExpr` can be reused - this wrapper is valid **only** + in a single `.agg(...)` operation. + """ + + expr: PandasLikeExpr + output_names: Sequence[str] + aliases: Sequence[str] + + def __init__(self, expr: PandasLikeExpr) -> None: + self.expr = expr + self.output_names = () + self.aliases = () + self._leaf_name: NarwhalsAggregation | Any = "" + + def with_expand_names(self, group_by: PandasLikeGroupBy, /) -> AggExpr: + """**Mutating operation**. + + Stores the results of `evaluate_output_names_and_aliases`. + """ + df = group_by.compliant + exclude = group_by.exclude + self.output_names, self.aliases = evaluate_output_names_and_aliases( + self.expr, df, exclude + ) + return self -class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", str]): - _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = { + def _getitem_aggs( + self, group_by: PandasLikeGroupBy, / + ) -> pd.DataFrame | pd.Series[Any]: + """Evaluate the wrapped expression as a group_by operation.""" + result: pd.DataFrame | pd.Series[Any] + names = self.output_names + if self.is_len() and self.is_anonymous(): + result = group_by._grouped.size() + else: + select = names[0] if len(names) == 1 else list(names) + result = self.native_agg()(group_by._grouped[select]) + if is_pandas_like_dataframe(result): + result.columns = list(self.aliases) + else: + result.name = self.aliases[0] + return result + + def is_len(self) -> bool: + return self.leaf_name == "len" + + def is_anonymous(self) -> bool: + return self.expr._depth == 0 + + @property + def kwargs(self) -> ScalarKwargs: + return self.expr._scalar_kwargs + + @property + def leaf_name(self) -> NarwhalsAggregation | Any: + if name := self._leaf_name: + return name + self._leaf_name = PandasLikeGroupBy._leaf_name(self.expr) + return self._leaf_name + + def native_agg(self) -> _NativeAgg: + """Return a partial `DataFrameGroupBy` method, missing only `self`.""" + return _native_agg( + PandasLikeGroupBy._remap_expr_name(self.leaf_name), **self.kwargs + ) + + +class PandasLikeGroupBy( + EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", NativeAggregation] +): + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, NativeAggregation]] = { "sum": "sum", "mean": "mean", "median": "median", @@ -31,6 +157,19 @@ class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", st "count": "count", "quantile": "quantile", } + _original_columns: tuple[str, ...] + """Column names *prior* to any aliasing in `ParseKeysGroupBy`.""" + + _keys: list[str] + """Stores the **aliased** version of group keys from `ParseKeysGroupBy`.""" + + _output_key_names: list[str] + """Stores the **original** version of group keys.""" + + @property + def exclude(self) -> tuple[str, ...]: + """Group keys to ignore when expanding multi-output aggregations.""" + return self._exclude def __init__( self, @@ -40,231 +179,106 @@ def __init__( *, drop_null_keys: bool, ) -> None: - self._df = df + self._original_columns = tuple(df.columns) self._drop_null_keys = drop_null_keys self._compliant_frame, self._keys, self._output_key_names = self._parse_keys( - df, keys=keys + df, keys ) + self._exclude: tuple[str, ...] = (*self._keys, *self._output_key_names) # Drop index to avoid potential collisions: # https://github.com/narwhals-dev/narwhals/issues/1907. - if set(self.compliant.native.index.names).intersection(self.compliant.columns): - native_frame = self.compliant.native.reset_index(drop=True) - else: - native_frame = self.compliant.native - - self._grouped = native_frame.groupby( - list(self._keys), + native = self.compliant.native + if set(native.index.names).intersection(self.compliant.columns): + native = native.reset_index(drop=True) + self._grouped: NativeGroupBy = native.groupby( + self._keys.copy(), sort=False, as_index=True, dropna=drop_null_keys, observed=True, ) - def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: C901, PLR0912, PLR0914, PLR0915 - implementation = self.compliant._implementation - backend_version = self.compliant._backend_version - new_names: list[str] = self._keys.copy() - + def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: all_aggs_are_simple = True - exclude = (*self._keys, *self._output_key_names) + agg_exprs: list[AggExpr] = [] for expr in exprs: - _, aliases = evaluate_output_names_and_aliases(expr, self.compliant, exclude) - new_names.extend(aliases) + agg_exprs.append(AggExpr(expr).with_expand_names(self)) if not self._is_simple(expr): all_aggs_are_simple = False - # dict of {output_name: root_name} that we count n_unique on - # We need to do this separately from the rest so that we - # can pass the `dropna` kwargs. - nunique_aggs: dict[str, str] = {} - simple_aggs: dict[str, list[str]] = collections.defaultdict(list) - simple_aggs_functions: set[str] = set() - - # ddof to (output_names, aliases) mapping - std_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( - lambda: ([], []) - ) - var_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( - lambda: ([], []) - ) - - expected_old_names: list[str] = [] - simple_agg_new_names: list[str] = [] - - if all_aggs_are_simple: # noqa: PLR1702 - for expr in exprs: - output_names, aliases = evaluate_output_names_and_aliases( - expr, self.compliant, exclude - ) - if expr._depth == 0: - # e.g. `agg(nw.len())` - function_name = self._remap_expr_name(expr._function_name) - simple_aggs_functions.add(function_name) - - for alias in aliases: - expected_old_names.append(f"{self._keys[0]}_{function_name}") - simple_aggs[self._keys[0]].append(function_name) - simple_agg_new_names.append(alias) - continue - - # e.g. `agg(nw.mean('a'))` - function_name = self._remap_expr_name(self._leaf_name(expr)) - is_n_unique = function_name == "nunique" - is_std = function_name == "std" - is_var = function_name == "var" - for output_name, alias in zip(output_names, aliases): - if is_n_unique: - nunique_aggs[alias] = output_name - elif is_std and (ddof := expr._scalar_kwargs["ddof"]) != 1: # pyright: ignore[reportTypedDictNotRequiredAccess] - std_aggs[ddof][0].append(output_name) - std_aggs[ddof][1].append(alias) - elif is_var and (ddof := expr._scalar_kwargs["ddof"]) != 1: # pyright: ignore[reportTypedDictNotRequiredAccess] - var_aggs[ddof][0].append(output_name) - var_aggs[ddof][1].append(alias) - else: - expected_old_names.append(f"{output_name}_{function_name}") - simple_aggs[output_name].append(function_name) - simple_agg_new_names.append(alias) - simple_aggs_functions.add(function_name) - - result_aggs = [] - - if simple_aggs: - # Fast path for single aggregation such as `df.groupby(...).mean()` - if ( - len(simple_aggs_functions) == 1 - and (agg_method := simple_aggs_functions.pop()) != "size" - and len(simple_aggs) > 1 - ): - result_simple_aggs = getattr( - self._grouped[list(simple_aggs.keys())], agg_method - )() - result_simple_aggs.columns = [ - f"{a}_{agg_method}" for a in result_simple_aggs.columns - ] - else: - result_simple_aggs = self._grouped.agg(simple_aggs) - result_simple_aggs.columns = [ - f"{a}_{b}" for a, b in result_simple_aggs.columns - ] - if not ( - set(result_simple_aggs.columns) == set(expected_old_names) - and len(result_simple_aggs.columns) == len(expected_old_names) - ): # pragma: no cover - msg = ( - f"Safety assertion failed, expected {expected_old_names} " - f"got {result_simple_aggs.columns}, " - "please report a bug at https://github.com/narwhals-dev/narwhals/issues" - ) - raise AssertionError(msg) - - # Rename columns, being very careful - expected_old_names_indices: dict[str, list[int]] = ( - collections.defaultdict(list) - ) - for idx, item in enumerate(expected_old_names): - expected_old_names_indices[item].append(idx) - index_map: list[int] = [ - expected_old_names_indices[item].pop(0) - for item in result_simple_aggs.columns - ] - result_simple_aggs.columns = [simple_agg_new_names[i] for i in index_map] - result_aggs.append(result_simple_aggs) - - if nunique_aggs: - result_nunique_aggs = self._grouped[list(nunique_aggs.values())].nunique( - dropna=False - ) - result_nunique_aggs.columns = list(nunique_aggs.keys()) - - result_aggs.append(result_nunique_aggs) - - if std_aggs: - for ddof, (std_output_names, std_aliases) in std_aggs.items(): - _aggregation = self._grouped[std_output_names].std(ddof=ddof) - # `_aggregation` is a new object so it's OK to operate inplace. - _aggregation.columns = std_aliases - result_aggs.append(_aggregation) - if var_aggs: - for ddof, (var_output_names, var_aliases) in var_aggs.items(): - _aggregation = self._grouped[var_output_names].var(ddof=ddof) - # `_aggregation` is a new object so it's OK to operate inplace. - _aggregation.columns = var_aliases - result_aggs.append(_aggregation) - - if result_aggs: - output_names_counter = collections.Counter( - c for frame in result_aggs for c in frame - ) - if any(v > 1 for v in output_names_counter.values()): - msg = "" - for key, value in output_names_counter.items(): - if value > 1: - msg += f"\n- '{key}' {value} times" - else: # pragma: no cover - pass - msg = f"Expected unique output names, got:{msg}" - raise ValueError(msg) - namespace = self.compliant.__narwhals_namespace__() - result = namespace._concat_horizontal(result_aggs) + if all_aggs_are_simple: + result: pd.DataFrame + if agg_exprs: + ns = self.compliant.__narwhals_namespace__() + result = ns._concat_horizontal(self._getitem_aggs(agg_exprs)) else: - # No aggregation provided result = self.compliant.__native_namespace__().DataFrame( - list(self._grouped.groups.keys()), columns=self._keys + list(self._grouped.groups), columns=self._keys ) - # Keep inplace=True to avoid making a redundant copy. - # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files - result.reset_index(inplace=True) - return self.compliant._with_native( - select_columns_by_name(result, new_names, implementation) - ).rename(dict(zip(self._keys, self._output_key_names))) - - if self.compliant.native.empty: - # Don't even attempt this, it's way too inconsistent across pandas versions. - msg = ( - "No results for group-by aggregation.\n\n" - "Hint: you were probably trying to apply a non-elementary aggregation with a " - "pandas-like API.\n" - "Please rewrite your query such that group-by aggregations " - "are elementary. For example, instead of:\n\n" - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" - "use:\n\n" - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" - ) - raise ValueError(msg) - - warnings.warn( - "Found complex group-by expression, which can't be expressed efficiently with the " - "pandas API. If you can, please rewrite your query such that group-by aggregations " - "are simple (e.g. mean, std, min, max, ...). \n\n" - "Please see: " - "https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/", - UserWarning, - stacklevel=find_stacklevel(), + elif self.compliant.native.empty: + raise empty_results_error() + else: + result = self._apply_aggs(exprs) + return self._select_results(result, agg_exprs) + + def _select_results( + self, df: pd.DataFrame, /, agg_exprs: Sequence[AggExpr] + ) -> PandasLikeDataFrame: + """Responsible for remapping temp column names back to original. + + See `ParseKeysGroupBy`. + """ + # NOTE: Keep `inplace=True` to avoid making a redundant copy. + # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files + df.reset_index(inplace=True) # noqa: PD002 + new_names = chain.from_iterable(e.aliases for e in agg_exprs) + return ( + self.compliant._with_native(df, validate_column_names=False) + .simple_select(*self._keys, *new_names) + .rename(dict(zip(self._keys, self._output_key_names))) ) - def func(df: Any) -> Any: + def _getitem_aggs( + self, exprs: Iterable[AggExpr], / + ) -> list[pd.DataFrame | pd.Series[Any]]: + return [e._getitem_aggs(self) for e in exprs] + + def _apply_aggs(self, exprs: Iterable[PandasLikeExpr]) -> pd.DataFrame: + """Stub issue for `include_groups` [pandas-dev/pandas-stubs#1270]. + + - [User guide] mentions `include_groups` 4 times without deprecation. + - [`DataFrameGroupBy.apply`] doc says the default value of `True` is deprecated since `2.2.0`. + - `False` is explicitly the only *non-deprecated* option, but entirely omitted since [pandas-dev/pandas-stubs#1268]. + + [pandas-dev/pandas-stubs#1270]: https://github.com/pandas-dev/pandas-stubs/issues/1270 + [User guide]: https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html + [`DataFrameGroupBy.apply`]: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.core.groupby.DataFrameGroupBy.apply.html + [pandas-dev/pandas-stubs#1268]: https://github.com/pandas-dev/pandas-stubs/pull/1268 + """ + warn_complex_group_by() + impl = self.compliant._implementation + func = self._apply_exprs_function(exprs) + apply = self._grouped.apply + if impl.is_pandas() and impl._backend_version() >= (2, 2): + return apply(func, include_groups=False) # type: ignore[call-overload] + else: # pragma: no cover + return apply(func) + + def _apply_exprs_function(self, exprs: Iterable[PandasLikeExpr]) -> NativeApply: + ns = self.compliant.__narwhals_namespace__() + into_series = ns._series.from_iterable + + def fn(df: pd.DataFrame) -> pd.Series[Any]: out_group = [] out_names = [] for expr in exprs: results_keys = expr(self.compliant._with_native(df)) - for result_keys in results_keys: - out_group.append(result_keys.native.iloc[0]) - out_names.append(result_keys.name) - ns = self.compliant.__narwhals_namespace__() - return ns._series.from_iterable(out_group, index=out_names, context=ns).native - - if implementation.is_pandas() and backend_version >= (2, 2): - result_complex = self._grouped.apply(func, include_groups=False) - else: # pragma: no cover - result_complex = self._grouped.apply(func) + for keys in results_keys: + out_group.append(keys.native.iloc[0]) + out_names.append(keys.name) + return into_series(out_group, index=out_names, context=ns).native - # Keep inplace=True to avoid making a redundant copy. - # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files - result_complex.reset_index(inplace=True) - return self.compliant._with_native( - select_columns_by_name(result_complex, new_names, implementation) - ).rename(dict(zip(self._keys, self._output_key_names))) + return fn def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: with warnings.catch_warnings(): @@ -273,9 +287,33 @@ def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: message=".*a length 1 tuple will be returned", category=FutureWarning, ) - + with_native = self.compliant._with_native for key, group in self._grouped: - yield ( - key, - self.compliant._with_native(group).simple_select(*self._df.columns), - ) + yield (key, with_native(group).simple_select(*self._original_columns)) + + +def empty_results_error() -> ValueError: + """Don't even attempt this, it's way too inconsistent across pandas versions.""" + msg = ( + "No results for group-by aggregation.\n\n" + "Hint: you were probably trying to apply a non-elementary aggregation with a " + "pandas-like API.\n" + "Please rewrite your query such that group-by aggregations " + "are elementary. For example, instead of:\n\n" + " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" + "use:\n\n" + " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" + ) + return ValueError(msg) + + +def warn_complex_group_by() -> None: + warnings.warn( + "Found complex group-by expression, which can't be expressed efficiently with the " + "pandas API. If you can, please rewrite your query such that group-by aggregations " + "are simple (e.g. mean, std, min, max, ...). \n\n" + "Please see: " + "https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/", + UserWarning, + stacklevel=find_stacklevel(), + ) diff --git a/pyproject.toml b/pyproject.toml index da2fb5caba..0f98e37c73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ typing = [ # keep some of these pinned and bump periodically so there's fewer s "duckdb==1.3.0", "hypothesis", "pytest", - "pandas-stubs==2.2.3.250308", + "pandas-stubs==2.3.0.250703", "typing_extensions", "mypy~=1.15.0", "pyright", diff --git a/tests/frame/group_by_test.py b/tests/frame/group_by_test.py index 2968d7f083..4fa98d077c 100644 --- a/tests/frame/group_by_test.py +++ b/tests/frame/group_by_test.py @@ -1,6 +1,9 @@ from __future__ import annotations +import datetime as dt import os +import re +from decimal import Decimal from typing import TYPE_CHECKING, Any import pandas as pd @@ -8,9 +11,10 @@ import pytest import narwhals as nw -from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals.exceptions import ComputeError, DuplicateError, InvalidOperationError from tests.utils import ( PANDAS_VERSION, + POLARS_VERSION, PYARROW_VERSION, Constructor, ConstructorEager, @@ -20,6 +24,9 @@ if TYPE_CHECKING: from collections.abc import Mapping + from narwhals.typing import NonNestedLiteral + + data: Mapping[str, Any] = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} df_pandas = pd.DataFrame(data) @@ -86,6 +93,16 @@ def test_group_by_iter(constructor_eager: ConstructorEager) -> None: assert sorted(keys) == sorted(expected_keys) +def test_group_by_iter_non_str_pandas() -> None: + expected = {"a": {0: [1], 1: ["a"]}, "b": {0: [2], 1: ["b"]}} + df = nw.from_native(pd.DataFrame({0: [1, 2], 1: ["a", "b"]})) + groups: dict[Any, Any] = {keys[0]: df for keys, df in df.group_by(1)} # type: ignore[call-overload] + assert groups.keys() == {"a", "b"} + groups["a"] = groups["a"].to_dict(as_series=False) + groups["b"] = groups["b"].to_dict(as_series=False) + assert_equal_data(groups, expected) + + def test_group_by_nw_all(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]})) result = df.group_by("a").agg(nw.all().sum()).sort("a") @@ -178,7 +195,10 @@ def test_group_by_n_unique_w_missing(constructor: Constructor) -> None: def test_group_by_same_name_twice() -> None: df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) - with pytest.raises(ValueError, match="Expected unique output names"): + pattern = re.compile( + "expected unique.+names.+'b'.+2 times", re.IGNORECASE | re.DOTALL + ) + with pytest.raises(DuplicateError, match=pattern): nw.from_native(df).group_by("a").agg(nw.col("b").sum(), nw.col("b").n_unique()) @@ -317,9 +337,8 @@ def test_group_by_shift_raises(constructor: Constructor) -> None: def test_double_same_aggregation( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if any(x in str(constructor) for x in ("dask", "modin", "cudf")): + if any(x in str(constructor) for x in ("dask", "cudf")): # bugged in dask https://github.com/dask/dask/issues/11612 - # and modin lol https://github.com/modin-project/modin/issues/7414 # and cudf https://github.com/rapidsai/cudf/issues/17649 request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]})) @@ -331,9 +350,8 @@ def test_double_same_aggregation( def test_all_kind_of_aggs( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if any(x in str(constructor) for x in ("dask", "cudf", "modin")): + if any(x in str(constructor) for x in ("dask", "cudf")): # bugged in dask https://github.com/dask/dask/issues/11612 - # and modin lol https://github.com/modin-project/modin/issues/7414 # and cudf https://github.com/rapidsai/cudf/issues/17649 request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4): @@ -523,3 +541,81 @@ def test_renaming_edge_case(constructor: Constructor) -> None: result = nw.from_native(constructor(data)).group_by(nw.col("a")).agg(nw.all().min()) expected = {"a": [0], "_a_tmp": [1], "b": [4]} assert_equal_data(result, expected) + + +def test_group_by_len_1_column( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + """Based on a failure from marimo. + + - https://github.com/marimo-team/marimo/blob/036fd3ff89ef3a0e598bebb166637028024f98bc/tests/_plugins/ui/_impl/tables/test_narwhals.py#L1098-L1108 + - https://github.com/marimo-team/marimo/blob/036fd3ff89ef3a0e598bebb166637028024f98bc/marimo/_plugins/ui/_impl/tables/narwhals_table.py#L163-L188 + """ + if any(x in str(constructor) for x in ("dask",)): + # `dask` + # ValueError: conflicting aggregation functions: [('size', 'a'), ('size', 'a')] + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 2, 1, 2, 3, 4]} + expected = {"a": [1, 2, 3, 4], "len": [2, 2, 1, 1], "len_a": [2, 2, 1, 1]} + result = ( + nw.from_native(constructor(data)) + .group_by("a") + .agg(nw.len(), nw.len().alias("len_a")) + .sort("a") + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("low", "high"), + [ + ("A", "B"), + (1.5, 5.2), + (dt.datetime(2000, 1, 1), dt.datetime(2002, 1, 1)), + (dt.date(2000, 1, 1), dt.date(2002, 1, 1)), + (dt.time(5, 0, 0), dt.time(14, 0, 0)), + (dt.timedelta(32), dt.timedelta(800)), + (False, True), + (b"a", b"z"), + (Decimal("43.954"), Decimal("264.124")), + ], + ids=[ + "str", + "float", + "datetime", + "date", + "time", + "timedelta", + "bool", + "bytes", + "Decimal", + ], +) +def test_group_by_no_preserve_dtype( + constructor_eager: ConstructorEager, low: NonNestedLiteral, high: NonNestedLiteral +) -> None: + """Minimal repro for [`px.sunburst` failure]. + + The issue appeared for `n_unique`, but applies for any [aggregation that requires a function]. + + [`px.sunburst` failure]: https://github.com/narwhals-dev/narwhals/pull/2680#discussion_r2151972940 + [aggregation that requires a function]: https://github.com/pandas-dev/pandas/issues/57317 + """ + if ( + "polars" in str(constructor_eager) + and isinstance(low, Decimal) + and POLARS_VERSION < (1, 0, 0) + ): + pytest.skip("Decimal support in group_by for polars didn't stabilize until 1.0.0") + data = { + "col_a": ["A", "B", None, "A", "A", "B", None], + "col_b": [low, low, high, high, None, None, None], + } + expected = {"col_a": [None, "A", "B"], "n_unique": [2, 3, 2]} + frame = nw.from_native(constructor_eager(data)) + result = ( + frame.group_by("col_a").agg(n_unique=nw.col("col_b").n_unique()).sort("col_a") + ) + actual_dtype = result.schema["n_unique"] + assert actual_dtype.is_integer() + assert_equal_data(result, expected)