Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
22efc12
feat(expr-ir): Add new `list.*` aggregations
dangotbanned Dec 14, 2025
02032f9
test: Add `list_agg_test`
dangotbanned Dec 14, 2025
2c2fa08
chore: Add to compliant-level
dangotbanned Dec 14, 2025
c8d09ed
feat(DRAFT): Porting (#3332)
dangotbanned Dec 14, 2025
0cb1f5c
fix: Ignore nulls on `list.sum`
dangotbanned Dec 14, 2025
a867206
simplify `list.sum`, break `list.median`
dangotbanned Dec 14, 2025
e9c3656
simplify `list.{max,mean,min}`
dangotbanned Dec 14, 2025
7cd45d6
fix: Let `median` take the simpler path
dangotbanned Dec 14, 2025
501480f
test: "Fix" `list.median` test
dangotbanned Dec 14, 2025
b5a78b0
test: Try removing `xfail`?
dangotbanned Dec 14, 2025
a3a43a4
test: Shrink list tests
dangotbanned Dec 14, 2025
e99f97a
test: Add `test_list_agg_scalar`
dangotbanned Dec 14, 2025
5d0376e
why are you like this mypy?
dangotbanned Dec 14, 2025
f8f9909
perf: Add `ListScalar` fastpaths
dangotbanned Dec 15, 2025
abd4843
Move to `group_by`, generalize, fix `<pyarrow.ListScalar: None>`
dangotbanned Dec 15, 2025
a7c9ee1
test: Make scalar cases less of a disaster
dangotbanned Dec 15, 2025
3fefcdb
feat(expr-ir): Add `list.{all,any}`
dangotbanned Dec 15, 2025
96b6638
feat(expr-ir): Add `list.{first,last}`
dangotbanned Dec 15, 2025
5b310c6
feat(expr-ir): Add `list.n_unique`
dangotbanned Dec 15, 2025
86a3060
docs: Rephrase `explode_with_indices`
dangotbanned Dec 16, 2025
d232439
style: re-align
dangotbanned Dec 16, 2025
92d0b74
Apply suggestions from code review
dangotbanned Dec 16, 2025
76ba623
refactor: Simplify double negations
dangotbanned Dec 16, 2025
f6de206
ooh nice, we don't need `ignore_nulls=False` this way!
dangotbanned Dec 16, 2025
fc761a9
refactor: Rename `len_eq_0` -> `is_sublist_empty`
dangotbanned Dec 16, 2025
f74c4dd
docs: More clearly demo `arrow.options` lazy mappings
dangotbanned Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions narwhals/_plan/arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_seq_column,
)
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow.group_by import AggSpec
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co
from narwhals._plan.common import temp
Expand Down Expand Up @@ -994,6 +995,24 @@ def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scala
)
return self.with_native(result, name)

def aggregate(
self, node: FExpr[lists.Aggregation], frame: Frame, name: str
) -> Expr | Scalar:
previous = node.input[0].dispatch(self.compliant, frame, name)
agg = AggSpec._from_list_agg(node.function, "values")
return self.with_native(agg.agg_list(previous.native), name)

min = aggregate
max = aggregate
mean = aggregate
median = aggregate
sum = aggregate
any = aggregate
all = aggregate
first = aggregate
last = aggregate
n_unique = aggregate


class ArrowStringNamespace(
ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT]
Expand Down
22 changes: 22 additions & 0 deletions narwhals/_plan/arrow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,17 @@ def explode(
return chunked_array(_list_explode(safe))

def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table:
"""Explode list elements, expanding one-level into a table indexing the origin.

Returns a 2-column table, with names `"idx"` and `"values"`:

>>> from narwhals._plan.arrow import functions as fn
>>>
>>> arr = fn.array([[1, 2, 3], None, [4, 5, 6], []])
>>> fn.ExplodeBuilder().explode_with_indices(arr).to_pydict()
{'idx': [0, 0, 0, 1, 2, 2, 2, 3], 'values': [1, 2, 3, None, 4, 5, 6, None]}
# ^ Which sublist values come from ^ The exploded values themselves
"""
safe = self._fill_with_null(native) if self.options.any() else native
arrays = [_list_parent_indices(safe), _list_explode(safe)]
return concat_horizontal(arrays, ["idx", "values"])
Expand Down Expand Up @@ -1042,6 +1053,12 @@ def _str_zfill_compat(
)


@t.overload
def when_then(
predicate: ChunkedArray[BooleanScalar], then: ScalarAny
) -> ChunkedArrayAny: ...
@t.overload
def when_then(predicate: Array[BooleanScalar], then: ScalarAny) -> ArrayAny: ...
@t.overload
def when_then(
predicate: Predicate, then: SameArrowT, otherwise: SameArrowT | None
Expand All @@ -1059,6 +1076,11 @@ def when_then(
def when_then(
predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None
) -> Incomplete:
"""Thin wrapper around `pyarrow.compute.if_else`.

- Supports a 2-arg form, like `pl.when(...).then(...)`
- Accepts python literals, but only in the `otherwise` position
"""
if is_non_nested_literal(otherwise):
otherwise = lit(otherwise, then.type)
return pc.if_else(predicate, then, otherwise)
Expand Down
143 changes: 128 additions & 15 deletions narwhals/_plan/arrow/group_by.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, overload
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import
Expand All @@ -13,7 +13,7 @@
from narwhals._plan.common import temp
from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy
from narwhals._plan.expressions import aggregation as agg
from narwhals._utils import Implementation
from narwhals._utils import Implementation, qualified_type_name
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
Expand All @@ -26,16 +26,17 @@
ArrayAny,
ChunkedArray,
ChunkedArrayAny,
ChunkedList,
ChunkedOrScalarAny,
Indices,
ListScalar,
ScalarAny,
)
from narwhals._plan.expressions import NamedIR
from narwhals._plan.typing import Seq

Incomplete: TypeAlias = Any

# NOTE: Unless stated otherwise, all aggregations have 2 variants:
# - `<function>` (pc.Function.kind == "scalar_aggregate")
# - `hash_<function>` (pc.Function.kind == "hash_aggregate")
SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = {
agg.Sum: "hash_sum",
agg.Mean: "hash_mean",
Expand All @@ -51,9 +52,19 @@
agg.Last: "hash_last",
fn.MinMax: "hash_min_max",
}
SUPPORTED_LIST_AGG: Mapping[type[ir.lists.Aggregation], type[agg.AggExpr]] = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, it'll make more sense to define this somewhere in _plan/expressions/ - but this'll do for now

ir.lists.Mean: agg.Mean,
ir.lists.Median: agg.Median,
ir.lists.Max: agg.Max,
ir.lists.Min: agg.Min,
ir.lists.Sum: agg.Sum,
ir.lists.First: agg.First,
ir.lists.Last: agg.Last,
ir.lists.NUnique: agg.NUnique,
}
SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = {
ir.Len: "hash_count_all",
ir.Column: "hash_list", # `hash_aggregate` only
ir.Column: "hash_list",
}

_version_dependent: dict[Any, acero.Aggregation] = {}
Expand All @@ -65,16 +76,42 @@
SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = {
ir.boolean.All: "hash_all",
ir.boolean.Any: "hash_any",
ir.functions.Unique: "hash_distinct", # `hash_aggregate` only
ir.functions.Unique: "hash_distinct",
ir.functions.NullCount: "hash_count",
**_version_dependent,
}

del _version_dependent


SUPPORTED_LIST_FUNCTION: Mapping[type[ir.lists.Aggregation], type[ir.Function]] = {
ir.lists.Any: ir.boolean.Any,
ir.lists.All: ir.boolean.All,
}
Comment on lines +87 to +90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect these to be aggregations as well (and go in SUPPORTED_LIST_AGG), but apparently aren't? That would simplify _from_list_agg

Copy link
Member Author

@dangotbanned dangotbanned Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I too was confused by this initially 😄

But this is something I've inherited from polars

polars

Here

But why?

There might be stronger reasoning that I haven't found yet, but from what I understand:

  • All AggExprs aggregate to a single value
  • Some FunctionExprs aggregate (but not many), and they are marked with FunctionOptions.aggregation

If I had to guess, it may be that these aggregating functions place additional constraints on their inputs.
These two cases must also have Boolean inputs.
Some others like NullCount do not observe order (I haven't added this concept, it was new 😅)

That being said, I wouldn't be opposed to deviating from upstream here if it can simplify things 🙂

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some git archaeology on the parts of (https://github.com/pola-rs/polars/tree/05f6f2db49c6721f8208b3bbcca5ec54568e34c4/crates) that I'm vaguely familiar with - I still wasn't able to find anything explaining what defines something being an AggExpr vs a FunctionExpr which aggregates.

An interesting find though was that IRAggExpr have a corresponding GroupByMethod - but this is a deprecated feature that is still geting updated? 🤔


I think until I understand why there's distinction between the two - it might be best to assume someone smarter than me made the right decision 😅

I'm definitely curious though and wanna revisit it later


SCALAR_OUTPUT_TYPE: Mapping[acero.Aggregation, pa.DataType] = {
"all": fn.BOOL,
"any": fn.BOOL,
"approximate_median": fn.F64,
"count": fn.I64,
"count_all": fn.I64,
"count_distinct": fn.I64,
"kurtosis": fn.F64,
"mean": fn.F64,
"skew": fn.F64,
"stddev": fn.F64,
"variance": fn.F64,
}
"""Scalar aggregates that have an output type **not** dependent on input types*.

For use in list aggregates, where the input was null.

*Except `"mean"` will preserve `Decimal`, if that's where we started.
"""


class AggSpec:
__slots__ = ("agg", "name", "option", "target")
__slots__ = ("_function", "_name", "_option", "_target")

def __init__(
self,
Expand All @@ -83,19 +120,22 @@ def __init__(
option: acero.Opts = None,
name: acero.OutputName = "",
) -> None:
self.target = target
self.agg = agg
self.option = option
self.name = name or str(target)
self._target = target
self._function: acero.Aggregation = agg
self._option: acero.Opts = option
self._name: acero.OutputName = name or str(target)

@property
def use_threads(self) -> bool:
"""See https://github.com/apache/arrow/issues/36709."""
return acero.can_thread(self.agg)
return acero.can_thread(self._function)

def __iter__(self) -> Iterator[acero.Target | acero.Aggregation | acero.Opts]:
"""Let's us duck-type as a 4-tuple."""
yield from (self.target, self.agg, self.option, self.name)
yield from (self._target, self._function, self._option, self._name)

def __repr__(self) -> str:
return f"{type(self).__name__}({self._target!r}, {self._function!r}, {self._option!r}, {self._name!r})"

@classmethod
def from_named_ir(cls, named_ir: NamedIR) -> Self:
Expand Down Expand Up @@ -141,6 +181,24 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self:
def _from_function(cls, tp: type[ir.Function], name: str) -> Self:
return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name)

@classmethod
def _from_list_agg(cls, list_agg: ir.lists.Aggregation, /, name: str) -> Self:
tp = type(list_agg)
if tp_agg := SUPPORTED_LIST_AGG.get(tp):
if tp_agg in {agg.Std, agg.Var}:
msg = (
f"TODO: {qualified_type_name(list_agg)!r} needs access to `ddof`.\n"
"Add some sugar around mapping `ListFunction.<parameter>` -> `AggExpr.<parameter>`\n"
"or using `Immutable.__immutable_keys__`"
)
raise NotImplementedError(msg)
fn_name = SUPPORTED_AGG[tp_agg]
elif tp_func := SUPPORTED_LIST_FUNCTION.get(tp):
fn_name = SUPPORTED_FUNCTION[tp_func]
else:
raise NotImplementedError(tp)
return cls(name, fn_name, options.LIST_AGG.get(tp), name)

@classmethod
def any(cls, name: str) -> Self:
return cls._from_function(ir.boolean.Any, name)
Expand All @@ -155,6 +213,29 @@ def implode(cls, name: str) -> Self:
# https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-plan/src/dsl/expr/mod.rs#L44
return cls(name, SUPPORTED_IR[ir.Column], None, name)

@overload
def agg_list(self, native: ChunkedList) -> ChunkedArrayAny: ...
@overload
def agg_list(self, native: ListScalar) -> ScalarAny: ...
def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny:
"""Execute this aggregation over the values in *each* list, reducing *each* to a single value."""
result: ChunkedOrScalarAny
if isinstance(native, pa.Scalar):
scalar = cast("pa.ListScalar[Any]", native)
func = HASH_TO_SCALAR_NAME[self._function]
if not scalar.is_valid:
return fn.lit(None, SCALAR_OUTPUT_TYPE.get(func, scalar.type.value_type))
result = pc.call_function(func, [scalar.values], self._option)
return result
result = self.over_index(fn.ExplodeBuilder().explode_with_indices(native), "idx")
result = fn.when_then(native.is_valid(), result)
if self._is_n_unique():
# NOTE: Exploding `[]` becomes `[None]` - so we need to adjust the unique count *iff* we were unlucky
is_sublist_empty = fn.eq(fn.list_len(native), fn.lit(0))
if fn.any_(is_sublist_empty).as_py():
result = fn.when_then(is_sublist_empty, fn.lit(0), result)
return result

def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table:
"""Sugar for `native.group_by(keys).aggregate([self])`.

Expand All @@ -167,7 +248,10 @@ def over_index(self, native: pa.Table, index_column: str) -> ChunkedArrayAny:

Returns a single, (unnamed) array, representing the aggregation results.
"""
return acero.group_by_table(native, [index_column], [self]).column(self.name)
return acero.group_by_table(native, [index_column], [self]).column(self._name)

def _is_n_unique(self) -> bool:
return self._function == SUPPORTED_AGG[agg.NUnique]


def group_by_error(
Expand Down Expand Up @@ -321,3 +405,32 @@ def _partition_by_many(
# E.g, to push down column selection to *before* collection
# Not needed for this task though
yield acero.collect(source, acero.filter(key == v), select)


def _generate_hash_to_scalar_name() -> Mapping[acero.Aggregation, acero.Aggregation]:
nw_to_hash = SUPPORTED_AGG, SUPPORTED_IR, SUPPORTED_FUNCTION
only_hash = {"hash_distinct", "hash_list", "hash_one"}
targets = set[str](chain.from_iterable(m.values() for m in nw_to_hash)) - only_hash
hash_to_scalar = {hash_name: hash_name.removeprefix("hash_") for hash_name in targets}
# NOTE: Support both of these when using `AggSpec` directly for scalar aggregates
# `(..., "hash_mean", ..., ...)`
# `(..., "mean", ..., ...)`
scalar_names = hash_to_scalar.values()
scalar_to_scalar = zip(scalar_names, scalar_names)
hash_to_scalar.update(dict(scalar_to_scalar))
return cast("Mapping[acero.Aggregation, acero.Aggregation]", hash_to_scalar)


# TODO @dangotbanned: Replace this with a lazier version
# Don't really want this running at import-time, but using `ModuleType.__getattr__` means
# defining it somewhere else
HASH_TO_SCALAR_NAME: Mapping[acero.Aggregation, acero.Aggregation] = (
_generate_hash_to_scalar_name()
)
"""Mapping between [Hash aggregate] and [Scalar aggregate] names.

Dynamically built for use in `ListScalar` aggregations, accounting for version availability.

[Hash aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#grouped-aggregations-group-by
[Scalar aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#aggregations
"""
Loading
Loading