Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5b3f3f5
feat(DRAFT): Add `CompliantGroupBy`
dangotbanned Mar 19, 2025
16a9230
fix: `3.8` compat
dangotbanned Mar 19, 2025
6f68f3a
chore: export to `_compliant`
dangotbanned Mar 19, 2025
5aaceb7
refactor: Implement `ArrowGroupBy`
dangotbanned Mar 19, 2025
41ea47a
fix(typing): Get an agreement on variance
dangotbanned Mar 19, 2025
9637662
refactor: Implement `PandasLikeGroupBy`?
dangotbanned Mar 19, 2025
810fa20
refactor: Implement `DaskLazyGroupBy`
dangotbanned Mar 19, 2025
e84cb98
Merge branch 'main' into compliant-group-by
dangotbanned Mar 19, 2025
eb04cc8
Merge branch 'main' into compliant-group-by
dangotbanned Mar 20, 2025
c354dd8
Merge branch 'main' into compliant-group-by
dangotbanned Mar 20, 2025
0a2ee8f
chore(typing): Utilize (#2251)
dangotbanned Mar 20, 2025
2e3291b
refactor: `CompliantGroupBy._remap_expr_name`
dangotbanned Mar 20, 2025
799720b
refactor: `CompliantGroupBy._leaf_name`
dangotbanned Mar 20, 2025
25eb682
refactor: `CompliantGroupBy._is_simple`
dangotbanned Mar 20, 2025
51910ac
fix: pre `3.13` protocol support
dangotbanned Mar 20, 2025
d335950
refactor: Move most of `CompliantGroupBy` -> `DepthTrackingGroupBy`
dangotbanned Mar 20, 2025
c8aa213
refactor(DRAFT): Start simplifying, aligning lazy group bys
dangotbanned Mar 20, 2025
9701ac4
help `mypy`
dangotbanned Mar 20, 2025
c99d9e7
refactor: Adds `LazyGroupBy`
dangotbanned Mar 20, 2025
9de0f1a
Merge branch 'main' into compliant-group-by
dangotbanned Mar 20, 2025
90fb3b4
feat(typing): `Polars*GroupBy` and `Compliant*Frame`
dangotbanned Mar 20, 2025
2438d0a
chore: Rename pattern, remove temp doc
dangotbanned Mar 20, 2025
1e242f2
Long variable names < types & docs
dangotbanned Mar 20, 2025
8c7ad01
refactor: listcomp < genexpr
dangotbanned Mar 20, 2025
7393bf8
chore(typing): Temp ignore for false positive
dangotbanned Mar 20, 2025
8e63e27
avoid unused import
dangotbanned Mar 20, 2025
2a3b2f7
refactor: Use `Implementation.is_pandas`
dangotbanned Mar 20, 2025
8a06e68
Merge remote-tracking branch 'upstream/main' into compliant-group-by
dangotbanned Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> ArrowGroupBy:
from narwhals._arrow.group_by import ArrowGroupBy

return ArrowGroupBy(self, list(keys), drop_null_keys=drop_null_keys)
return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)

def join(
self: Self,
Expand Down
98 changes: 42 additions & 56 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

import collections
import re
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from typing import Iterator
from typing import Mapping
from typing import Sequence

import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.utils import cast_to_comparable_string_types
from narwhals._arrow.utils import extract_py_scalar
from narwhals._compliant import EagerGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._expression_parsing import is_elementary_expression
from narwhals.utils import generate_temporary_column_name

if TYPE_CHECKING:
Expand All @@ -21,78 +23,61 @@
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import Incomplete
from narwhals._compliant.group_by import NarwhalsAggregation


class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]):
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = {
"sum": "sum",
"mean": "mean",
"median": "approximate_median",
"max": "max",
"min": "min",
"std": "stddev",
"var": "variance",
"len": "count",
"n_unique": "count_distinct",
"count": "count",
}

POLARS_TO_ARROW_AGGREGATIONS = {
"sum": "sum",
"mean": "mean",
"median": "approximate_median",
"max": "max",
"min": "min",
"std": "stddev",
"var": "variance",
"len": "count",
"n_unique": "count_distinct",
"count": "count",
}


class ArrowGroupBy:
def __init__(
self: Self, df: ArrowDataFrame, keys: list[str], *, drop_null_keys: bool
self,
compliant_frame: ArrowDataFrame,
keys: Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
if drop_null_keys:
self._df = df.drop_nulls(keys)
self._compliant_frame = compliant_frame.drop_nulls(keys)
else:
self._df = df
self._keys = keys.copy()
self._grouped = pa.TableGroupBy(self._df._native_frame, self._keys)
self._compliant_frame = compliant_frame
self._keys: list[str] = list(keys)
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)

def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
all_simple_aggs = True
for expr in exprs:
if not (
is_elementary_expression(expr)
and re.sub(r"(\w+->)", "", expr._function_name)
in POLARS_TO_ARROW_AGGREGATIONS
):
all_simple_aggs = False
break

if not all_simple_aggs:
msg = (
"Non-trivial complex aggregation found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"pyarrow table.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
)
raise ValueError(msg)

self._ensure_all_simple(exprs)
aggs: list[tuple[str, str, Any]] = []
expected_pyarrow_column_names: list[str] = self._keys.copy()
new_column_names: list[str] = self._keys.copy()

for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(
expr, self._df, self._keys
expr, self.compliant, self._keys
)

if expr._depth == 0:
# e.g. agg(nw.len()) # noqa: ERA001
# e.g. `agg(nw.len())`
if expr._function_name != "len": # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

new_column_names.append(aliases[0])
expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))

continue

function_name = re.sub(r"(\w+->)", "", expr._function_name)
function_name = self._leaf_name(expr)
if function_name in {"std", "var"}:
option: Any = pc.VarianceOptions(ddof=expr._call_kwargs["ddof"])
elif function_name in {"len", "n_unique"}:
Expand All @@ -102,8 +87,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
else:
option = None

function_name = POLARS_TO_ARROW_AGGREGATIONS[function_name]

function_name = self._remap_expr_name(function_name)
new_column_names.extend(aliases)
expected_pyarrow_column_names.extend(
[f"{output_name}_{function_name}" for output_name in output_names]
Expand Down Expand Up @@ -133,18 +117,20 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
]
new_column_names = [new_column_names[i] for i in index_map]
result_simple = result_simple.rename_columns(new_column_names)
if self._df._backend_version < (12, 0, 0):
if self.compliant._backend_version < (12, 0, 0):
columns = result_simple.column_names
result_simple = result_simple.select(
[*self._keys, *[col for col in columns if col not in self._keys]]
)
return self._df._from_native_frame(result_simple)
return self.compliant._from_native_frame(result_simple)

def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns)
col_token = generate_temporary_column_name(
n_bytes=8, columns=self.compliant.columns
)
null_token: str = "__null_token_value__" # noqa: S105

table = self._df._native_frame
table = self.compliant.native
# NOTE: stubs fail in multiple places for `ChunkedArray`
it, separator_scalar = cast_to_comparable_string_types(
*(table[key] for key in self._keys), separator=""
Expand All @@ -160,7 +146,7 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
)
table = table.add_column(i=0, field_=col_token, column=key_values)
for v in pc.unique(key_values):
t = self._df._from_native_frame(
t = self.compliant._from_native_frame(
table.filter(pc.equal(table[col_token], v)).drop([col_token])
)
row = t.simple_select(*self._keys).row(0)
Expand Down
8 changes: 8 additions & 0 deletions narwhals/_compliant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.expr import EagerExpr
from narwhals._compliant.expr import LazyExpr
from narwhals._compliant.group_by import CompliantGroupBy
from narwhals._compliant.group_by import DepthTrackingGroupBy
from narwhals._compliant.group_by import EagerGroupBy
from narwhals._compliant.group_by import LazyGroupBy
from narwhals._compliant.namespace import CompliantNamespace
from narwhals._compliant.namespace import EagerNamespace
from narwhals._compliant.selectors import CompliantSelector
Expand All @@ -31,16 +35,19 @@
"CompliantExpr",
"CompliantExprT",
"CompliantFrameT",
"CompliantGroupBy",
"CompliantLazyFrame",
"CompliantNamespace",
"CompliantSelector",
"CompliantSelectorNamespace",
"CompliantSeries",
"CompliantSeriesOrNativeExprT_co",
"CompliantSeriesT",
"DepthTrackingGroupBy",
"EagerDataFrame",
"EagerDataFrameT",
"EagerExpr",
"EagerGroupBy",
"EagerNamespace",
"EagerSelectorNamespace",
"EagerSeries",
Expand All @@ -49,6 +56,7 @@
"EvalSeries",
"IntoCompliantExpr",
"LazyExpr",
"LazyGroupBy",
"LazySelectorNamespace",
"NativeFrameT_co",
"NativeSeriesT_co",
Expand Down
12 changes: 9 additions & 3 deletions narwhals/_compliant/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing_extensions import Self
from typing_extensions import TypeAlias

from narwhals._compliant.group_by import CompliantGroupBy
from narwhals.dtypes import DType
from narwhals.typing import SizeUnit
from narwhals.typing import _2DArray
Expand Down Expand Up @@ -67,7 +68,8 @@ def aggregate(self, *exprs: CompliantExprT_contra) -> Self:

(so, no broadcasting is necessary).
"""
return self.select(*exprs)
# NOTE: Ignore is to avoid an intermittent false positive
return self.select(*exprs) # pyright: ignore[reportArgumentType]

@property
def native(self) -> NativeFrameT_co:
Expand All @@ -91,7 +93,9 @@ def explode(self: Self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def get_column(self, name: str) -> CompliantSeriesT: ...
def group_by(self, *keys: str, drop_null_keys: bool) -> Incomplete: ...
def group_by(
self, *keys: str, drop_null_keys: bool
) -> CompliantGroupBy[Self, Any]: ...
def head(self, n: int) -> Self: ...
def item(self, row: int | None, column: int | str | None) -> Any: ...
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
Expand Down Expand Up @@ -218,7 +222,9 @@ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
)
def gather_every(self, n: int, offset: int) -> Self: ...
def group_by(self, *keys: str, drop_null_keys: bool) -> Incomplete: ...
def group_by(
self, *keys: str, drop_null_keys: bool
) -> CompliantGroupBy[Self, Any]: ...
def head(self, n: int) -> Self: ...
def join(
self: Self,
Expand Down
8 changes: 8 additions & 0 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ def __invert__(self) -> Self: ...
def broadcast(
self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]
) -> Self: ...
def _is_multi_output_agg(self) -> bool:
"""Return `True` for multi-output aggregations.

Here we skip the keys, else they would appear duplicated in the output:

df.group_by("a").agg(nw.all().mean())
"""
return self._function_name.split("->", maxsplit=1)[0] in {"all", "selector"}


class EagerExpr(
Expand Down
Loading
Loading