From 5b3f3f5965d526fda5f64a4541e048546d85dfbe Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Mar 2025 18:47:57 +0000 Subject: [PATCH 01/23] feat(DRAFT): Add `CompliantGroupBy` Based on (https://github.com/narwhals-dev/narwhals/issues/2184#issuecomment-2710918460) --- narwhals/_compliant/group_by.py | 82 +++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 narwhals/_compliant/group_by.py diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py new file mode 100644 index 0000000000..4924b05c05 --- /dev/null +++ b/narwhals/_compliant/group_by.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING +from typing import Any +from typing import ClassVar +from typing import Iterator +from typing import Mapping +from typing import Protocol +from typing import Sequence + +from narwhals._compliant.typing import CompliantDataFrameT +from narwhals._compliant.typing import CompliantExprT +from narwhals._compliant.typing import CompliantFrameT +from narwhals._compliant.typing import NativeFrameT_co +from narwhals._expression_parsing import is_elementary_expression + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals._compliant.dataframe import CompliantDataFrame + from narwhals._compliant.dataframe import CompliantLazyFrame + + Frame: TypeAlias = "CompliantDataFrame[Any, Any, NativeFrameT_co] | CompliantLazyFrame[Any, NativeFrameT_co]" + +__all__ = ["CompliantGroupBy", "EagerGroupBy"] + + +# NOTE: Type checkers disagree +# - `pyright` wants invariant `*Expr` +# - `mypy` want contravariant `*Expr` +class CompliantGroupBy(Protocol[CompliantFrameT, CompliantExprT]): # type: ignore[misc] + _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] + _compliant_frame: CompliantFrameT + _keys: Sequence[str] + + def __init__( + self, + compliant_frame: CompliantFrameT, + keys: Sequence[str], + *, + drop_null_keys: bool, + ) -> None: ... + @property + def compliant(self) -> CompliantFrameT: + return self._compliant_frame + + @property + def native( + self: CompliantGroupBy[Frame[NativeFrameT_co], CompliantExprT], + ) -> NativeFrameT_co: + return self.compliant.native + + def agg(self, *exprs: CompliantExprT) -> CompliantFrameT: ... + + def _ensure_all_simple(self, exprs: Sequence[CompliantExprT]) -> None: + for expr in exprs: + if ( + not is_elementary_expression(expr) + and re.sub(r"(\w+->)", "", expr._function_name) + in self._NARWHALS_TO_NATIVE_AGGREGATIONS + ): + # NOTE: Need to define `_implementation` in both protocols + name = self.compliant._implementation.name.lower() # type: ignore # noqa: PGH003 + msg = ( + f"Non-trivial complex aggregation found.\n\n" + f"Hint: you were probably trying to apply a non-elementary aggregation with a" + f"{name!r} table.\n" + "Please rewrite your query such that group-by aggregations " + "are elementary. For example, instead of:\n\n" + " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" + "use:\n\n" + " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" + ) + raise ValueError(msg) + + +class EagerGroupBy( # type: ignore[misc] + CompliantGroupBy[CompliantDataFrameT, CompliantExprT], + Protocol[CompliantDataFrameT, CompliantExprT], +): + def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT]]: ... From 16a9230996788bfdb1332ba54db81c86afe8ba0a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Mar 2025 18:56:06 +0000 Subject: [PATCH 02/23] fix: `3.8` compat Needed for `__init__` support --- narwhals/_compliant/group_by.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 4924b05c05..8338f9ff9c 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -1,12 +1,12 @@ from __future__ import annotations import re +import sys from typing import TYPE_CHECKING from typing import Any from typing import ClassVar from typing import Iterator from typing import Mapping -from typing import Protocol from typing import Sequence from narwhals._compliant.typing import CompliantDataFrameT @@ -23,13 +23,23 @@ Frame: TypeAlias = "CompliantDataFrame[Any, Any, NativeFrameT_co] | CompliantLazyFrame[Any, NativeFrameT_co]" +if not TYPE_CHECKING: # pragma: no cover + if sys.version_info >= (3, 9): + from typing import Protocol as Protocol38 + else: + from typing import Generic as Protocol38 +else: # pragma: no cover + # TODO @dangotbanned: Remove after dropping `3.8` (#2084) + # - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386 + from typing import Protocol as Protocol38 + __all__ = ["CompliantGroupBy", "EagerGroupBy"] # NOTE: Type checkers disagree # - `pyright` wants invariant `*Expr` # - `mypy` want contravariant `*Expr` -class CompliantGroupBy(Protocol[CompliantFrameT, CompliantExprT]): # type: ignore[misc] +class CompliantGroupBy(Protocol38[CompliantFrameT, CompliantExprT]): # type: ignore[misc] _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] _compliant_frame: CompliantFrameT _keys: Sequence[str] @@ -60,7 +70,7 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT]) -> None: and re.sub(r"(\w+->)", "", expr._function_name) in self._NARWHALS_TO_NATIVE_AGGREGATIONS ): - # NOTE: Need to define `_implementation` in both protocols + # NOTE: Need to define `_implementation` in both protocols (#2251) name = self.compliant._implementation.name.lower() # type: ignore # noqa: PGH003 msg = ( f"Non-trivial complex aggregation found.\n\n" @@ -77,6 +87,6 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT]) -> None: class EagerGroupBy( # type: ignore[misc] CompliantGroupBy[CompliantDataFrameT, CompliantExprT], - Protocol[CompliantDataFrameT, CompliantExprT], + Protocol38[CompliantDataFrameT, CompliantExprT], ): def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT]]: ... From 6f68f3aea341a5f372ca19a49f4fc61b3792c328 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Mar 2025 18:59:56 +0000 Subject: [PATCH 03/23] chore: export to `_compliant` --- narwhals/_compliant/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index 2c65296a66..dfa709aaa8 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -6,6 +6,8 @@ from narwhals._compliant.expr import CompliantExpr from narwhals._compliant.expr import EagerExpr from narwhals._compliant.expr import LazyExpr +from narwhals._compliant.group_by import CompliantGroupBy +from narwhals._compliant.group_by import EagerGroupBy from narwhals._compliant.namespace import CompliantNamespace from narwhals._compliant.namespace import EagerNamespace from narwhals._compliant.selectors import CompliantSelector @@ -30,6 +32,7 @@ "CompliantExpr", "CompliantExprT", "CompliantFrameT", + "CompliantGroupBy", "CompliantLazyFrame", "CompliantNamespace", "CompliantSelector", @@ -40,6 +43,7 @@ "EagerDataFrame", "EagerDataFrameT", "EagerExpr", + "EagerGroupBy", "EagerNamespace", "EagerSelectorNamespace", "EagerSeries", From 5aaceb78cd65a5644355380103d37e45f764467e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Mar 2025 19:15:52 +0000 Subject: [PATCH 04/23] refactor: Implement `ArrowGroupBy` --- narwhals/_arrow/dataframe.py | 2 +- narwhals/_arrow/group_by.py | 90 ++++++++++++++++-------------------- 2 files changed, 40 insertions(+), 52 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 209d831bf1..d88d80fd4e 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -399,7 +399,7 @@ def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: def group_by(self: Self, *keys: str, drop_null_keys: bool) -> ArrowGroupBy: from narwhals._arrow.group_by import ArrowGroupBy - return ArrowGroupBy(self, list(keys), drop_null_keys=drop_null_keys) + return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys) def join( self: Self, diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 25ec346dd2..5a5a83d054 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -4,15 +4,19 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import ClassVar from typing import Iterator +from typing import Mapping +from typing import Sequence import pyarrow as pa import pyarrow.compute as pc +from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.utils import cast_to_comparable_string_types from narwhals._arrow.utils import extract_py_scalar +from narwhals._compliant import EagerGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._expression_parsing import is_elementary_expression from narwhals.utils import generate_temporary_column_name if TYPE_CHECKING: @@ -22,62 +26,44 @@ from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.typing import Incomplete -POLARS_TO_ARROW_AGGREGATIONS = { - "sum": "sum", - "mean": "mean", - "median": "approximate_median", - "max": "max", - "min": "min", - "std": "stddev", - "var": "variance", - "len": "count", - "n_unique": "count_distinct", - "count": "count", -} - - -class ArrowGroupBy: + +class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]): + _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] = { + "sum": "sum", + "mean": "mean", + "median": "approximate_median", + "max": "max", + "min": "min", + "std": "stddev", + "var": "variance", + "len": "count", + "n_unique": "count_distinct", + "count": "count", + } + def __init__( - self: Self, df: ArrowDataFrame, keys: list[str], *, drop_null_keys: bool + self, + compliant_frame: ArrowDataFrame, + keys: Sequence[str], + *, + drop_null_keys: bool, ) -> None: if drop_null_keys: - self._df = df.drop_nulls(keys) + self._compliant_frame = compliant_frame.drop_nulls(keys) else: - self._df = df - self._keys = keys.copy() - self._grouped = pa.TableGroupBy(self._df._native_frame, self._keys) + self._compliant_frame = compliant_frame + self._keys: list[str] = list(keys) + self._grouped = pa.TableGroupBy(self.compliant.native, self._keys) def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: - all_simple_aggs = True - for expr in exprs: - if not ( - is_elementary_expression(expr) - and re.sub(r"(\w+->)", "", expr._function_name) - in POLARS_TO_ARROW_AGGREGATIONS - ): - all_simple_aggs = False - break - - if not all_simple_aggs: - msg = ( - "Non-trivial complex aggregation found.\n\n" - "Hint: you were probably trying to apply a non-elementary aggregation with a " - "pyarrow table.\n" - "Please rewrite your query such that group-by aggregations " - "are elementary. For example, instead of:\n\n" - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" - "use:\n\n" - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" - ) - raise ValueError(msg) - + self._ensure_all_simple(exprs) aggs: list[tuple[str, str, Any]] = [] expected_pyarrow_column_names: list[str] = self._keys.copy() new_column_names: list[str] = self._keys.copy() for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases( - expr, self._df, self._keys + expr, self.compliant, self._keys ) if expr._depth == 0: @@ -102,7 +88,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: else: option = None - function_name = POLARS_TO_ARROW_AGGREGATIONS[function_name] + function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS[function_name] new_column_names.extend(aliases) expected_pyarrow_column_names.extend( @@ -133,18 +119,20 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: ] new_column_names = [new_column_names[i] for i in index_map] result_simple = result_simple.rename_columns(new_column_names) - if self._df._backend_version < (12, 0, 0): + if self.compliant._backend_version < (12, 0, 0): columns = result_simple.column_names result_simple = result_simple.select( [*self._keys, *[col for col in columns if col not in self._keys]] ) - return self._df._from_native_frame(result_simple) + return self.compliant._from_native_frame(result_simple) def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: - col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns) + col_token = generate_temporary_column_name( + n_bytes=8, columns=self.compliant.columns + ) null_token: str = "__null_token_value__" # noqa: S105 - table = self._df._native_frame + table = self.compliant.native # NOTE: stubs fail in multiple places for `ChunkedArray` it, separator_scalar = cast_to_comparable_string_types( *(table[key] for key in self._keys), separator="" @@ -160,7 +148,7 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: ) table = table.add_column(i=0, field_=col_token, column=key_values) for v in pc.unique(key_values): - t = self._df._from_native_frame( + t = self.compliant._from_native_frame( table.filter(pc.equal(table[col_token], v)).drop([col_token]) ) row = t.simple_select(*self._keys).row(0) From 41ea47a740d79df11064a01a813ac72c59989b6f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Mar 2025 19:22:51 +0000 Subject: [PATCH 05/23] fix(typing): Get an agreement on variance - Dropped `native` as it got too complex - I've done this fake `Any` thing to make `mypy` understand in multiple places --- narwhals/_compliant/group_by.py | 46 ++++++++++----------------------- narwhals/_compliant/typing.py | 8 ++++++ 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 8338f9ff9c..4d3e63a434 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -9,20 +9,11 @@ from typing import Mapping from typing import Sequence -from narwhals._compliant.typing import CompliantDataFrameT -from narwhals._compliant.typing import CompliantExprT -from narwhals._compliant.typing import CompliantFrameT -from narwhals._compliant.typing import NativeFrameT_co +from narwhals._compliant.typing import CompliantDataFrameT_co +from narwhals._compliant.typing import CompliantExprT_contra +from narwhals._compliant.typing import CompliantFrameT_co from narwhals._expression_parsing import is_elementary_expression -if TYPE_CHECKING: - from typing_extensions import TypeAlias - - from narwhals._compliant.dataframe import CompliantDataFrame - from narwhals._compliant.dataframe import CompliantLazyFrame - - Frame: TypeAlias = "CompliantDataFrame[Any, Any, NativeFrameT_co] | CompliantLazyFrame[Any, NativeFrameT_co]" - if not TYPE_CHECKING: # pragma: no cover if sys.version_info >= (3, 9): from typing import Protocol as Protocol38 @@ -36,34 +27,25 @@ __all__ = ["CompliantGroupBy", "EagerGroupBy"] -# NOTE: Type checkers disagree -# - `pyright` wants invariant `*Expr` -# - `mypy` want contravariant `*Expr` -class CompliantGroupBy(Protocol38[CompliantFrameT, CompliantExprT]): # type: ignore[misc] +class CompliantGroupBy(Protocol38[CompliantFrameT_co, CompliantExprT_contra]): _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] - _compliant_frame: CompliantFrameT + _compliant_frame: Any _keys: Sequence[str] def __init__( self, - compliant_frame: CompliantFrameT, + compliant_frame: CompliantFrameT_co, keys: Sequence[str], *, drop_null_keys: bool, ) -> None: ... @property - def compliant(self) -> CompliantFrameT: - return self._compliant_frame - - @property - def native( - self: CompliantGroupBy[Frame[NativeFrameT_co], CompliantExprT], - ) -> NativeFrameT_co: - return self.compliant.native + def compliant(self) -> CompliantFrameT_co: + return self._compliant_frame # type: ignore[no-any-return] - def agg(self, *exprs: CompliantExprT) -> CompliantFrameT: ... + def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ... - def _ensure_all_simple(self, exprs: Sequence[CompliantExprT]) -> None: + def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: for expr in exprs: if ( not is_elementary_expression(expr) @@ -85,8 +67,8 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT]) -> None: raise ValueError(msg) -class EagerGroupBy( # type: ignore[misc] - CompliantGroupBy[CompliantDataFrameT, CompliantExprT], - Protocol38[CompliantDataFrameT, CompliantExprT], +class EagerGroupBy( + CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra], + Protocol38[CompliantDataFrameT_co, CompliantExprT_contra], ): - def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT]]: ... + def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ... diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 1da99e6685..e3e2fb4a12 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -42,9 +42,17 @@ "CompliantFrameT", bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]", ) +CompliantFrameT_co = TypeVar( + "CompliantFrameT_co", + bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]", + covariant=True, +) CompliantDataFrameT = TypeVar( "CompliantDataFrameT", bound="CompliantDataFrame[Any, Any, Any]" ) +CompliantDataFrameT_co = TypeVar( + "CompliantDataFrameT_co", bound="CompliantDataFrame[Any, Any, Any]", covariant=True +) CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame[Any, Any]") IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co" CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]") From 96376621b7ba717d218a74799b69899c86441cc9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Mar 2025 21:11:41 +0000 Subject: [PATCH 06/23] refactor: Implement `PandasLikeGroupBy`? - Having a hard time working out what is going on here - All I've changed is what the refs are named --- narwhals/_arrow/group_by.py | 1 + narwhals/_compliant/group_by.py | 1 + narwhals/_dask/expr.py | 6 ++- narwhals/_pandas_like/dataframe.py | 6 +-- narwhals/_pandas_like/expr.py | 5 +- narwhals/_pandas_like/group_by.py | 83 +++++++++++++++++------------- 6 files changed, 59 insertions(+), 43 deletions(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 5a5a83d054..690be487e1 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -45,6 +45,7 @@ def __init__( self, compliant_frame: ArrowDataFrame, keys: Sequence[str], + /, *, drop_null_keys: bool, ) -> None: diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 4d3e63a434..497938b47b 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -36,6 +36,7 @@ def __init__( self, compliant_frame: CompliantFrameT_co, keys: Sequence[str], + /, *, drop_null_keys: bool, ) -> None: ... diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 8d87a96e41..d2dbd7b02d 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -550,7 +550,11 @@ def over( order_by: Sequence[str] | None, ) -> Self: # pandas is a required dependency of dask so it's safe to import this - from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT + from narwhals._pandas_like.group_by import PandasLikeGroupBy + + AGGREGATIONS_TO_PANDAS_EQUIVALENT = ( # noqa: N806 + PandasLikeGroupBy._NARWHALS_TO_NATIVE_AGGREGATIONS + ) if not partition_by: assert order_by is not None # help type checkers # noqa: S101 diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 49b2e86ed6..fbb4f41777 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -580,11 +580,7 @@ def collect( def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PandasLikeGroupBy: from narwhals._pandas_like.group_by import PandasLikeGroupBy - return PandasLikeGroupBy( - self, - list(keys), - drop_null_keys=drop_null_keys, - ) + return PandasLikeGroupBy(self, keys, drop_null_keys=drop_null_keys) def join( self: Self, diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 104a4efc00..1168972bd5 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -11,7 +11,7 @@ from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._expression_parsing import is_elementary_expression -from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT +from narwhals._pandas_like.group_by import PandasLikeGroupBy from narwhals._pandas_like.series import PandasLikeSeries from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import generate_temporary_column_name @@ -223,6 +223,9 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: ) raise NotImplementedError(msg) else: + AGGREGATIONS_TO_PANDAS_EQUIVALENT = ( # noqa: N806 + PandasLikeGroupBy._NARWHALS_TO_NATIVE_AGGREGATIONS + ) function_name: str = re.sub(r"(\w+->)", "", self._function_name) pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get( function_name, diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 0da57297fc..c345dc8e00 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -5,8 +5,12 @@ import warnings from typing import TYPE_CHECKING from typing import Any +from typing import ClassVar from typing import Iterator +from typing import Mapping +from typing import Sequence +from narwhals._compliant import EagerGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._expression_parsing import is_elementary_expression from narwhals._pandas_like.utils import horizontal_concat @@ -22,39 +26,44 @@ from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr -AGGREGATIONS_TO_PANDAS_EQUIVALENT = { - "sum": "sum", - "mean": "mean", - "median": "median", - "max": "max", - "min": "min", - "std": "std", - "var": "var", - "len": "size", - "n_unique": "nunique", - "count": "count", -} +class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr"]): + _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] = { + "sum": "sum", + "mean": "mean", + "median": "median", + "max": "max", + "min": "min", + "std": "std", + "var": "var", + "len": "size", + "n_unique": "nunique", + "count": "count", + } -class PandasLikeGroupBy: def __init__( - self: Self, df: PandasLikeDataFrame, keys: list[str], *, drop_null_keys: bool + self: Self, + df: PandasLikeDataFrame, + keys: Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - self._df = df - self._keys = keys + self._compliant_frame = df + self._keys: list[str] = list(keys) # Drop index to avoid potential collisions: # https://github.com/narwhals-dev/narwhals/issues/1907. - if set(df._native_frame.index.names).intersection(df.columns): - native_frame = df._native_frame.reset_index(drop=True) + if set(df.native.index.names).intersection(df.columns): + native_frame = df.native.reset_index(drop=True) else: - native_frame = df._native_frame + native_frame = df.native if ( - self._df._implementation is Implementation.PANDAS - and self._df._backend_version < (1, 1) + self.compliant._implementation is Implementation.PANDAS + and self.compliant._backend_version < (1, 1) ): # pragma: no cover if ( not drop_null_keys - and self._df.simple_select(*self._keys)._native_frame.isna().any().any() + and self.compliant.simple_select(*self._keys).native.isna().any().any() ): msg = "Grouping by null values is not supported in pandas < 1.1.0" raise NotImplementedError(msg) @@ -74,19 +83,21 @@ def __init__( ) def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR0915 - implementation = self._df._implementation - backend_version = self._df._backend_version + implementation = self.compliant._implementation + backend_version = self.compliant._backend_version new_names: list[str] = self._keys.copy() all_aggs_are_simple = True for expr in exprs: - _, aliases = evaluate_output_names_and_aliases(expr, self._df, self._keys) + _, aliases = evaluate_output_names_and_aliases( + expr, self.compliant, self._keys + ) new_names.extend(aliases) if not ( is_elementary_expression(expr) and re.sub(r"(\w+->)", "", expr._function_name) - in AGGREGATIONS_TO_PANDAS_EQUIVALENT + in self._NARWHALS_TO_NATIVE_AGGREGATIONS ): all_aggs_are_simple = False @@ -111,11 +122,11 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR if all_aggs_are_simple: for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases( - expr, self._df, self._keys + expr, self.compliant, self._keys ) if expr._depth == 0: # e.g. agg(nw.len()) # noqa: ERA001 - function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT.get( + function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS.get( expr._function_name, expr._function_name ) simple_aggs_functions.add(function_name) @@ -128,7 +139,7 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR # e.g. agg(nw.mean('a')) # noqa: ERA001 function_name = re.sub(r"(\w+->)", "", expr._function_name) - function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT.get( + function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS.get( function_name, function_name ) @@ -247,17 +258,17 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR ) else: # No aggregation provided - result = self._df.__native_namespace__().DataFrame( + result = self.compliant.__native_namespace__().DataFrame( list(self._grouped.groups.keys()), columns=self._keys ) # Keep inplace=True to avoid making a redundant copy. # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files result.reset_index(inplace=True) # noqa: PD002 - return self._df._from_native_frame( + return self.compliant._from_native_frame( select_columns_by_name(result, new_names, backend_version, implementation) ) - if self._df._native_frame.empty: + if self.compliant.native.empty: # Don't even attempt this, it's way too inconsistent across pandas versions. msg = ( "No results for group-by aggregation.\n\n" @@ -285,9 +296,9 @@ def func(df: Any) -> Any: out_group = [] out_names = [] for expr in exprs: - results_keys = expr(self._df._from_native_frame(df)) + results_keys = expr(self.compliant._from_native_frame(df)) for result_keys in results_keys: - out_group.append(result_keys._native_series.iloc[0]) + out_group.append(result_keys.native.iloc[0]) out_names.append(result_keys.name) return native_series_from_iterable( out_group, @@ -305,7 +316,7 @@ def func(df: Any) -> Any: # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files result_complex.reset_index(inplace=True) # noqa: PD002 - return self._df._from_native_frame( + return self.compliant._from_native_frame( select_columns_by_name( result_complex, new_names, backend_version, implementation ) @@ -319,4 +330,4 @@ def __iter__(self: Self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: category=FutureWarning, ) for key, group in self._grouped: - yield (key, self._df._from_native_frame(group)) + yield (key, self.compliant._from_native_frame(group)) From 810fa209dba43e2bfb81c3195f00e25b89e227bc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Mar 2025 21:42:51 +0000 Subject: [PATCH 07/23] refactor: Implement `DaskLazyGroupBy` Much happier with this than the `pandas` one --- narwhals/_dask/dataframe.py | 2 +- narwhals/_dask/group_by.py | 131 ++++++++++++------------------------ 2 files changed, 43 insertions(+), 90 deletions(-) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 58add7a79b..6103afa7fb 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -388,7 +388,7 @@ def join_asof( def group_by(self: Self, *by: str, drop_null_keys: bool) -> DaskLazyGroupBy: from narwhals._dask.group_by import DaskLazyGroupBy - return DaskLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys) + return DaskLazyGroupBy(self, by, drop_null_keys=drop_null_keys) def tail(self: Self, n: int) -> Self: # pragma: no cover native_frame = self._native_frame diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 6fadc8ab3d..298411c05d 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -5,13 +5,14 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import ClassVar from typing import Mapping from typing import Sequence import dask.dataframe as dd +from narwhals._compliant import CompliantGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._expression_parsing import is_elementary_expression try: import dask.dataframe.dask_expr as dx @@ -54,90 +55,54 @@ def std(ddof: int) -> _AggFn: return partial(_DaskGroupBy.std, ddof=ddof) -POLARS_TO_DASK_AGGREGATIONS: Mapping[str, Aggregation] = { - "sum": "sum", - "mean": "mean", - "median": "median", - "max": "max", - "min": "min", - "std": std, - "var": var, - "len": "size", - "n_unique": n_unique, - "count": "count", -} +class DaskLazyGroupBy(CompliantGroupBy["DaskLazyFrame", "DaskExpr"]): + _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Aggregation]] = { + "sum": "sum", + "mean": "mean", + "median": "median", + "max": "max", + "min": "min", + "std": std, + "var": var, + "len": "size", + "n_unique": n_unique, + "count": "count", + } - -class DaskLazyGroupBy: def __init__( - self: Self, df: DaskLazyFrame, keys: list[str], *, drop_null_keys: bool + self: Self, df: DaskLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool ) -> None: - self._df: DaskLazyFrame = df - self._keys = keys - self._grouped = self._df._native_frame.groupby( - list(self._keys), - dropna=drop_null_keys, - observed=True, - ) - - def agg( - self: Self, - *exprs: DaskExpr, - ) -> DaskLazyFrame: - return agg_dask( - self._df, - self._grouped, - exprs, - self._keys, - self._from_native_frame, + self._compliant_frame = df + self._keys: list[str] = list(keys) + self._grouped = self.compliant.native.groupby( + list(self._keys), dropna=drop_null_keys, observed=True ) - def _from_native_frame(self: Self, df: dd.DataFrame) -> DaskLazyFrame: + def agg(self: Self, *exprs: DaskExpr) -> DaskLazyFrame: from narwhals._dask.dataframe import DaskLazyFrame - return DaskLazyFrame( - df, - backend_version=self._df._backend_version, - version=self._df._version, - ) - - -def agg_dask( - df: DaskLazyFrame, - grouped: Any, - exprs: Sequence[DaskExpr], - keys: list[str], - from_dataframe: Callable[[Any], DaskLazyFrame], -) -> DaskLazyFrame: - """This should be the fastpath, but cuDF is too far behind to use it. - - - https://github.com/rapidsai/cudf/issues/15118 - - https://github.com/rapidsai/cudf/issues/15084 - """ - if not exprs: - # No aggregation provided - return df.simple_select(*keys).unique(subset=keys, keep="any") - - all_simple_aggs = True - for expr in exprs: - if not ( - is_elementary_expression(expr) - and re.sub(r"(\w+->)", "", expr._function_name) in POLARS_TO_DASK_AGGREGATIONS - ): - all_simple_aggs = False - break - - if all_simple_aggs: + if not exprs: + # No aggregation provided + return self.compliant.simple_select(*self._keys).unique( + self._keys, keep="any" + ) + self._ensure_all_simple(exprs) + # This should be the fastpath, but cuDF is too far behind to use it. + # - https://github.com/rapidsai/cudf/issues/15118 + # - https://github.com/rapidsai/cudf/issues/15084 + POLARS_TO_DASK_AGGREGATIONS = self._NARWHALS_TO_NATIVE_AGGREGATIONS # noqa: N806 simple_aggregations: dict[str, tuple[str, Aggregation]] = {} for expr in exprs: - output_names, aliases = evaluate_output_names_and_aliases(expr, df, keys) + output_names, aliases = evaluate_output_names_and_aliases( + expr, self.compliant, self._keys + ) if expr._depth == 0: # e.g. agg(nw.len()) # noqa: ERA001 function_name = POLARS_TO_DASK_AGGREGATIONS.get( expr._function_name, expr._function_name ) simple_aggregations.update( - dict.fromkeys(aliases, (keys[0], function_name)) + dict.fromkeys(aliases, (self._keys[0], function_name)) ) continue @@ -150,24 +115,12 @@ def agg_dask( if callable(agg_function) else agg_function ) - simple_aggregations.update( - { - alias: (output_name, agg_function) - for alias, output_name in zip(aliases, output_names) - } + (alias, (output_name, agg_function)) + for alias, output_name in zip(aliases, output_names) ) - result_simple = grouped.agg(**simple_aggregations) - return from_dataframe(result_simple.reset_index()) - - msg = ( - "Non-trivial complex aggregation found.\n\n" - "Hint: you were probably trying to apply a non-elementary aggregation with a " - "dask dataframe.\n" - "Please rewrite your query such that group-by aggregations " - "are elementary. For example, instead of:\n\n" - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" - "use:\n\n" - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" - ) - raise ValueError(msg) + return DaskLazyFrame( + self._grouped.agg(**simple_aggregations).reset_index(), + backend_version=self.compliant._backend_version, + version=self.compliant._version, + ) From 0a2ee8f7913d8da9fa5b78ea77247a12a2193447 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 11:57:58 +0000 Subject: [PATCH 08/23] chore(typing): Utilize (#2251) Fixes https://github.com/narwhals-dev/narwhals/actions/runs/13968615945/job/39104706209?pr=2252 --- narwhals/_compliant/group_by.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 497938b47b..06879f42fb 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -53,8 +53,7 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: and re.sub(r"(\w+->)", "", expr._function_name) in self._NARWHALS_TO_NATIVE_AGGREGATIONS ): - # NOTE: Need to define `_implementation` in both protocols (#2251) - name = self.compliant._implementation.name.lower() # type: ignore # noqa: PGH003 + name = self.compliant._implementation.name.lower() msg = ( f"Non-trivial complex aggregation found.\n\n" f"Hint: you were probably trying to apply a non-elementary aggregation with a" From 2e3291b30382fe01925ae0ccdb077ae14a04ad9d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 12:58:06 +0000 Subject: [PATCH 09/23] refactor: `CompliantGroupBy._remap_expr_name` - Shorted in each backend - Added docs - Accounts for `dask` deviation from `str` --- narwhals/_arrow/group_by.py | 6 ++---- narwhals/_compliant/group_by.py | 22 +++++++++++++++++++++- narwhals/_dask/group_by.py | 30 ++++++++++++------------------ narwhals/_pandas_like/group_by.py | 12 ++++-------- 4 files changed, 39 insertions(+), 31 deletions(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 690be487e1..aff31071e5 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -68,7 +68,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: ) if expr._depth == 0: - # e.g. agg(nw.len()) # noqa: ERA001 + # e.g. `agg(nw.len())` if expr._function_name != "len": # pragma: no cover msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) @@ -76,7 +76,6 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: new_column_names.append(aliases[0]) expected_pyarrow_column_names.append(f"{self._keys[0]}_count") aggs.append((self._keys[0], "count", pc.CountOptions(mode="all"))) - continue function_name = re.sub(r"(\w+->)", "", expr._function_name) @@ -89,8 +88,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: else: option = None - function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS[function_name] - + function_name = self._remap_expr_name(function_name) new_column_names.extend(aliases) expected_pyarrow_column_names.extend( [f"{output_name}_{function_name}" for output_name in output_names] diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 06879f42fb..3f86208646 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -13,6 +13,7 @@ from narwhals._compliant.typing import CompliantExprT_contra from narwhals._compliant.typing import CompliantFrameT_co from narwhals._expression_parsing import is_elementary_expression +from narwhals._translate import TypeVar # type: ignore[attr-defined] if not TYPE_CHECKING: # pragma: no cover if sys.version_info >= (3, 9): @@ -26,8 +27,15 @@ __all__ = ["CompliantGroupBy", "EagerGroupBy"] +NativeAggregationT_co = TypeVar("NativeAggregationT_co", covariant=True, default="str") +"""Some backends *may* return a `Callable` instead of a `str` referring to one.""" -class CompliantGroupBy(Protocol38[CompliantFrameT_co, CompliantExprT_contra]): + +# TODO @dangotbanned: Compile and assign a name to `r"(\w+->)"` +# - Then make `re.sub(r"(\w+->)", "", expr._function_name)` a method +class CompliantGroupBy( + Protocol38[CompliantFrameT_co, CompliantExprT_contra, NativeAggregationT_co] +): _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] _compliant_frame: Any _keys: Sequence[str] @@ -66,6 +74,18 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: ) raise ValueError(msg) + @classmethod + def _remap_expr_name(cls, name: str, /) -> NativeAggregationT_co: + """Replace `name`, with some native representation. + + Arguments: + name: Name of a `nw.Expr` aggregation method. + + Returns: + A native compatible representation. + """ + return cls._NARWHALS_TO_NATIVE_AGGREGATIONS.get(name, name) + class EagerGroupBy( CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra], diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 298411c05d..05d4af1477 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -30,12 +30,14 @@ PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any] _AggFn: TypeAlias = Callable[..., Any] - Aggregation: TypeAlias = "str | _AggFn" from dask_expr._groupby import GroupBy as _DaskGroupBy else: _DaskGroupBy = dx._groupby.GroupBy +Aggregation: TypeAlias = "str | _AggFn" +"""The name of an aggregation function, or the function itself.""" + def n_unique() -> dd.Aggregation: def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]: @@ -55,7 +57,7 @@ def std(ddof: int) -> _AggFn: return partial(_DaskGroupBy.std, ddof=ddof) -class DaskLazyGroupBy(CompliantGroupBy["DaskLazyFrame", "DaskExpr"]): +class DaskLazyGroupBy(CompliantGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]): _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Aggregation]] = { "sum": "sum", "mean": "mean", @@ -90,33 +92,25 @@ def agg(self: Self, *exprs: DaskExpr) -> DaskLazyFrame: # This should be the fastpath, but cuDF is too far behind to use it. # - https://github.com/rapidsai/cudf/issues/15118 # - https://github.com/rapidsai/cudf/issues/15084 - POLARS_TO_DASK_AGGREGATIONS = self._NARWHALS_TO_NATIVE_AGGREGATIONS # noqa: N806 simple_aggregations: dict[str, tuple[str, Aggregation]] = {} for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases( expr, self.compliant, self._keys ) if expr._depth == 0: - # e.g. agg(nw.len()) # noqa: ERA001 - function_name = POLARS_TO_DASK_AGGREGATIONS.get( - expr._function_name, expr._function_name - ) - simple_aggregations.update( - dict.fromkeys(aliases, (self._keys[0], function_name)) - ) + # e.g. `agg(nw.len())` + column = self._keys[0] + agg_fn = self._remap_expr_name(expr._function_name) + simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn))) continue - # e.g. agg(nw.mean('a')) # noqa: ERA001 + # e.g. `agg(nw.mean('a'))` function_name = re.sub(r"(\w+->)", "", expr._function_name) - agg_function = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) + agg_fn = self._remap_expr_name(function_name) # deal with n_unique case in a "lazy" mode to not depend on dask globally - agg_function = ( - agg_function(**expr._call_kwargs) - if callable(agg_function) - else agg_function - ) + agg_fn = agg_fn(**expr._call_kwargs) if callable(agg_fn) else agg_fn simple_aggregations.update( - (alias, (output_name, agg_function)) + (alias, (output_name, agg_fn)) for alias, output_name in zip(aliases, output_names) ) return DaskLazyFrame( diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index c345dc8e00..571205a738 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -125,10 +125,8 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR expr, self.compliant, self._keys ) if expr._depth == 0: - # e.g. agg(nw.len()) # noqa: ERA001 - function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS.get( - expr._function_name, expr._function_name - ) + # e.g. `agg(nw.len())` + function_name = self._remap_expr_name(expr._function_name) simple_aggs_functions.add(function_name) for alias in aliases: @@ -137,11 +135,9 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR simple_agg_new_names.append(alias) continue - # e.g. agg(nw.mean('a')) # noqa: ERA001 + # e.g. `agg(nw.mean('a'))` function_name = re.sub(r"(\w+->)", "", expr._function_name) - function_name = self._NARWHALS_TO_NATIVE_AGGREGATIONS.get( - function_name, function_name - ) + function_name = self._remap_expr_name(function_name) is_n_unique = function_name == "nunique" is_std = function_name == "std" From 799720bf123397eda9bb10fc5f4b9034e5cfbcc8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:31:45 +0000 Subject: [PATCH 10/23] refactor: `CompliantGroupBy._leaf_name` - More performant to compile a single pattern and reuse everywhere - Gives a name to a common op - Backend code is shorter --- narwhals/_arrow/group_by.py | 3 +-- narwhals/_compliant/group_by.py | 24 ++++++++++++++++++++---- narwhals/_compliant/typing.py | 5 +++-- narwhals/_dask/expr.py | 3 +-- narwhals/_dask/group_by.py | 4 +--- narwhals/_pandas_like/expr.py | 3 +-- narwhals/_pandas_like/group_by.py | 9 ++------- 7 files changed, 29 insertions(+), 22 deletions(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index aff31071e5..6beeeb5517 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import re from typing import TYPE_CHECKING from typing import Any from typing import ClassVar @@ -78,7 +77,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: aggs.append((self._keys[0], "count", pc.CountOptions(mode="all"))) continue - function_name = re.sub(r"(\w+->)", "", expr._function_name) + function_name = self._leaf_name(expr) if function_name in {"std", "var"}: option: Any = pc.VarianceOptions(ddof=expr._call_kwargs["ddof"]) elif function_name in {"len", "n_unique"}: diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 3f86208646..031ed7fd5f 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -10,6 +10,7 @@ from typing import Sequence from narwhals._compliant.typing import CompliantDataFrameT_co +from narwhals._compliant.typing import CompliantExprAny from narwhals._compliant.typing import CompliantExprT_contra from narwhals._compliant.typing import CompliantFrameT_co from narwhals._expression_parsing import is_elementary_expression @@ -31,8 +32,19 @@ """Some backends *may* return a `Callable` instead of a `str` referring to one.""" -# TODO @dangotbanned: Compile and assign a name to `r"(\w+->)"` -# - Then make `re.sub(r"(\w+->)", "", expr._function_name)` a method +UNNAMED_PATTERN: re.Pattern[str] = re.compile(r"(\w+->)") +"""I'm unsure what this should be called. + +Seems to be used as a way to get `thing_n`: + + "thing_1->thing_2->...->thing_n" + +But with the assumption that `depth` is constrained below `2` (maybe?). + +**In isolation - the pattern doesn't mean any of that.** 🤔 +""" + + class CompliantGroupBy( Protocol38[CompliantFrameT_co, CompliantExprT_contra, NativeAggregationT_co] ): @@ -58,8 +70,7 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: for expr in exprs: if ( not is_elementary_expression(expr) - and re.sub(r"(\w+->)", "", expr._function_name) - in self._NARWHALS_TO_NATIVE_AGGREGATIONS + and self._leaf_name(expr) in self._NARWHALS_TO_NATIVE_AGGREGATIONS ): name = self.compliant._implementation.name.lower() msg = ( @@ -86,6 +97,11 @@ def _remap_expr_name(cls, name: str, /) -> NativeAggregationT_co: """ return cls._NARWHALS_TO_NATIVE_AGGREGATIONS.get(name, name) + @classmethod + def _leaf_name(cls, expr: CompliantExprAny, /) -> str: + """Return the last function name in the chain defined by `expr`.""" + return UNNAMED_PATTERN.sub("", expr._function_name) + class EagerGroupBy( CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra], diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 565a032d2e..56eafdfa72 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -58,9 +58,10 @@ ) CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame[Any, Any]") IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co" -CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]") +CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]" +CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny) CompliantExprT_contra = TypeVar( - "CompliantExprT_contra", bound="CompliantExpr[Any, Any]", contravariant=True + "CompliantExprT_contra", bound=CompliantExprAny, contravariant=True ) EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any]") diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index d2dbd7b02d..a793bdb914 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re import warnings from typing import TYPE_CHECKING from typing import Any @@ -571,7 +570,7 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: ) raise NotImplementedError(msg) else: - function_name = re.sub(r"(\w+->)", "", self._function_name) + function_name = PandasLikeGroupBy._leaf_name(self) try: dask_function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT[function_name] except KeyError: diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 05d4af1477..0199af7a58 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from functools import partial from typing import TYPE_CHECKING from typing import Any @@ -105,8 +104,7 @@ def agg(self: Self, *exprs: DaskExpr) -> DaskLazyFrame: continue # e.g. `agg(nw.mean('a'))` - function_name = re.sub(r"(\w+->)", "", expr._function_name) - agg_fn = self._remap_expr_name(function_name) + agg_fn = self._remap_expr_name(self._leaf_name(expr)) # deal with n_unique case in a "lazy" mode to not depend on dask globally agg_fn = agg_fn(**expr._call_kwargs) if callable(agg_fn) else agg_fn simple_aggregations.update( diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 1168972bd5..ec5da3d86f 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -226,7 +225,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: AGGREGATIONS_TO_PANDAS_EQUIVALENT = ( # noqa: N806 PandasLikeGroupBy._NARWHALS_TO_NATIVE_AGGREGATIONS ) - function_name: str = re.sub(r"(\w+->)", "", self._function_name) + function_name = PandasLikeGroupBy._leaf_name(self) pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get( function_name, AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(function_name), diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 571205a738..dcc37faefe 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import re import warnings from typing import TYPE_CHECKING from typing import Any @@ -93,11 +92,9 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR expr, self.compliant, self._keys ) new_names.extend(aliases) - if not ( is_elementary_expression(expr) - and re.sub(r"(\w+->)", "", expr._function_name) - in self._NARWHALS_TO_NATIVE_AGGREGATIONS + and self._leaf_name(expr) in self._NARWHALS_TO_NATIVE_AGGREGATIONS ): all_aggs_are_simple = False @@ -136,9 +133,7 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR continue # e.g. `agg(nw.mean('a'))` - function_name = re.sub(r"(\w+->)", "", expr._function_name) - function_name = self._remap_expr_name(function_name) - + function_name = self._remap_expr_name(self._leaf_name(expr)) is_n_unique = function_name == "nunique" is_std = function_name == "std" is_var = function_name == "var" From 25eb682f4d9ef0260869bd4818cd38dd965740ef Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:43:35 +0000 Subject: [PATCH 11/23] refactor: `CompliantGroupBy._is_simple` Keeps this part of `pandas` in sync, despite it doing some extra name stuff --- narwhals/_compliant/group_by.py | 13 +++++++++---- narwhals/_pandas_like/group_by.py | 6 +----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 031ed7fd5f..df254036c1 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -68,10 +68,7 @@ def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ... def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: for expr in exprs: - if ( - not is_elementary_expression(expr) - and self._leaf_name(expr) in self._NARWHALS_TO_NATIVE_AGGREGATIONS - ): + if not self._is_simple(expr): name = self.compliant._implementation.name.lower() msg = ( f"Non-trivial complex aggregation found.\n\n" @@ -85,6 +82,14 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: ) raise ValueError(msg) + @classmethod + def _is_simple(cls, expr: CompliantExprAny, /) -> bool: + """Return `True` is we can efficiently use `expr` in a native `group_by` context.""" + return ( + is_elementary_expression(expr) + and cls._leaf_name(expr) in cls._NARWHALS_TO_NATIVE_AGGREGATIONS + ) + @classmethod def _remap_expr_name(cls, name: str, /) -> NativeAggregationT_co: """Replace `name`, with some native representation. diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index dcc37faefe..576524e4bb 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -11,7 +11,6 @@ from narwhals._compliant import EagerGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._expression_parsing import is_elementary_expression from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable from narwhals._pandas_like.utils import select_columns_by_name @@ -92,10 +91,7 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR expr, self.compliant, self._keys ) new_names.extend(aliases) - if not ( - is_elementary_expression(expr) - and self._leaf_name(expr) in self._NARWHALS_TO_NATIVE_AGGREGATIONS - ): + if not self._is_simple(expr): all_aggs_are_simple = False # dict of {output_name: root_name} that we count n_unique on From 51910ac35cf85eb02e762b1e3532aae325557149 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:51:43 +0000 Subject: [PATCH 12/23] fix: pre `3.13` protocol support https://results.pre-commit.ci/run/github/760058710/1742478308._GAah8QXT2i9H9TbtcMHwQ --- narwhals/_compliant/group_by.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index df254036c1..aebd57be35 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -109,7 +109,7 @@ def _leaf_name(cls, expr: CompliantExprAny, /) -> str: class EagerGroupBy( - CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra], + CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra, str], Protocol38[CompliantDataFrameT_co, CompliantExprT_contra], ): def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ... From d33595095a828c5312f9525159fb21d6e8b3bbbf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 14:20:12 +0000 Subject: [PATCH 13/23] refactor: Move most of `CompliantGroupBy` -> `DepthTrackingGroupBy` - `_duckdb` and `_spark_like` don't need these parts (only `_dask` does) - Also avoids needing a `TypeVar` default, which caused some issues in https://github.com/narwhals-dev/narwhals/actions/runs/13970848097/job/39111959826?pr=2252 --- narwhals/_compliant/__init__.py | 2 ++ narwhals/_compliant/group_by.py | 32 +++++++++++++++++++++----------- narwhals/_dask/group_by.py | 4 ++-- narwhals/_duckdb/group_by.py | 1 + narwhals/_spark_like/group_by.py | 1 + 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index b1995b2e0e..0c228994e6 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -7,6 +7,7 @@ from narwhals._compliant.expr import EagerExpr from narwhals._compliant.expr import LazyExpr from narwhals._compliant.group_by import CompliantGroupBy +from narwhals._compliant.group_by import DepthTrackingGroupBy from narwhals._compliant.group_by import EagerGroupBy from narwhals._compliant.namespace import CompliantNamespace from narwhals._compliant.namespace import EagerNamespace @@ -41,6 +42,7 @@ "CompliantSeries", "CompliantSeriesOrNativeExprT_co", "CompliantSeriesT", + "DepthTrackingGroupBy", "EagerDataFrame", "EagerDataFrameT", "EagerExpr", diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index aebd57be35..b9f4ce5e40 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -4,17 +4,18 @@ import sys from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import ClassVar from typing import Iterator from typing import Mapping from typing import Sequence +from typing import TypeVar from narwhals._compliant.typing import CompliantDataFrameT_co from narwhals._compliant.typing import CompliantExprAny from narwhals._compliant.typing import CompliantExprT_contra from narwhals._compliant.typing import CompliantFrameT_co from narwhals._expression_parsing import is_elementary_expression -from narwhals._translate import TypeVar # type: ignore[attr-defined] if not TYPE_CHECKING: # pragma: no cover if sys.version_info >= (3, 9): @@ -26,9 +27,11 @@ # - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386 from typing import Protocol as Protocol38 -__all__ = ["CompliantGroupBy", "EagerGroupBy"] +__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy"] -NativeAggregationT_co = TypeVar("NativeAggregationT_co", covariant=True, default="str") +NativeAggregationT_co = TypeVar( + "NativeAggregationT_co", bound="str | Callable[...,Any]", covariant=True +) """Some backends *may* return a `Callable` instead of a `str` referring to one.""" @@ -45,13 +48,14 @@ """ -class CompliantGroupBy( - Protocol38[CompliantFrameT_co, CompliantExprT_contra, NativeAggregationT_co] -): - _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] +class CompliantGroupBy(Protocol38[CompliantFrameT_co, CompliantExprT_contra]): _compliant_frame: Any _keys: Sequence[str] + @property + def compliant(self) -> CompliantFrameT_co: + return self._compliant_frame # type: ignore[no-any-return] + def __init__( self, compliant_frame: CompliantFrameT_co, @@ -60,12 +64,18 @@ def __init__( *, drop_null_keys: bool, ) -> None: ... - @property - def compliant(self) -> CompliantFrameT_co: - return self._compliant_frame # type: ignore[no-any-return] def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ... + +class DepthTrackingGroupBy( + CompliantGroupBy[CompliantFrameT_co, CompliantExprT_contra], + Protocol38[CompliantFrameT_co, CompliantExprT_contra, NativeAggregationT_co], +): + """`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`.""" + + _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] + def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: for expr in exprs: if not self._is_simple(expr): @@ -109,7 +119,7 @@ def _leaf_name(cls, expr: CompliantExprAny, /) -> str: class EagerGroupBy( - CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra, str], + DepthTrackingGroupBy[CompliantDataFrameT_co, CompliantExprT_contra, str], Protocol38[CompliantDataFrameT_co, CompliantExprT_contra], ): def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ... diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 0199af7a58..0a1d085dde 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -10,7 +10,7 @@ import dask.dataframe as dd -from narwhals._compliant import CompliantGroupBy +from narwhals._compliant import DepthTrackingGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases try: @@ -56,7 +56,7 @@ def std(ddof: int) -> _AggFn: return partial(_DaskGroupBy.std, ddof=ddof) -class DaskLazyGroupBy(CompliantGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]): +class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]): _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Aggregation]] = { "sum": "sum", "mean": "mean", diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index c9c75de5d9..7ffae18c9d 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -10,6 +10,7 @@ from narwhals._duckdb.expr import DuckDBExpr +# NOTE: No depth-tracking class DuckDBGroupBy: def __init__( self: Self, diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index a0acdcfe3a..7b8e8f395d 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -9,6 +9,7 @@ from narwhals._spark_like.expr import SparkLikeExpr +# NOTE: No depth-tracking class SparkLikeLazyGroupBy: def __init__( self: Self, From c8aa213dea551b4a4d482c09c09dc2532a0a1e7d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 14:38:52 +0000 Subject: [PATCH 14/23] refactor(DRAFT): Start simplifying, aligning lazy group bys - These two are almost identical - Trying to reduce them as much as possible, before moving the common parts to `nw._compliant.LazyGroupBy` --- narwhals/_duckdb/dataframe.py | 4 +-- narwhals/_duckdb/group_by.py | 42 ++++++++++++++-------------- narwhals/_spark_like/dataframe.py | 4 +-- narwhals/_spark_like/group_by.py | 46 +++++++++++++++---------------- 4 files changed, 44 insertions(+), 52 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 5ac1642b23..45504da2e3 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -238,9 +238,7 @@ def _from_native_frame(self: Self, df: duckdb.DuckDBPyRelation) -> Self: def group_by(self: Self, *keys: str, drop_null_keys: bool) -> DuckDBGroupBy: from narwhals._duckdb.group_by import DuckDBGroupBy - return DuckDBGroupBy( - compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys - ) + return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys) def rename(self: Self, mapping: Mapping[str, str]) -> Self: df = self._native_frame diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index 7ffae18c9d..0475a29d75 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -1,6 +1,9 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Sequence + +from narwhals._compliant import CompliantGroupBy if TYPE_CHECKING: from duckdb import Expression @@ -11,45 +14,40 @@ # NOTE: No depth-tracking -class DuckDBGroupBy: +class DuckDBGroupBy(CompliantGroupBy["DuckDBLazyFrame", "DuckDBExpr"]): def __init__( self: Self, - compliant_frame: DuckDBLazyFrame, - keys: list[str], - drop_null_keys: bool, # noqa: FBT001 + df: DuckDBLazyFrame, + keys: Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - if drop_null_keys: - self._compliant_frame = compliant_frame.drop_nulls(subset=None) - else: - self._compliant_frame = compliant_frame - self._keys = keys + self._compliant_frame = df.drop_nulls(subset=None) if drop_null_keys else df + self._keys = list(keys) def agg(self: Self, *exprs: DuckDBExpr) -> DuckDBLazyFrame: agg_columns: list[str | Expression] = list(self._keys) - df = self._compliant_frame for expr in exprs: - output_names = expr._evaluate_output_names(df) + output_names = expr._evaluate_output_names(self.compliant) aliases = ( output_names if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - native_expressions = expr(df) + native_expressions = expr(self.compliant) exclude = ( self._keys if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} else [] ) agg_columns.extend( - [ - native_expression.alias(alias) - for native_expression, output_name, alias in zip( - native_expressions, output_names, aliases - ) - if output_name not in exclude - ] + native_expression.alias(alias) + for native_expression, output_name, alias in zip( + native_expressions, output_names, aliases + ) + if output_name not in exclude ) - - return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.aggregate(agg_columns) # type: ignore[arg-type] + return self.compliant._from_native_frame( + self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type] ) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 0b829a2f23..cd1577c8eb 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -273,9 +273,7 @@ def head(self: Self, n: int) -> Self: def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy: from narwhals._spark_like.group_by import SparkLikeLazyGroupBy - return SparkLikeLazyGroupBy( - compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys - ) + return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) def sort( self: Self, diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 7b8e8f395d..0268a685f7 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -1,6 +1,9 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Sequence + +from narwhals._compliant import CompliantGroupBy if TYPE_CHECKING: from typing_extensions import Self @@ -10,49 +13,44 @@ # NOTE: No depth-tracking -class SparkLikeLazyGroupBy: +class SparkLikeLazyGroupBy(CompliantGroupBy["SparkLikeLazyFrame", "SparkLikeExpr"]): def __init__( self: Self, - compliant_frame: SparkLikeLazyFrame, - keys: list[str], - drop_null_keys: bool, # noqa: FBT001 + df: SparkLikeLazyFrame, + keys: Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - if drop_null_keys: - self._compliant_frame = compliant_frame.drop_nulls(subset=None) - else: - self._compliant_frame = compliant_frame - self._keys = keys + self._compliant_frame = df.drop_nulls(subset=None) if drop_null_keys else df + self._keys = list(keys) def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: agg_columns = [] - df = self._compliant_frame for expr in exprs: - output_names = expr._evaluate_output_names(df) + output_names = expr._evaluate_output_names(self.compliant) aliases = ( output_names if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - native_expressions = expr(df) + native_expressions = expr(self.compliant) exclude = ( self._keys if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} else [] ) agg_columns.extend( - [ - native_expression.alias(alias) - for native_expression, output_name, alias in zip( - native_expressions, output_names, aliases - ) - if output_name not in exclude - ] + native_expression.alias(alias) + for native_expression, output_name, alias in zip( + native_expressions, output_names, aliases + ) + if output_name not in exclude ) - if not agg_columns: - return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.select(*self._keys).dropDuplicates() + return self.compliant._from_native_frame( + self.compliant.native.select(*self._keys).dropDuplicates() ) - return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.groupBy(*self._keys).agg(*agg_columns) + return self.compliant._from_native_frame( + self.compliant.native.groupBy(*self._keys).agg(*agg_columns) ) From 9701ac46cd91f0a386001fe53d34a6ba61553e31 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 14:43:06 +0000 Subject: [PATCH 15/23] help `mypy` https://github.com/narwhals-dev/narwhals/actions/runs/13972033263/job/39116058196?pr=2252 --- narwhals/_spark_like/group_by.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 0268a685f7..9b48201fe9 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -6,6 +6,7 @@ from narwhals._compliant import CompliantGroupBy if TYPE_CHECKING: + from sqlframe.base.column import Column from typing_extensions import Self from narwhals._spark_like.dataframe import SparkLikeLazyFrame @@ -26,7 +27,7 @@ def __init__( self._keys = list(keys) def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: - agg_columns = [] + agg_columns: list[Column] = [] for expr in exprs: output_names = expr._evaluate_output_names(self.compliant) aliases = ( From c99d9e71fac0f5520dafb8ab429879c128369487 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 16:54:29 +0000 Subject: [PATCH 16/23] refactor: Adds `LazyGroupBy` - Greatly simplifies what each backend needs to implement - Avoids creating and combining intermediate lists - Avoids performing a `not in exclude` check, where `exclude` is empty - Identified and documented a new common method `CompliantExpr._is_multi_output_agg` --- narwhals/_compliant/__init__.py | 2 ++ narwhals/_compliant/expr.py | 8 ++++++++ narwhals/_compliant/group_by.py | 33 ++++++++++++++++++++++++++++++- narwhals/_compliant/typing.py | 7 +++++++ narwhals/_duckdb/group_by.py | 30 +++++----------------------- narwhals/_expression_parsing.py | 4 +--- narwhals/_spark_like/group_by.py | 34 ++++++-------------------------- 7 files changed, 61 insertions(+), 57 deletions(-) diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index 0c228994e6..f223aa6a3f 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -9,6 +9,7 @@ from narwhals._compliant.group_by import CompliantGroupBy from narwhals._compliant.group_by import DepthTrackingGroupBy from narwhals._compliant.group_by import EagerGroupBy +from narwhals._compliant.group_by import LazyGroupBy from narwhals._compliant.namespace import CompliantNamespace from narwhals._compliant.namespace import EagerNamespace from narwhals._compliant.selectors import CompliantSelector @@ -55,6 +56,7 @@ "EvalSeries", "IntoCompliantExpr", "LazyExpr", + "LazyGroupBy", "LazySelectorNamespace", "NativeFrameT_co", "NativeSeriesT_co", diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index a922c66e78..ec251a6ac8 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -270,6 +270,14 @@ def __invert__(self) -> Self: ... def broadcast( self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL] ) -> Self: ... + def _is_multi_output_agg(self) -> bool: + """Return `True` for multi-output aggregations. + + Here we skip the keys, else they would appear duplicated in the output: + + df.group_by("a").agg(nw.all().mean()) + """ + return self._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} class EagerExpr( diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index b9f4ce5e40..15bdeb3071 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -6,6 +6,7 @@ from typing import Any from typing import Callable from typing import ClassVar +from typing import Iterable from typing import Iterator from typing import Mapping from typing import Sequence @@ -15,6 +16,9 @@ from narwhals._compliant.typing import CompliantExprAny from narwhals._compliant.typing import CompliantExprT_contra from narwhals._compliant.typing import CompliantFrameT_co +from narwhals._compliant.typing import CompliantLazyFrameT_co +from narwhals._compliant.typing import LazyExprT_contra +from narwhals._compliant.typing import NativeExprT_co from narwhals._expression_parsing import is_elementary_expression if not TYPE_CHECKING: # pragma: no cover @@ -27,7 +31,7 @@ # - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386 from typing import Protocol as Protocol38 -__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy"] +__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy", "LazyGroupBy"] NativeAggregationT_co = TypeVar( "NativeAggregationT_co", bound="str | Callable[...,Any]", covariant=True @@ -123,3 +127,30 @@ class EagerGroupBy( Protocol38[CompliantDataFrameT_co, CompliantExprT_contra], ): def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ... + + +class LazyGroupBy( + CompliantGroupBy[CompliantLazyFrameT_co, LazyExprT_contra], + Protocol38[CompliantLazyFrameT_co, LazyExprT_contra, NativeExprT_co], +): + def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]: + output_names = expr._evaluate_output_names(self.compliant) + aliases = ( + expr._alias_output_names(output_names) + if expr._alias_output_names + else output_names + ) + native_exprs = expr(self.compliant) + if expr._is_multi_output_agg(): + for native_expr, name, alias in zip(native_exprs, output_names, aliases): + if name not in self._keys: + yield native_expr.alias(alias) + else: + for native_expr, alias in zip(native_exprs, aliases): + yield native_expr.alias(alias) + + def _evaluate_exprs( + self, exprs: Iterable[LazyExprT_contra], / + ) -> Iterator[NativeExprT_co]: + for expr in exprs: + yield from self._evaluate_expr(expr) diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 56eafdfa72..c1771bde16 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -14,6 +14,7 @@ from narwhals._compliant.dataframe import EagerDataFrame from narwhals._compliant.expr import CompliantExpr from narwhals._compliant.expr import EagerExpr + from narwhals._compliant.expr import LazyExpr from narwhals._compliant.expr import NativeExpr from narwhals._compliant.namespace import EagerNamespace from narwhals._compliant.series import CompliantSeries @@ -57,6 +58,9 @@ "CompliantDataFrameT_co", bound="CompliantDataFrame[Any, Any, Any]", covariant=True ) CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame[Any, Any]") +CompliantLazyFrameT_co = TypeVar( + "CompliantLazyFrameT_co", bound="CompliantLazyFrame[Any, Any]", covariant=True +) IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co" CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]" CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny) @@ -74,6 +78,9 @@ EagerNamespaceAny: TypeAlias = ( "EagerNamespace[EagerDataFrame[Any, Any, Any], EagerSeries[Any], EagerExpr[Any, Any]]" ) +LazyExprT_contra = TypeVar( + "LazyExprT_contra", bound="LazyExpr[Any, Any]", contravariant=True +) AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]] AliasName: TypeAlias = Callable[[str], str] diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index 0475a29d75..8e5110e184 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -1,20 +1,20 @@ from __future__ import annotations +from itertools import chain from typing import TYPE_CHECKING from typing import Sequence -from narwhals._compliant import CompliantGroupBy +from narwhals._compliant import LazyGroupBy if TYPE_CHECKING: - from duckdb import Expression + from duckdb import Expression # noqa: F401 from typing_extensions import Self from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.expr import DuckDBExpr -# NOTE: No depth-tracking -class DuckDBGroupBy(CompliantGroupBy["DuckDBLazyFrame", "DuckDBExpr"]): +class DuckDBGroupBy(LazyGroupBy["DuckDBLazyFrame", "DuckDBExpr", "Expression"]): def __init__( self: Self, df: DuckDBLazyFrame, @@ -27,27 +27,7 @@ def __init__( self._keys = list(keys) def agg(self: Self, *exprs: DuckDBExpr) -> DuckDBLazyFrame: - agg_columns: list[str | Expression] = list(self._keys) - for expr in exprs: - output_names = expr._evaluate_output_names(self.compliant) - aliases = ( - output_names - if expr._alias_output_names is None - else expr._alias_output_names(output_names) - ) - native_expressions = expr(self.compliant) - exclude = ( - self._keys - if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} - else [] - ) - agg_columns.extend( - native_expression.alias(alias) - for native_expression, output_name, alias in zip( - native_expressions, output_names, aliases - ) - if output_name not in exclude - ) + agg_columns = list(chain(self._keys, self._evaluate_exprs(exprs))) return self.compliant._from_native_frame( self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type] ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 2c5f8cb506..91e6bcc2ee 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -120,9 +120,7 @@ def evaluate_output_names_and_aliases( if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"}: - # For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip - # the keys, else they would appear duplicated in the output. + if expr._is_multi_output_agg(): output_names, aliases = zip( *[(x, alias) for x, alias in zip(output_names, aliases) if x not in exclude] ) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 9b48201fe9..77d46602a3 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -3,18 +3,17 @@ from typing import TYPE_CHECKING from typing import Sequence -from narwhals._compliant import CompliantGroupBy +from narwhals._compliant import LazyGroupBy if TYPE_CHECKING: - from sqlframe.base.column import Column + from sqlframe.base.column import Column # noqa: F401 from typing_extensions import Self from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr -# NOTE: No depth-tracking -class SparkLikeLazyGroupBy(CompliantGroupBy["SparkLikeLazyFrame", "SparkLikeExpr"]): +class SparkLikeLazyGroupBy(LazyGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]): def __init__( self: Self, df: SparkLikeLazyFrame, @@ -27,31 +26,10 @@ def __init__( self._keys = list(keys) def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: - agg_columns: list[Column] = [] - for expr in exprs: - output_names = expr._evaluate_output_names(self.compliant) - aliases = ( - output_names - if expr._alias_output_names is None - else expr._alias_output_names(output_names) - ) - native_expressions = expr(self.compliant) - exclude = ( - self._keys - if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} - else [] - ) - agg_columns.extend( - native_expression.alias(alias) - for native_expression, output_name, alias in zip( - native_expressions, output_names, aliases - ) - if output_name not in exclude - ) - if not agg_columns: + if agg_columns := list(self._evaluate_exprs(exprs)): return self.compliant._from_native_frame( - self.compliant.native.select(*self._keys).dropDuplicates() + self.compliant.native.groupBy(*self._keys).agg(*agg_columns) ) return self.compliant._from_native_frame( - self.compliant.native.groupBy(*self._keys).agg(*agg_columns) + self.compliant.native.select(*self._keys).dropDuplicates() ) From 90fb3b4ef1b48444709b24f820e5d8a301eaca35 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 17:48:02 +0000 Subject: [PATCH 17/23] feat(typing): `Polars*GroupBy` and `Compliant*Frame` - Will need to make (inavriant) `CompliantExprT` to type the second part correctly - Believe that is also causing this weird error ``` Argument of type "CompliantExprT_contra@CompliantDataFrame" cannot be assigned to parameter "exprs" of type "CompliantExprT_contra@CompliantDataFrame" in function "select" Type "CompliantExprT_contra@CompliantDataFrame" is not assignable to type "CompliantExprT_contra@CompliantDataFrame" Pylance(reportArgumentType) ``` --- narwhals/_compliant/dataframe.py | 9 +++++++-- narwhals/_polars/dataframe.py | 8 ++++---- narwhals/_polars/group_by.py | 33 +++++++++++++++++++++++--------- 3 files changed, 35 insertions(+), 15 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 5cb593530f..dfff57f46d 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -32,6 +32,7 @@ from typing_extensions import Self from typing_extensions import TypeAlias + from narwhals._compliant.group_by import CompliantGroupBy from narwhals.dtypes import DType from narwhals.typing import SizeUnit from narwhals.typing import _2DArray @@ -91,7 +92,9 @@ def explode(self: Self, columns: Sequence[str]) -> Self: ... def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... def gather_every(self, n: int, offset: int) -> Self: ... def get_column(self, name: str) -> CompliantSeriesT: ... - def group_by(self, *keys: str, drop_null_keys: bool) -> Incomplete: ... + def group_by( + self, *keys: str, drop_null_keys: bool + ) -> CompliantGroupBy[Self, Any]: ... def head(self, n: int) -> Self: ... def item(self, row: int | None, column: int | str | None) -> Any: ... def iter_columns(self) -> Iterator[CompliantSeriesT]: ... @@ -218,7 +221,9 @@ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... "`LazyFrame.gather_every` is deprecated and will be removed in a future version." ) def gather_every(self, n: int, offset: int) -> Self: ... - def group_by(self, *keys: str, drop_null_keys: bool) -> Incomplete: ... + def group_by( + self, *keys: str, drop_null_keys: bool + ) -> CompliantGroupBy[Self, Any]: ... def head(self, n: int) -> Self: ... def join( self: Self, diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 3ccce65580..09c24042c9 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -347,10 +347,10 @@ def to_dict( else: return self.native.to_dict(as_series=False) - def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsGroupBy: + def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PolarsGroupBy: from narwhals._polars.group_by import PolarsGroupBy - return PolarsGroupBy(self, list(by), drop_null_keys=drop_null_keys) + return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys) def with_row_index(self: Self, name: str) -> Self: if self._backend_version < (0, 20, 4): @@ -572,10 +572,10 @@ def collect( msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise ValueError(msg) # pragma: no cover - def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsLazyGroupBy: + def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PolarsLazyGroupBy: from narwhals._polars.group_by import PolarsLazyGroupBy - return PolarsLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys) + return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) def with_row_index(self: Self, name: str) -> Self: if self._backend_version < (0, 20, 4): diff --git a/narwhals/_polars/group_by.py b/narwhals/_polars/group_by.py index 655e0e9bb8..48178f5fac 100644 --- a/narwhals/_polars/group_by.py +++ b/narwhals/_polars/group_by.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Iterator +from typing import Sequence from typing import cast from narwhals._polars.utils import extract_native @@ -17,32 +18,46 @@ class PolarsGroupBy: + _compliant_frame: PolarsDataFrame + _keys: Sequence[str] + + @property + def compliant(self) -> PolarsDataFrame: + return self._compliant_frame + def __init__( - self: Self, df: PolarsDataFrame, keys: list[str], *, drop_null_keys: bool + self, df: PolarsDataFrame, keys: Sequence[str], /, *, drop_null_keys: bool ) -> None: - self._compliant_frame: PolarsDataFrame = df - self.keys: list[str] = keys + self._compliant_frame = df + self._keys = list(keys) df = df.drop_nulls(keys) if drop_null_keys else df self._grouped: NativeGroupBy = df._native_frame.group_by(keys) def agg(self: Self, *aggs: PolarsExpr) -> PolarsDataFrame: - from_native = self._compliant_frame._from_native_frame + from_native = self.compliant._from_native_frame return from_native(self._grouped.agg(extract_native(arg) for arg in aggs)) def __iter__(self: Self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]: for key, df in self._grouped: - yield tuple(cast("str", key)), self._compliant_frame._from_native_frame(df) + yield tuple(cast("str", key)), self.compliant._from_native_frame(df) class PolarsLazyGroupBy: + _compliant_frame: PolarsLazyFrame + _keys: Sequence[str] + + @property + def compliant(self) -> PolarsLazyFrame: + return self._compliant_frame + def __init__( - self: Self, df: PolarsLazyFrame, keys: list[str], *, drop_null_keys: bool + self, df: PolarsLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool ) -> None: - self._compliant_frame: PolarsLazyFrame = df - self.keys: list[str] = keys + self._compliant_frame = df + self._keys = list(keys) df = df.drop_nulls(keys) if drop_null_keys else df self._grouped: NativeLazyGroupBy = df._native_frame.group_by(keys) def agg(self: Self, *aggs: PolarsExpr) -> PolarsLazyFrame: - from_native = self._compliant_frame._from_native_frame + from_native = self.compliant._from_native_frame return from_native(self._grouped.agg(extract_native(arg) for arg in aggs)) From 2438d0a5acc1ee20ee7905cdf0888e17ed77aae9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 18:43:39 +0000 Subject: [PATCH 18/23] chore: Rename pattern, remove temp doc --- narwhals/_compliant/group_by.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 15bdeb3071..947608b2d5 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -39,17 +39,7 @@ """Some backends *may* return a `Callable` instead of a `str` referring to one.""" -UNNAMED_PATTERN: re.Pattern[str] = re.compile(r"(\w+->)") -"""I'm unsure what this should be called. - -Seems to be used as a way to get `thing_n`: - - "thing_1->thing_2->...->thing_n" - -But with the assumption that `depth` is constrained below `2` (maybe?). - -**In isolation - the pattern doesn't mean any of that.** 🤔 -""" +_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)") class CompliantGroupBy(Protocol38[CompliantFrameT_co, CompliantExprT_contra]): @@ -119,7 +109,7 @@ def _remap_expr_name(cls, name: str, /) -> NativeAggregationT_co: @classmethod def _leaf_name(cls, expr: CompliantExprAny, /) -> str: """Return the last function name in the chain defined by `expr`.""" - return UNNAMED_PATTERN.sub("", expr._function_name) + return _RE_LEAF_NAME.sub("", expr._function_name) class EagerGroupBy( From 1e242f23b3d158a9cf498338b0667b5ae7d40744 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 19:16:00 +0000 Subject: [PATCH 19/23] Long variable names < types & docs --- narwhals/_arrow/group_by.py | 3 ++- narwhals/_compliant/group_by.py | 38 ++++++++++++++++++++++--------- narwhals/_dask/expr.py | 8 ++----- narwhals/_dask/group_by.py | 3 ++- narwhals/_pandas_like/expr.py | 8 ++----- narwhals/_pandas_like/group_by.py | 3 ++- 6 files changed, 37 insertions(+), 26 deletions(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 6beeeb5517..8d25c9d9af 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -24,10 +24,11 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.typing import Incomplete + from narwhals._compliant.group_by import NarwhalsAggregation class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]): - _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] = { + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = { "sum": "sum", "mean": "mean", "median": "approximate_median", diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 947608b2d5..866e98c6e8 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -8,6 +8,7 @@ from typing import ClassVar from typing import Iterable from typing import Iterator +from typing import Literal from typing import Mapping from typing import Sequence from typing import TypeVar @@ -21,6 +22,9 @@ from narwhals._compliant.typing import NativeExprT_co from narwhals._expression_parsing import is_elementary_expression +if TYPE_CHECKING: + from typing_extensions import TypeAlias + if not TYPE_CHECKING: # pragma: no cover if sys.version_info >= (3, 9): from typing import Protocol as Protocol38 @@ -31,12 +35,20 @@ # - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386 from typing import Protocol as Protocol38 -__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy", "LazyGroupBy"] +__all__ = [ + "CompliantGroupBy", + "DepthTrackingGroupBy", + "EagerGroupBy", + "LazyGroupBy", + "NarwhalsAggregation", +] NativeAggregationT_co = TypeVar( - "NativeAggregationT_co", bound="str | Callable[...,Any]", covariant=True + "NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True ) -"""Some backends *may* return a `Callable` instead of a `str` referring to one.""" +NarwhalsAggregation: TypeAlias = Literal[ + "sum", "mean", "median", "max", "min", "std", "var", "len", "n_unique", "count" +] _RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)") @@ -68,7 +80,12 @@ class DepthTrackingGroupBy( ): """`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`.""" - _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] + """Mapping from `narwhals` to native representation. + + Note: + - `Dask` *may* return a `Callable` instead of a `str` referring to one. + """ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: for expr in exprs: @@ -89,13 +106,12 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: @classmethod def _is_simple(cls, expr: CompliantExprAny, /) -> bool: """Return `True` is we can efficiently use `expr` in a native `group_by` context.""" - return ( - is_elementary_expression(expr) - and cls._leaf_name(expr) in cls._NARWHALS_TO_NATIVE_AGGREGATIONS - ) + return is_elementary_expression(expr) and cls._leaf_name(expr) in cls._REMAP_AGGS @classmethod - def _remap_expr_name(cls, name: str, /) -> NativeAggregationT_co: + def _remap_expr_name( + cls, name: NarwhalsAggregation | Any, / + ) -> NativeAggregationT_co: """Replace `name`, with some native representation. Arguments: @@ -104,10 +120,10 @@ def _remap_expr_name(cls, name: str, /) -> NativeAggregationT_co: Returns: A native compatible representation. """ - return cls._NARWHALS_TO_NATIVE_AGGREGATIONS.get(name, name) + return cls._REMAP_AGGS.get(name, name) @classmethod - def _leaf_name(cls, expr: CompliantExprAny, /) -> str: + def _leaf_name(cls, expr: CompliantExprAny, /) -> NarwhalsAggregation | Any: """Return the last function name in the chain defined by `expr`.""" return _RE_LEAF_NAME.sub("", expr._function_name) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index a793bdb914..ab99a5e15e 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -551,10 +551,6 @@ def over( # pandas is a required dependency of dask so it's safe to import this from narwhals._pandas_like.group_by import PandasLikeGroupBy - AGGREGATIONS_TO_PANDAS_EQUIVALENT = ( # noqa: N806 - PandasLikeGroupBy._NARWHALS_TO_NATIVE_AGGREGATIONS - ) - if not partition_by: assert order_by is not None # help type checkers # noqa: S101 @@ -572,12 +568,12 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: else: function_name = PandasLikeGroupBy._leaf_name(self) try: - dask_function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT[function_name] + dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] except KeyError: # window functions are unsupported: https://github.com/dask/dask/issues/11806 msg = ( f"Unsupported function: {function_name} in `over` context.\n\n" - f"Supported functions are {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}\n" + f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n" ) raise NotImplementedError(msg) from None diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 0a1d085dde..cca29870e8 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -24,6 +24,7 @@ from typing_extensions import Self from typing_extensions import TypeAlias + from narwhals._compliant.group_by import NarwhalsAggregation from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr @@ -57,7 +58,7 @@ def std(ddof: int) -> _AggFn: class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]): - _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Aggregation]] = { + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = { "sum": "sum", "mean": "mean", "median": "median", diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index ec5da3d86f..be93b9c0a3 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -222,19 +222,15 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: ) raise NotImplementedError(msg) else: - AGGREGATIONS_TO_PANDAS_EQUIVALENT = ( # noqa: N806 - PandasLikeGroupBy._NARWHALS_TO_NATIVE_AGGREGATIONS - ) function_name = PandasLikeGroupBy._leaf_name(self) pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get( - function_name, - AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(function_name), + function_name, PandasLikeGroupBy._REMAP_AGGS.get(function_name) ) if pandas_function_name is None: msg = ( f"Unsupported function: {function_name} in `over` context.\n\n" f"Supported functions are {', '.join(WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT)}\n" - f"and {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}." + f"and {', '.join(PandasLikeGroupBy._REMAP_AGGS)}." ) raise NotImplementedError(msg) pandas_kwargs = window_kwargs_to_pandas_equivalent( diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 576524e4bb..86419f9a63 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -21,12 +21,13 @@ if TYPE_CHECKING: from typing_extensions import Self + from narwhals._compliant.group_by import NarwhalsAggregation from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr"]): - _NARWHALS_TO_NATIVE_AGGREGATIONS: ClassVar[Mapping[str, Any]] = { + _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = { "sum": "sum", "mean": "mean", "median": "median", From 8c7ad01266704cb20c298c882742bd02a669c438 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 19:20:41 +0000 Subject: [PATCH 20/23] refactor: listcomp < genexpr --- narwhals/_pandas_like/group_by.py | 32 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 86419f9a63..8fadf3dc69 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -203,27 +203,23 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR if std_aggs: result_aggs.extend( - [ - set_columns( - self._grouped[std_output_names].std(ddof=ddof), - columns=std_aliases, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, (std_output_names, std_aliases) in std_aggs.items() - ] + set_columns( + self._grouped[std_output_names].std(ddof=ddof), + columns=std_aliases, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, (std_output_names, std_aliases) in std_aggs.items() ) if var_aggs: result_aggs.extend( - [ - set_columns( - self._grouped[var_output_names].var(ddof=ddof), - columns=var_aliases, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, (var_output_names, var_aliases) in var_aggs.items() - ] + set_columns( + self._grouped[var_output_names].var(ddof=ddof), + columns=var_aliases, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, (var_output_names, var_aliases) in var_aggs.items() ) if result_aggs: From 7393bf8930bcc6e098e2a924df885ca3c488d1eb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 20:29:51 +0000 Subject: [PATCH 21/23] chore(typing): Temp ignore for false positive --- narwhals/_compliant/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index dfff57f46d..06b8a00c44 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -68,7 +68,8 @@ def aggregate(self, *exprs: CompliantExprT_contra) -> Self: (so, no broadcasting is necessary). """ - return self.select(*exprs) + # NOTE: Ignore is to avoid an intermittent false positive + return self.select(*exprs) # pyright: ignore[reportArgumentType] @property def native(self) -> NativeFrameT_co: From 8e63e2720723c50ce2ed52f1e155fa60b24c92cf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 20:34:50 +0000 Subject: [PATCH 22/23] avoid unused import --- narwhals/_arrow/group_by.py | 1 - 1 file changed, 1 deletion(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 8d25c9d9af..c178a940b3 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -11,7 +11,6 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.utils import cast_to_comparable_string_types from narwhals._arrow.utils import extract_py_scalar from narwhals._compliant import EagerGroupBy From 2a3b2f726f54636c0464e54b6831eff2dce0b0e2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Mar 2025 20:38:59 +0000 Subject: [PATCH 23/23] refactor: Use `Implementation.is_pandas` --- narwhals/_pandas_like/group_by.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 8fadf3dc69..bed561f24e 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -15,7 +15,6 @@ from narwhals._pandas_like.utils import native_series_from_iterable from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import set_columns -from narwhals.utils import Implementation from narwhals.utils import find_stacklevel if TYPE_CHECKING: @@ -57,7 +56,7 @@ def __init__( else: native_frame = df.native if ( - self.compliant._implementation is Implementation.PANDAS + self.compliant._implementation.is_pandas() and self.compliant._backend_version < (1, 1) ): # pragma: no cover if ( @@ -291,7 +290,7 @@ def func(df: Any) -> Any: implementation=implementation, ) - if implementation is Implementation.PANDAS and backend_version >= (2, 2): + if implementation.is_pandas() and backend_version >= (2, 2): result_complex = self._grouped.apply(func, include_groups=False) else: # pragma: no cover result_complex = self._grouped.apply(func)