diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 209d831bf1..d88d80fd4e 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -399,7 +399,7 @@ def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: def group_by(self: Self, *keys: str, drop_null_keys: bool) -> ArrowGroupBy: from narwhals._arrow.group_by import ArrowGroupBy - return ArrowGroupBy(self, list(keys), drop_null_keys=drop_null_keys) + return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys) def join( self: Self, diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 25ec346dd2..c178a940b3 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -1,18 +1,20 @@ from __future__ import annotations import collections -import re from typing import TYPE_CHECKING from typing import Any +from typing import ClassVar from typing import Iterator +from typing import Mapping +from typing import Sequence import pyarrow as pa import pyarrow.compute as pc from narwhals._arrow.utils import cast_to_comparable_string_types from narwhals._arrow.utils import extract_py_scalar +from narwhals._compliant import EagerGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._expression_parsing import is_elementary_expression from narwhals.utils import generate_temporary_column_name if TYPE_CHECKING: @@ -21,67 +23,51 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.typing import Incomplete + from narwhals._compliant.group_by import NarwhalsAggregation + + +class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]): + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = { + "sum": "sum", + "mean": "mean", + "median": "approximate_median", + "max": "max", + "min": "min", + "std": "stddev", + "var": "variance", + "len": "count", + "n_unique": "count_distinct", + "count": "count", + } -POLARS_TO_ARROW_AGGREGATIONS = { - "sum": "sum", - "mean": "mean", - "median": "approximate_median", - "max": "max", - "min": "min", - "std": "stddev", - "var": "variance", - "len": "count", - "n_unique": "count_distinct", - "count": "count", -} - - -class ArrowGroupBy: def __init__( - self: Self, df: ArrowDataFrame, keys: list[str], *, drop_null_keys: bool + self, + compliant_frame: ArrowDataFrame, + keys: Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: if drop_null_keys: - self._df = df.drop_nulls(keys) + self._compliant_frame = compliant_frame.drop_nulls(keys) else: - self._df = df - self._keys = keys.copy() - self._grouped = pa.TableGroupBy(self._df._native_frame, self._keys) + self._compliant_frame = compliant_frame + self._keys: list[str] = list(keys) + self._grouped = pa.TableGroupBy(self.compliant.native, self._keys) def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: - all_simple_aggs = True - for expr in exprs: - if not ( - is_elementary_expression(expr) - and re.sub(r"(\w+->)", "", expr._function_name) - in POLARS_TO_ARROW_AGGREGATIONS - ): - all_simple_aggs = False - break - - if not all_simple_aggs: - msg = ( - "Non-trivial complex aggregation found.\n\n" - "Hint: you were probably trying to apply a non-elementary aggregation with a " - "pyarrow table.\n" - "Please rewrite your query such that group-by aggregations " - "are elementary. For example, instead of:\n\n" - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" - "use:\n\n" - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" - ) - raise ValueError(msg) - + self._ensure_all_simple(exprs) aggs: list[tuple[str, str, Any]] = [] expected_pyarrow_column_names: list[str] = self._keys.copy() new_column_names: list[str] = self._keys.copy() for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases( - expr, self._df, self._keys + expr, self.compliant, self._keys ) if expr._depth == 0: - # e.g. agg(nw.len()) # noqa: ERA001 + # e.g. `agg(nw.len())` if expr._function_name != "len": # pragma: no cover msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) @@ -89,10 +75,9 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: new_column_names.append(aliases[0]) expected_pyarrow_column_names.append(f"{self._keys[0]}_count") aggs.append((self._keys[0], "count", pc.CountOptions(mode="all"))) - continue - function_name = re.sub(r"(\w+->)", "", expr._function_name) + function_name = self._leaf_name(expr) if function_name in {"std", "var"}: option: Any = pc.VarianceOptions(ddof=expr._call_kwargs["ddof"]) elif function_name in {"len", "n_unique"}: @@ -102,8 +87,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: else: option = None - function_name = POLARS_TO_ARROW_AGGREGATIONS[function_name] - + function_name = self._remap_expr_name(function_name) new_column_names.extend(aliases) expected_pyarrow_column_names.extend( [f"{output_name}_{function_name}" for output_name in output_names] @@ -133,18 +117,20 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: ] new_column_names = [new_column_names[i] for i in index_map] result_simple = result_simple.rename_columns(new_column_names) - if self._df._backend_version < (12, 0, 0): + if self.compliant._backend_version < (12, 0, 0): columns = result_simple.column_names result_simple = result_simple.select( [*self._keys, *[col for col in columns if col not in self._keys]] ) - return self._df._from_native_frame(result_simple) + return self.compliant._from_native_frame(result_simple) def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: - col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns) + col_token = generate_temporary_column_name( + n_bytes=8, columns=self.compliant.columns + ) null_token: str = "__null_token_value__" # noqa: S105 - table = self._df._native_frame + table = self.compliant.native # NOTE: stubs fail in multiple places for `ChunkedArray` it, separator_scalar = cast_to_comparable_string_types( *(table[key] for key in self._keys), separator="" @@ -160,7 +146,7 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: ) table = table.add_column(i=0, field_=col_token, column=key_values) for v in pc.unique(key_values): - t = self._df._from_native_frame( + t = self.compliant._from_native_frame( table.filter(pc.equal(table[col_token], v)).drop([col_token]) ) row = t.simple_select(*self._keys).row(0) diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index 0a8bc59cd8..f223aa6a3f 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -6,6 +6,10 @@ from narwhals._compliant.expr import CompliantExpr from narwhals._compliant.expr import EagerExpr from narwhals._compliant.expr import LazyExpr +from narwhals._compliant.group_by import CompliantGroupBy +from narwhals._compliant.group_by import DepthTrackingGroupBy +from narwhals._compliant.group_by import EagerGroupBy +from narwhals._compliant.group_by import LazyGroupBy from narwhals._compliant.namespace import CompliantNamespace from narwhals._compliant.namespace import EagerNamespace from narwhals._compliant.selectors import CompliantSelector @@ -31,6 +35,7 @@ "CompliantExpr", "CompliantExprT", "CompliantFrameT", + "CompliantGroupBy", "CompliantLazyFrame", "CompliantNamespace", "CompliantSelector", @@ -38,9 +43,11 @@ "CompliantSeries", "CompliantSeriesOrNativeExprT_co", "CompliantSeriesT", + "DepthTrackingGroupBy", "EagerDataFrame", "EagerDataFrameT", "EagerExpr", + "EagerGroupBy", "EagerNamespace", "EagerSelectorNamespace", "EagerSeries", @@ -49,6 +56,7 @@ "EvalSeries", "IntoCompliantExpr", "LazyExpr", + "LazyGroupBy", "LazySelectorNamespace", "NativeFrameT_co", "NativeSeriesT_co", diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 5cb593530f..06b8a00c44 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -32,6 +32,7 @@ from typing_extensions import Self from typing_extensions import TypeAlias + from narwhals._compliant.group_by import CompliantGroupBy from narwhals.dtypes import DType from narwhals.typing import SizeUnit from narwhals.typing import _2DArray @@ -67,7 +68,8 @@ def aggregate(self, *exprs: CompliantExprT_contra) -> Self: (so, no broadcasting is necessary). """ - return self.select(*exprs) + # NOTE: Ignore is to avoid an intermittent false positive + return self.select(*exprs) # pyright: ignore[reportArgumentType] @property def native(self) -> NativeFrameT_co: @@ -91,7 +93,9 @@ def explode(self: Self, columns: Sequence[str]) -> Self: ... def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... def gather_every(self, n: int, offset: int) -> Self: ... def get_column(self, name: str) -> CompliantSeriesT: ... - def group_by(self, *keys: str, drop_null_keys: bool) -> Incomplete: ... + def group_by( + self, *keys: str, drop_null_keys: bool + ) -> CompliantGroupBy[Self, Any]: ... def head(self, n: int) -> Self: ... def item(self, row: int | None, column: int | str | None) -> Any: ... def iter_columns(self) -> Iterator[CompliantSeriesT]: ... @@ -218,7 +222,9 @@ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... "`LazyFrame.gather_every` is deprecated and will be removed in a future version." ) def gather_every(self, n: int, offset: int) -> Self: ... - def group_by(self, *keys: str, drop_null_keys: bool) -> Incomplete: ... + def group_by( + self, *keys: str, drop_null_keys: bool + ) -> CompliantGroupBy[Self, Any]: ... def head(self, n: int) -> Self: ... def join( self: Self, diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index a922c66e78..ec251a6ac8 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -270,6 +270,14 @@ def __invert__(self) -> Self: ... def broadcast( self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL] ) -> Self: ... + def _is_multi_output_agg(self) -> bool: + """Return `True` for multi-output aggregations. + + Here we skip the keys, else they would appear duplicated in the output: + + df.group_by("a").agg(nw.all().mean()) + """ + return self._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} class EagerExpr( diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py new file mode 100644 index 0000000000..866e98c6e8 --- /dev/null +++ b/narwhals/_compliant/group_by.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import re +import sys +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import ClassVar +from typing import Iterable +from typing import Iterator +from typing import Literal +from typing import Mapping +from typing import Sequence +from typing import TypeVar + +from narwhals._compliant.typing import CompliantDataFrameT_co +from narwhals._compliant.typing import CompliantExprAny +from narwhals._compliant.typing import CompliantExprT_contra +from narwhals._compliant.typing import CompliantFrameT_co +from narwhals._compliant.typing import CompliantLazyFrameT_co +from narwhals._compliant.typing import LazyExprT_contra +from narwhals._compliant.typing import NativeExprT_co +from narwhals._expression_parsing import is_elementary_expression + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + +if not TYPE_CHECKING: # pragma: no cover + if sys.version_info >= (3, 9): + from typing import Protocol as Protocol38 + else: + from typing import Generic as Protocol38 +else: # pragma: no cover + # TODO @dangotbanned: Remove after dropping `3.8` (#2084) + # - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386 + from typing import Protocol as Protocol38 + +__all__ = [ + "CompliantGroupBy", + "DepthTrackingGroupBy", + "EagerGroupBy", + "LazyGroupBy", + "NarwhalsAggregation", +] + +NativeAggregationT_co = TypeVar( + "NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True +) +NarwhalsAggregation: TypeAlias = Literal[ + "sum", "mean", "median", "max", "min", "std", "var", "len", "n_unique", "count" +] + + +_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)") + + +class CompliantGroupBy(Protocol38[CompliantFrameT_co, CompliantExprT_contra]): + _compliant_frame: Any + _keys: Sequence[str] + + @property + def compliant(self) -> CompliantFrameT_co: + return self._compliant_frame # type: ignore[no-any-return] + + def __init__( + self, + compliant_frame: CompliantFrameT_co, + keys: Sequence[str], + /, + *, + drop_null_keys: bool, + ) -> None: ... + + def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ... + + +class DepthTrackingGroupBy( + CompliantGroupBy[CompliantFrameT_co, CompliantExprT_contra], + Protocol38[CompliantFrameT_co, CompliantExprT_contra, NativeAggregationT_co], +): + """`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`.""" + + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] + """Mapping from `narwhals` to native representation. + + Note: + - `Dask` *may* return a `Callable` instead of a `str` referring to one. + """ + + def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: + for expr in exprs: + if not self._is_simple(expr): + name = self.compliant._implementation.name.lower() + msg = ( + f"Non-trivial complex aggregation found.\n\n" + f"Hint: you were probably trying to apply a non-elementary aggregation with a" + f"{name!r} table.\n" + "Please rewrite your query such that group-by aggregations " + "are elementary. For example, instead of:\n\n" + " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" + "use:\n\n" + " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" + ) + raise ValueError(msg) + + @classmethod + def _is_simple(cls, expr: CompliantExprAny, /) -> bool: + """Return `True` is we can efficiently use `expr` in a native `group_by` context.""" + return is_elementary_expression(expr) and cls._leaf_name(expr) in cls._REMAP_AGGS + + @classmethod + def _remap_expr_name( + cls, name: NarwhalsAggregation | Any, / + ) -> NativeAggregationT_co: + """Replace `name`, with some native representation. + + Arguments: + name: Name of a `nw.Expr` aggregation method. + + Returns: + A native compatible representation. + """ + return cls._REMAP_AGGS.get(name, name) + + @classmethod + def _leaf_name(cls, expr: CompliantExprAny, /) -> NarwhalsAggregation | Any: + """Return the last function name in the chain defined by `expr`.""" + return _RE_LEAF_NAME.sub("", expr._function_name) + + +class EagerGroupBy( + DepthTrackingGroupBy[CompliantDataFrameT_co, CompliantExprT_contra, str], + Protocol38[CompliantDataFrameT_co, CompliantExprT_contra], +): + def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ... + + +class LazyGroupBy( + CompliantGroupBy[CompliantLazyFrameT_co, LazyExprT_contra], + Protocol38[CompliantLazyFrameT_co, LazyExprT_contra, NativeExprT_co], +): + def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]: + output_names = expr._evaluate_output_names(self.compliant) + aliases = ( + expr._alias_output_names(output_names) + if expr._alias_output_names + else output_names + ) + native_exprs = expr(self.compliant) + if expr._is_multi_output_agg(): + for native_expr, name, alias in zip(native_exprs, output_names, aliases): + if name not in self._keys: + yield native_expr.alias(alias) + else: + for native_expr, alias in zip(native_exprs, aliases): + yield native_expr.alias(alias) + + def _evaluate_exprs( + self, exprs: Iterable[LazyExprT_contra], / + ) -> Iterator[NativeExprT_co]: + for expr in exprs: + yield from self._evaluate_expr(expr) diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 9a662e15a2..c1771bde16 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -14,6 +14,7 @@ from narwhals._compliant.dataframe import EagerDataFrame from narwhals._compliant.expr import CompliantExpr from narwhals._compliant.expr import EagerExpr + from narwhals._compliant.expr import LazyExpr from narwhals._compliant.expr import NativeExpr from narwhals._compliant.namespace import EagerNamespace from narwhals._compliant.series import CompliantSeries @@ -45,14 +46,26 @@ "CompliantFrameT", bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]", ) +CompliantFrameT_co = TypeVar( + "CompliantFrameT_co", + bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]", + covariant=True, +) CompliantDataFrameT = TypeVar( "CompliantDataFrameT", bound="CompliantDataFrame[Any, Any, Any]" ) +CompliantDataFrameT_co = TypeVar( + "CompliantDataFrameT_co", bound="CompliantDataFrame[Any, Any, Any]", covariant=True +) CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame[Any, Any]") +CompliantLazyFrameT_co = TypeVar( + "CompliantLazyFrameT_co", bound="CompliantLazyFrame[Any, Any]", covariant=True +) IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co" -CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]") +CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]" +CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny) CompliantExprT_contra = TypeVar( - "CompliantExprT_contra", bound="CompliantExpr[Any, Any]", contravariant=True + "CompliantExprT_contra", bound=CompliantExprAny, contravariant=True ) EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any]") @@ -65,6 +78,9 @@ EagerNamespaceAny: TypeAlias = ( "EagerNamespace[EagerDataFrame[Any, Any, Any], EagerSeries[Any], EagerExpr[Any, Any]]" ) +LazyExprT_contra = TypeVar( + "LazyExprT_contra", bound="LazyExpr[Any, Any]", contravariant=True +) AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]] AliasName: TypeAlias = Callable[[str], str] diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 58add7a79b..6103afa7fb 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -388,7 +388,7 @@ def join_asof( def group_by(self: Self, *by: str, drop_null_keys: bool) -> DaskLazyGroupBy: from narwhals._dask.group_by import DaskLazyGroupBy - return DaskLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys) + return DaskLazyGroupBy(self, by, drop_null_keys=drop_null_keys) def tail(self: Self, n: int) -> Self: # pragma: no cover native_frame = self._native_frame diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 8d87a96e41..ab99a5e15e 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re import warnings from typing import TYPE_CHECKING from typing import Any @@ -550,7 +549,7 @@ def over( order_by: Sequence[str] | None, ) -> Self: # pandas is a required dependency of dask so it's safe to import this - from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT + from narwhals._pandas_like.group_by import PandasLikeGroupBy if not partition_by: assert order_by is not None # help type checkers # noqa: S101 @@ -567,14 +566,14 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: ) raise NotImplementedError(msg) else: - function_name = re.sub(r"(\w+->)", "", self._function_name) + function_name = PandasLikeGroupBy._leaf_name(self) try: - dask_function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT[function_name] + dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] except KeyError: # window functions are unsupported: https://github.com/dask/dask/issues/11806 msg = ( f"Unsupported function: {function_name} in `over` context.\n\n" - f"Supported functions are {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}\n" + f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n" ) raise NotImplementedError(msg) from None diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 6fadc8ab3d..cca29870e8 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -1,17 +1,17 @@ from __future__ import annotations -import re from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import ClassVar from typing import Mapping from typing import Sequence import dask.dataframe as dd +from narwhals._compliant import DepthTrackingGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._expression_parsing import is_elementary_expression try: import dask.dataframe.dask_expr as dx @@ -24,17 +24,20 @@ from typing_extensions import Self from typing_extensions import TypeAlias + from narwhals._compliant.group_by import NarwhalsAggregation from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any] _AggFn: TypeAlias = Callable[..., Any] - Aggregation: TypeAlias = "str | _AggFn" from dask_expr._groupby import GroupBy as _DaskGroupBy else: _DaskGroupBy = dx._groupby.GroupBy +Aggregation: TypeAlias = "str | _AggFn" +"""The name of an aggregation function, or the function itself.""" + def n_unique() -> dd.Aggregation: def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]: @@ -54,120 +57,63 @@ def std(ddof: int) -> _AggFn: return partial(_DaskGroupBy.std, ddof=ddof) -POLARS_TO_DASK_AGGREGATIONS: Mapping[str, Aggregation] = { - "sum": "sum", - "mean": "mean", - "median": "median", - "max": "max", - "min": "min", - "std": std, - "var": var, - "len": "size", - "n_unique": n_unique, - "count": "count", -} - +class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]): + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = { + "sum": "sum", + "mean": "mean", + "median": "median", + "max": "max", + "min": "min", + "std": std, + "var": var, + "len": "size", + "n_unique": n_unique, + "count": "count", + } -class DaskLazyGroupBy: def __init__( - self: Self, df: DaskLazyFrame, keys: list[str], *, drop_null_keys: bool + self: Self, df: DaskLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool ) -> None: - self._df: DaskLazyFrame = df - self._keys = keys - self._grouped = self._df._native_frame.groupby( - list(self._keys), - dropna=drop_null_keys, - observed=True, - ) - - def agg( - self: Self, - *exprs: DaskExpr, - ) -> DaskLazyFrame: - return agg_dask( - self._df, - self._grouped, - exprs, - self._keys, - self._from_native_frame, + self._compliant_frame = df + self._keys: list[str] = list(keys) + self._grouped = self.compliant.native.groupby( + list(self._keys), dropna=drop_null_keys, observed=True ) - def _from_native_frame(self: Self, df: dd.DataFrame) -> DaskLazyFrame: + def agg(self: Self, *exprs: DaskExpr) -> DaskLazyFrame: from narwhals._dask.dataframe import DaskLazyFrame - return DaskLazyFrame( - df, - backend_version=self._df._backend_version, - version=self._df._version, - ) - - -def agg_dask( - df: DaskLazyFrame, - grouped: Any, - exprs: Sequence[DaskExpr], - keys: list[str], - from_dataframe: Callable[[Any], DaskLazyFrame], -) -> DaskLazyFrame: - """This should be the fastpath, but cuDF is too far behind to use it. - - - https://github.com/rapidsai/cudf/issues/15118 - - https://github.com/rapidsai/cudf/issues/15084 - """ - if not exprs: - # No aggregation provided - return df.simple_select(*keys).unique(subset=keys, keep="any") - - all_simple_aggs = True - for expr in exprs: - if not ( - is_elementary_expression(expr) - and re.sub(r"(\w+->)", "", expr._function_name) in POLARS_TO_DASK_AGGREGATIONS - ): - all_simple_aggs = False - break - - if all_simple_aggs: + if not exprs: + # No aggregation provided + return self.compliant.simple_select(*self._keys).unique( + self._keys, keep="any" + ) + self._ensure_all_simple(exprs) + # This should be the fastpath, but cuDF is too far behind to use it. + # - https://github.com/rapidsai/cudf/issues/15118 + # - https://github.com/rapidsai/cudf/issues/15084 simple_aggregations: dict[str, tuple[str, Aggregation]] = {} for expr in exprs: - output_names, aliases = evaluate_output_names_and_aliases(expr, df, keys) + output_names, aliases = evaluate_output_names_and_aliases( + expr, self.compliant, self._keys + ) if expr._depth == 0: - # e.g. agg(nw.len()) # noqa: ERA001 - function_name = POLARS_TO_DASK_AGGREGATIONS.get( - expr._function_name, expr._function_name - ) - simple_aggregations.update( - dict.fromkeys(aliases, (keys[0], function_name)) - ) + # e.g. `agg(nw.len())` + column = self._keys[0] + agg_fn = self._remap_expr_name(expr._function_name) + simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn))) continue - # e.g. agg(nw.mean('a')) # noqa: ERA001 - function_name = re.sub(r"(\w+->)", "", expr._function_name) - agg_function = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) + # e.g. `agg(nw.mean('a'))` + agg_fn = self._remap_expr_name(self._leaf_name(expr)) # deal with n_unique case in a "lazy" mode to not depend on dask globally - agg_function = ( - agg_function(**expr._call_kwargs) - if callable(agg_function) - else agg_function - ) - + agg_fn = agg_fn(**expr._call_kwargs) if callable(agg_fn) else agg_fn simple_aggregations.update( - { - alias: (output_name, agg_function) - for alias, output_name in zip(aliases, output_names) - } + (alias, (output_name, agg_fn)) + for alias, output_name in zip(aliases, output_names) ) - result_simple = grouped.agg(**simple_aggregations) - return from_dataframe(result_simple.reset_index()) - - msg = ( - "Non-trivial complex aggregation found.\n\n" - "Hint: you were probably trying to apply a non-elementary aggregation with a " - "dask dataframe.\n" - "Please rewrite your query such that group-by aggregations " - "are elementary. For example, instead of:\n\n" - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" - "use:\n\n" - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" - ) - raise ValueError(msg) + return DaskLazyFrame( + self._grouped.agg(**simple_aggregations).reset_index(), + backend_version=self.compliant._backend_version, + version=self.compliant._version, + ) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 5ac1642b23..45504da2e3 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -238,9 +238,7 @@ def _from_native_frame(self: Self, df: duckdb.DuckDBPyRelation) -> Self: def group_by(self: Self, *keys: str, drop_null_keys: bool) -> DuckDBGroupBy: from narwhals._duckdb.group_by import DuckDBGroupBy - return DuckDBGroupBy( - compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys - ) + return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys) def rename(self: Self, mapping: Mapping[str, str]) -> Self: df = self._native_frame diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index c9c75de5d9..8e5110e184 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -1,54 +1,33 @@ from __future__ import annotations +from itertools import chain from typing import TYPE_CHECKING +from typing import Sequence + +from narwhals._compliant import LazyGroupBy if TYPE_CHECKING: - from duckdb import Expression + from duckdb import Expression # noqa: F401 from typing_extensions import Self from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.expr import DuckDBExpr -class DuckDBGroupBy: +class DuckDBGroupBy(LazyGroupBy["DuckDBLazyFrame", "DuckDBExpr", "Expression"]): def __init__( self: Self, - compliant_frame: DuckDBLazyFrame, - keys: list[str], - drop_null_keys: bool, # noqa: FBT001 + df: DuckDBLazyFrame, + keys: Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - if drop_null_keys: - self._compliant_frame = compliant_frame.drop_nulls(subset=None) - else: - self._compliant_frame = compliant_frame - self._keys = keys + self._compliant_frame = df.drop_nulls(subset=None) if drop_null_keys else df + self._keys = list(keys) def agg(self: Self, *exprs: DuckDBExpr) -> DuckDBLazyFrame: - agg_columns: list[str | Expression] = list(self._keys) - df = self._compliant_frame - for expr in exprs: - output_names = expr._evaluate_output_names(df) - aliases = ( - output_names - if expr._alias_output_names is None - else expr._alias_output_names(output_names) - ) - native_expressions = expr(df) - exclude = ( - self._keys - if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} - else [] - ) - agg_columns.extend( - [ - native_expression.alias(alias) - for native_expression, output_name, alias in zip( - native_expressions, output_names, aliases - ) - if output_name not in exclude - ] - ) - - return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.aggregate(agg_columns) # type: ignore[arg-type] + agg_columns = list(chain(self._keys, self._evaluate_exprs(exprs))) + return self.compliant._from_native_frame( + self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type] ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 2c5f8cb506..91e6bcc2ee 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -120,9 +120,7 @@ def evaluate_output_names_and_aliases( if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"}: - # For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip - # the keys, else they would appear duplicated in the output. + if expr._is_multi_output_agg(): output_names, aliases = zip( *[(x, alias) for x, alias in zip(output_names, aliases) if x not in exclude] ) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index be6fce1bf8..8605eac3ee 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -580,11 +580,7 @@ def collect( def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PandasLikeGroupBy: from narwhals._pandas_like.group_by import PandasLikeGroupBy - return PandasLikeGroupBy( - self, - list(keys), - drop_null_keys=drop_null_keys, - ) + return PandasLikeGroupBy(self, keys, drop_null_keys=drop_null_keys) def join( self: Self, diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 104a4efc00..be93b9c0a3 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -11,7 +10,7 @@ from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._expression_parsing import is_elementary_expression -from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT +from narwhals._pandas_like.group_by import PandasLikeGroupBy from narwhals._pandas_like.series import PandasLikeSeries from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import generate_temporary_column_name @@ -223,16 +222,15 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: ) raise NotImplementedError(msg) else: - function_name: str = re.sub(r"(\w+->)", "", self._function_name) + function_name = PandasLikeGroupBy._leaf_name(self) pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get( - function_name, - AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(function_name), + function_name, PandasLikeGroupBy._REMAP_AGGS.get(function_name) ) if pandas_function_name is None: msg = ( f"Unsupported function: {function_name} in `over` context.\n\n" f"Supported functions are {', '.join(WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT)}\n" - f"and {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}." + f"and {', '.join(PandasLikeGroupBy._REMAP_AGGS)}." ) raise NotImplementedError(msg) pandas_kwargs = window_kwargs_to_pandas_equivalent( diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 0da57297fc..bed561f24e 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -1,60 +1,67 @@ from __future__ import annotations import collections -import re import warnings from typing import TYPE_CHECKING from typing import Any +from typing import ClassVar from typing import Iterator +from typing import Mapping +from typing import Sequence +from narwhals._compliant import EagerGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._expression_parsing import is_elementary_expression from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import set_columns -from narwhals.utils import Implementation from narwhals.utils import find_stacklevel if TYPE_CHECKING: from typing_extensions import Self + from narwhals._compliant.group_by import NarwhalsAggregation from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr -AGGREGATIONS_TO_PANDAS_EQUIVALENT = { - "sum": "sum", - "mean": "mean", - "median": "median", - "max": "max", - "min": "min", - "std": "std", - "var": "var", - "len": "size", - "n_unique": "nunique", - "count": "count", -} +class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr"]): + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = { + "sum": "sum", + "mean": "mean", + "median": "median", + "max": "max", + "min": "min", + "std": "std", + "var": "var", + "len": "size", + "n_unique": "nunique", + "count": "count", + } -class PandasLikeGroupBy: def __init__( - self: Self, df: PandasLikeDataFrame, keys: list[str], *, drop_null_keys: bool + self: Self, + df: PandasLikeDataFrame, + keys: Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - self._df = df - self._keys = keys + self._compliant_frame = df + self._keys: list[str] = list(keys) # Drop index to avoid potential collisions: # https://github.com/narwhals-dev/narwhals/issues/1907. - if set(df._native_frame.index.names).intersection(df.columns): - native_frame = df._native_frame.reset_index(drop=True) + if set(df.native.index.names).intersection(df.columns): + native_frame = df.native.reset_index(drop=True) else: - native_frame = df._native_frame + native_frame = df.native if ( - self._df._implementation is Implementation.PANDAS - and self._df._backend_version < (1, 1) + self.compliant._implementation.is_pandas() + and self.compliant._backend_version < (1, 1) ): # pragma: no cover if ( not drop_null_keys - and self._df.simple_select(*self._keys)._native_frame.isna().any().any() + and self.compliant.simple_select(*self._keys).native.isna().any().any() ): msg = "Grouping by null values is not supported in pandas < 1.1.0" raise NotImplementedError(msg) @@ -74,20 +81,17 @@ def __init__( ) def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR0915 - implementation = self._df._implementation - backend_version = self._df._backend_version + implementation = self.compliant._implementation + backend_version = self.compliant._backend_version new_names: list[str] = self._keys.copy() all_aggs_are_simple = True for expr in exprs: - _, aliases = evaluate_output_names_and_aliases(expr, self._df, self._keys) + _, aliases = evaluate_output_names_and_aliases( + expr, self.compliant, self._keys + ) new_names.extend(aliases) - - if not ( - is_elementary_expression(expr) - and re.sub(r"(\w+->)", "", expr._function_name) - in AGGREGATIONS_TO_PANDAS_EQUIVALENT - ): + if not self._is_simple(expr): all_aggs_are_simple = False # dict of {output_name: root_name} that we count n_unique on @@ -111,13 +115,11 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR if all_aggs_are_simple: for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases( - expr, self._df, self._keys + expr, self.compliant, self._keys ) if expr._depth == 0: - # e.g. agg(nw.len()) # noqa: ERA001 - function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT.get( - expr._function_name, expr._function_name - ) + # e.g. `agg(nw.len())` + function_name = self._remap_expr_name(expr._function_name) simple_aggs_functions.add(function_name) for alias in aliases: @@ -126,12 +128,8 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR simple_agg_new_names.append(alias) continue - # e.g. agg(nw.mean('a')) # noqa: ERA001 - function_name = re.sub(r"(\w+->)", "", expr._function_name) - function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT.get( - function_name, function_name - ) - + # e.g. `agg(nw.mean('a'))` + function_name = self._remap_expr_name(self._leaf_name(expr)) is_n_unique = function_name == "nunique" is_std = function_name == "std" is_var = function_name == "var" @@ -204,27 +202,23 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR if std_aggs: result_aggs.extend( - [ - set_columns( - self._grouped[std_output_names].std(ddof=ddof), - columns=std_aliases, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, (std_output_names, std_aliases) in std_aggs.items() - ] + set_columns( + self._grouped[std_output_names].std(ddof=ddof), + columns=std_aliases, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, (std_output_names, std_aliases) in std_aggs.items() ) if var_aggs: result_aggs.extend( - [ - set_columns( - self._grouped[var_output_names].var(ddof=ddof), - columns=var_aliases, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, (var_output_names, var_aliases) in var_aggs.items() - ] + set_columns( + self._grouped[var_output_names].var(ddof=ddof), + columns=var_aliases, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, (var_output_names, var_aliases) in var_aggs.items() ) if result_aggs: @@ -247,17 +241,17 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR ) else: # No aggregation provided - result = self._df.__native_namespace__().DataFrame( + result = self.compliant.__native_namespace__().DataFrame( list(self._grouped.groups.keys()), columns=self._keys ) # Keep inplace=True to avoid making a redundant copy. # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files result.reset_index(inplace=True) # noqa: PD002 - return self._df._from_native_frame( + return self.compliant._from_native_frame( select_columns_by_name(result, new_names, backend_version, implementation) ) - if self._df._native_frame.empty: + if self.compliant.native.empty: # Don't even attempt this, it's way too inconsistent across pandas versions. msg = ( "No results for group-by aggregation.\n\n" @@ -285,9 +279,9 @@ def func(df: Any) -> Any: out_group = [] out_names = [] for expr in exprs: - results_keys = expr(self._df._from_native_frame(df)) + results_keys = expr(self.compliant._from_native_frame(df)) for result_keys in results_keys: - out_group.append(result_keys._native_series.iloc[0]) + out_group.append(result_keys.native.iloc[0]) out_names.append(result_keys.name) return native_series_from_iterable( out_group, @@ -296,7 +290,7 @@ def func(df: Any) -> Any: implementation=implementation, ) - if implementation is Implementation.PANDAS and backend_version >= (2, 2): + if implementation.is_pandas() and backend_version >= (2, 2): result_complex = self._grouped.apply(func, include_groups=False) else: # pragma: no cover result_complex = self._grouped.apply(func) @@ -305,7 +299,7 @@ def func(df: Any) -> Any: # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files result_complex.reset_index(inplace=True) # noqa: PD002 - return self._df._from_native_frame( + return self.compliant._from_native_frame( select_columns_by_name( result_complex, new_names, backend_version, implementation ) @@ -319,4 +313,4 @@ def __iter__(self: Self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: category=FutureWarning, ) for key, group in self._grouped: - yield (key, self._df._from_native_frame(group)) + yield (key, self.compliant._from_native_frame(group)) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 3ccce65580..09c24042c9 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -347,10 +347,10 @@ def to_dict( else: return self.native.to_dict(as_series=False) - def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsGroupBy: + def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PolarsGroupBy: from narwhals._polars.group_by import PolarsGroupBy - return PolarsGroupBy(self, list(by), drop_null_keys=drop_null_keys) + return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys) def with_row_index(self: Self, name: str) -> Self: if self._backend_version < (0, 20, 4): @@ -572,10 +572,10 @@ def collect( msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise ValueError(msg) # pragma: no cover - def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsLazyGroupBy: + def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PolarsLazyGroupBy: from narwhals._polars.group_by import PolarsLazyGroupBy - return PolarsLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys) + return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) def with_row_index(self: Self, name: str) -> Self: if self._backend_version < (0, 20, 4): diff --git a/narwhals/_polars/group_by.py b/narwhals/_polars/group_by.py index 655e0e9bb8..48178f5fac 100644 --- a/narwhals/_polars/group_by.py +++ b/narwhals/_polars/group_by.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Iterator +from typing import Sequence from typing import cast from narwhals._polars.utils import extract_native @@ -17,32 +18,46 @@ class PolarsGroupBy: + _compliant_frame: PolarsDataFrame + _keys: Sequence[str] + + @property + def compliant(self) -> PolarsDataFrame: + return self._compliant_frame + def __init__( - self: Self, df: PolarsDataFrame, keys: list[str], *, drop_null_keys: bool + self, df: PolarsDataFrame, keys: Sequence[str], /, *, drop_null_keys: bool ) -> None: - self._compliant_frame: PolarsDataFrame = df - self.keys: list[str] = keys + self._compliant_frame = df + self._keys = list(keys) df = df.drop_nulls(keys) if drop_null_keys else df self._grouped: NativeGroupBy = df._native_frame.group_by(keys) def agg(self: Self, *aggs: PolarsExpr) -> PolarsDataFrame: - from_native = self._compliant_frame._from_native_frame + from_native = self.compliant._from_native_frame return from_native(self._grouped.agg(extract_native(arg) for arg in aggs)) def __iter__(self: Self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]: for key, df in self._grouped: - yield tuple(cast("str", key)), self._compliant_frame._from_native_frame(df) + yield tuple(cast("str", key)), self.compliant._from_native_frame(df) class PolarsLazyGroupBy: + _compliant_frame: PolarsLazyFrame + _keys: Sequence[str] + + @property + def compliant(self) -> PolarsLazyFrame: + return self._compliant_frame + def __init__( - self: Self, df: PolarsLazyFrame, keys: list[str], *, drop_null_keys: bool + self, df: PolarsLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool ) -> None: - self._compliant_frame: PolarsLazyFrame = df - self.keys: list[str] = keys + self._compliant_frame = df + self._keys = list(keys) df = df.drop_nulls(keys) if drop_null_keys else df self._grouped: NativeLazyGroupBy = df._native_frame.group_by(keys) def agg(self: Self, *aggs: PolarsExpr) -> PolarsLazyFrame: - from_native = self._compliant_frame._from_native_frame + from_native = self.compliant._from_native_frame return from_native(self._grouped.agg(extract_native(arg) for arg in aggs)) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 0b829a2f23..cd1577c8eb 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -273,9 +273,7 @@ def head(self: Self, n: int) -> Self: def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy: from narwhals._spark_like.group_by import SparkLikeLazyGroupBy - return SparkLikeLazyGroupBy( - compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys - ) + return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) def sort( self: Self, diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index a0acdcfe3a..77d46602a3 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -1,57 +1,35 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Sequence + +from narwhals._compliant import LazyGroupBy if TYPE_CHECKING: + from sqlframe.base.column import Column # noqa: F401 from typing_extensions import Self from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr -class SparkLikeLazyGroupBy: +class SparkLikeLazyGroupBy(LazyGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]): def __init__( self: Self, - compliant_frame: SparkLikeLazyFrame, - keys: list[str], - drop_null_keys: bool, # noqa: FBT001 + df: SparkLikeLazyFrame, + keys: Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - if drop_null_keys: - self._compliant_frame = compliant_frame.drop_nulls(subset=None) - else: - self._compliant_frame = compliant_frame - self._keys = keys + self._compliant_frame = df.drop_nulls(subset=None) if drop_null_keys else df + self._keys = list(keys) def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: - agg_columns = [] - df = self._compliant_frame - for expr in exprs: - output_names = expr._evaluate_output_names(df) - aliases = ( - output_names - if expr._alias_output_names is None - else expr._alias_output_names(output_names) - ) - native_expressions = expr(df) - exclude = ( - self._keys - if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} - else [] - ) - agg_columns.extend( - [ - native_expression.alias(alias) - for native_expression, output_name, alias in zip( - native_expressions, output_names, aliases - ) - if output_name not in exclude - ] - ) - - if not agg_columns: - return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.select(*self._keys).dropDuplicates() + if agg_columns := list(self._evaluate_exprs(exprs)): + return self.compliant._from_native_frame( + self.compliant.native.groupBy(*self._keys).agg(*agg_columns) ) - return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.groupBy(*self._keys).agg(*agg_columns) + return self.compliant._from_native_frame( + self.compliant.native.select(*self._keys).dropDuplicates() )