From 22efc1246ac79e38c5b304910234165108b81625 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 16:44:53 +0000 Subject: [PATCH 01/26] feat(expr-ir): Add new `list.*` aggregations Playing catch-up on #3332 --- narwhals/_plan/expressions/lists.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index e35b4fb41e..001d7c7419 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -20,6 +20,11 @@ # fmt: off class ListFunction(Function, accessor="list", options=FunctionOptions.elementwise): ... +class Min(ListFunction): ... +class Max(ListFunction): ... +class Mean(ListFunction): ... +class Median(ListFunction): ... +class Sum(ListFunction): ... class Len(ListFunction): ... class Unique(ListFunction): ... class Get(ListFunction): @@ -46,6 +51,11 @@ class IRListNamespace(IRNamespace): contains: ClassVar = Contains get: ClassVar = Get join: ClassVar = Join + min: ClassVar = Min + max: ClassVar = Max + mean: ClassVar = Mean + median: ClassVar = Median + sum: ClassVar = Sum class ExprListNamespace(ExprNamespace[IRListNamespace]): @@ -53,6 +63,21 @@ class ExprListNamespace(ExprNamespace[IRListNamespace]): def _ir_namespace(self) -> type[IRListNamespace]: return IRListNamespace + def min(self) -> Expr: + return self._with_unary(self._ir.min()) + + def max(self) -> Expr: + return self._with_unary(self._ir.max()) + + def mean(self) -> Expr: + return self._with_unary(self._ir.mean()) + + def median(self) -> Expr: + return self._with_unary(self._ir.median()) + + def sum(self) -> Expr: + return self._with_unary(self._ir.sum()) + def len(self) -> Expr: return self._with_unary(self._ir.len()) From 02032f92116acda9015a1139729c8de67aa7566b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 16:54:21 +0000 Subject: [PATCH 02/26] test: Add `list_agg_test` --- tests/plan/list_agg_test.py | 67 +++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/plan/list_agg_test.py diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py new file mode 100644 index 0000000000..f87d926398 --- /dev/null +++ b/tests/plan/list_agg_test.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe +from tests.utils import is_windows + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], []]} + + +@pytest.fixture(scope="module") +def data_median(data: Data) -> Data: + return {"a": [*data["a"], [3, 4, None]]} + + +a = nwp.col("a") +cast_a = a.cast(nw.List(nw.Int32)) + + +XFAIL_NOT_IMPL = pytest.mark.xfail( + reason="TODO: ArrowExpr.list.", raises=NotImplementedError +) + + +@XFAIL_NOT_IMPL +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + (a.list.max(), {"a": [4, -1, None, None, None]}), + (a.list.mean(), {"a": [2.75, -1, None, None, None]}), + (a.list.min(), {"a": [2, -1, None, None, None]}), + (a.list.sum(), {"a": [11, -1, None, 0, 0]}), + ], +) +def test_list_agg( + data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data +) -> None: # pragma: no cover + df = dataframe(data).with_columns(cast_a) + result = df.select(exprs) + assert_equal_data(result, expected) + + +@XFAIL_NOT_IMPL +@pytest.mark.xfail( + is_windows() and sys.version_info < (3, 10), reason="Old pyarrow windows bad?" +) +def test_list_median(data_median: Data) -> None: # pragma: no cover + df = dataframe(data_median).with_columns(cast_a) + result = df.select(a.list.median()) + + # TODO @dangotbanned: Is this fixable with `FunctionOptions`? + expected = [2.5, -1, None, None, None, 3.5] + expected_pyarrow = [2.5, -1, None, None, None, 3] + expected = expected_pyarrow + assert_equal_data(result, {"a": expected}) From 2c2fa08eede23f78002276fa464e44becf9ec3fd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 17:06:58 +0000 Subject: [PATCH 03/26] chore: Add to compliant-level --- narwhals/_plan/arrow/expr.py | 6 ++++++ narwhals/_plan/compliant/accessors.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 74f26de213..3c3ce659d6 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -994,6 +994,12 @@ def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scala ) return self.with_native(result, name) + min = not_implemented() + max = not_implemented() + mean = not_implemented() + median = not_implemented() + sum = not_implemented() + class ArrowStringNamespace( ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 26df3ff3ba..d7ef786a34 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -38,6 +38,21 @@ def unique( def join( self, node: FExpr[lists.Join], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def min( + self, node: FExpr[lists.Min], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def max( + self, node: FExpr[lists.Max], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def mean( + self, node: FExpr[lists.Mean], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def median( + self, node: FExpr[lists.Median], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def sum( + self, node: FExpr[lists.Sum], frame: FrameT_contra, name: str + ) -> ExprT_co: ... class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): From c8d09ede0621dbfbfe85069aee5f81a205b6054b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 18:57:11 +0000 Subject: [PATCH 04/26] feat(DRAFT): Porting (#3332) Tried to keep everything as close to original as possible Next step is simplifying everything and fixing `list.sum` --- narwhals/_plan/arrow/expr.py | 54 +++++++++++++++++++++++++---- narwhals/_plan/arrow/functions.py | 22 ++++++++++++ narwhals/_plan/arrow/group_by.py | 17 ++++++++- narwhals/_plan/expressions/lists.py | 5 ++- tests/plan/list_agg_test.py | 22 ++++++------ 5 files changed, 100 insertions(+), 20 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 3c3ce659d6..fec69bcceb 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -15,8 +15,14 @@ is_seq_column, ) from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow.group_by import AggSpec from narwhals._plan.arrow.series import ArrowSeries as Series -from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co +from narwhals._plan.arrow.typing import ( + ChunkedOrArrayAny, + ChunkedOrScalarAny, + NativeScalar, + StoresNativeT_co, +) from narwhals._plan.common import temp from narwhals._plan.compliant.accessors import ( ExprCatNamespace, @@ -994,11 +1000,47 @@ def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scala ) return self.with_native(result, name) - min = not_implemented() - max = not_implemented() - mean = not_implemented() - median = not_implemented() - sum = not_implemented() + def aggregate( + self, node: FExpr[lists.Aggregation], frame: Frame, name: str + ) -> Expr | Scalar: + previous = node.input[0].dispatch(self.compliant, frame, name) + func = node.function + if isinstance(previous, ArrowScalar): + msg = f"TODO: ArrowScalar.{func!r}" + raise NotImplementedError(msg) + + native = previous.native + lists = native + # TODO @dangotbanned: Experiment with explode step + # These options are to mirror `main`, but setting them to `True` may simplify everything after? + builder = fn.ExplodeBuilder(empty_as_null=False, keep_nulls=False) + explode_w_idx = builder.explode_with_indices(lists) + idx, v = "idx", "values" + agg_result = ( + AggSpec._from_agg(type(func), v) + .over(explode_w_idx, [idx]) + .sort_by(idx) + .column(v) + ) + dtype: pa.DataType = agg_result.type + non_empty_mask = fn.not_eq(fn.list_len(lists), fn.lit(0)) + base_array: ChunkedOrArrayAny + if isinstance(func, ir.lists.Sum): + # Make sure sum of empty list is 0. + base_array = fn.when_then(fn.is_not_null(non_empty_mask), fn.lit(0, dtype)) + else: + base_array = fn.repeat_unchecked(fn.lit(None, dtype), len(lists)) + replaced = fn.replace_with_mask( + base_array, fn.fill_null(non_empty_mask, False), agg_result + ) + result = fn.chunked_array(replaced) + return self.with_native(result, name) + + min = aggregate + max = aggregate + mean = aggregate + median = aggregate + sum = aggregate class ArrowStringNamespace( diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 42b44983b5..5d9ef53a4a 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -442,6 +442,17 @@ def explode( return chunked_array(_list_explode(safe)) def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table: + """Explode list elements, expanding one-level into a table indexing the origin. + + Returns a 2-column table, with names `"idx"` and `"values"`: + + >>> from narwhals._plan.arrow import functions as fn + >>> + >>> arr = fn.array([[1, 2, 3], None, [4, 5, 6], []]) + >>> fn.ExplodeBuilder().explode_with_indices(arr).to_pydict() + {'idx': [0, 0, 0, 1, 2, 2, 2, 3], 'values': [1, 2, 3, None, 4, 5, 6, None]} + # ^ Which sublist we came from ^ The exploded values themselves + """ safe = self._fill_with_null(native) if self.options.any() else native arrays = [_list_parent_indices(safe), _list_explode(safe)] return concat_horizontal(arrays, ["idx", "values"]) @@ -1042,6 +1053,12 @@ def _str_zfill_compat( ) +@t.overload +def when_then( + predicate: ChunkedArray[BooleanScalar], then: ScalarAny +) -> ChunkedArrayAny: ... +@t.overload +def when_then(predicate: Array[BooleanScalar], then: ScalarAny) -> ArrayAny: ... @t.overload def when_then( predicate: Predicate, then: SameArrowT, otherwise: SameArrowT | None @@ -1059,6 +1076,11 @@ def when_then( def when_then( predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None ) -> Incomplete: + """Thin wrapper around `pyarrow.compute.if_else`. + + - Supports a 2-arg form, like `pl.when(...).then(...)` + - Accepts python literals, but only in the `otherwise` position + """ if is_non_nested_literal(otherwise): otherwise = lit(otherwise, then.type) return pc.if_else(predicate, then, otherwise) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index af918a3e4c..0be29d2cae 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -13,7 +13,7 @@ from narwhals._plan.common import temp from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy from narwhals._plan.expressions import aggregation as agg -from narwhals._utils import Implementation +from narwhals._utils import Implementation, qualified_type_name from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: @@ -51,6 +51,13 @@ agg.Last: "hash_last", fn.MinMax: "hash_min_max", } +SUPPORTED_LIST_AGG: Mapping[type[ir.lists.Aggregation], type[agg.AggExpr]] = { + ir.lists.Mean: agg.Mean, + ir.lists.Median: agg.Median, + ir.lists.Max: agg.Max, + ir.lists.Min: agg.Min, + ir.lists.Sum: agg.Sum, +} SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", ir.Column: "hash_list", # `hash_aggregate` only @@ -141,6 +148,14 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: def _from_function(cls, tp: type[ir.Function], name: str) -> Self: return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name) + @classmethod + def _from_agg(cls, tp: type[ir.lists.Aggregation | agg.AggExpr], name: str) -> Self: + tp_agg = SUPPORTED_LIST_AGG[tp] if issubclass(tp, ir.lists.ListFunction) else tp + if tp_agg in {agg.Std, agg.Var}: + msg = f"TODO: {qualified_type_name(agg)!r} needs access to `ddof`, so can't be passed in without an instance" + raise NotImplementedError(msg) + return cls(name, SUPPORTED_AGG[tp_agg], options.AGG.get(tp_agg), name) + @classmethod def any(cls, name: str) -> Self: return cls._from_function(ir.boolean.Any, name) diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 001d7c7419..a4882acb3f 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -11,7 +11,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan.expr import Expr from narwhals._plan.expressions import ExprIR, FunctionExpr as FExpr @@ -45,6 +45,9 @@ def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: return expr, item +Aggregation: TypeAlias = "Min | Max | Mean | Median | Sum" + + class IRListNamespace(IRNamespace): len: ClassVar = Len unique: ClassVar = Unique diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index f87d926398..3ec3362efb 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -29,34 +29,32 @@ def data_median(data: Data) -> Data: cast_a = a.cast(nw.List(nw.Int32)) -XFAIL_NOT_IMPL = pytest.mark.xfail( - reason="TODO: ArrowExpr.list.", raises=NotImplementedError -) - - -@XFAIL_NOT_IMPL @pytest.mark.parametrize( ("exprs", "expected"), [ (a.list.max(), {"a": [4, -1, None, None, None]}), (a.list.mean(), {"a": [2.75, -1, None, None, None]}), (a.list.min(), {"a": [2, -1, None, None, None]}), - (a.list.sum(), {"a": [11, -1, None, 0, 0]}), + pytest.param( + a.list.sum(), + {"a": [11, -1, None, 0, 0]}, + marks=pytest.mark.xfail( + reason="Mismatch at index 3, key a: None != 0", raises=AssertionError + ), + ), ], + ids=["max", "mean", "min", "sum"], ) -def test_list_agg( - data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data -) -> None: # pragma: no cover +def test_list_agg(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: df = dataframe(data).with_columns(cast_a) result = df.select(exprs) assert_equal_data(result, expected) -@XFAIL_NOT_IMPL @pytest.mark.xfail( is_windows() and sys.version_info < (3, 10), reason="Old pyarrow windows bad?" ) -def test_list_median(data_median: Data) -> None: # pragma: no cover +def test_list_median(data_median: Data) -> None: df = dataframe(data_median).with_columns(cast_a) result = df.select(a.list.median()) From 0cb1f5c737823b9bcf11bf9929d3d6f3c8c037f8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 19:27:11 +0000 Subject: [PATCH 05/26] fix: Ignore nulls on `list.sum` There's definitely other steps that can be simplified now --- narwhals/_plan/arrow/expr.py | 2 +- narwhals/_plan/arrow/group_by.py | 9 ++++++--- narwhals/_plan/arrow/options.py | 11 +++++++++++ tests/plan/list_agg_test.py | 8 +------- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index fec69bcceb..8e2a60b6bf 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1017,7 +1017,7 @@ def aggregate( explode_w_idx = builder.explode_with_indices(lists) idx, v = "idx", "values" agg_result = ( - AggSpec._from_agg(type(func), v) + AggSpec._from_list_agg(type(func), v) .over(explode_w_idx, [idx]) .sort_by(idx) .column(v) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 0be29d2cae..886db057a2 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -104,6 +104,9 @@ def __iter__(self) -> Iterator[acero.Target | acero.Aggregation | acero.Opts]: """Let's us duck-type as a 4-tuple.""" yield from (self.target, self.agg, self.option, self.name) + def __repr__(self) -> str: + return f"{type(self).__name__}({self.target!r}, {self.agg!r}, {self.option!r}, {self.name!r})" + @classmethod def from_named_ir(cls, named_ir: NamedIR) -> Self: return cls.from_expr_ir(named_ir.expr, named_ir.name) @@ -149,12 +152,12 @@ def _from_function(cls, tp: type[ir.Function], name: str) -> Self: return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name) @classmethod - def _from_agg(cls, tp: type[ir.lists.Aggregation | agg.AggExpr], name: str) -> Self: - tp_agg = SUPPORTED_LIST_AGG[tp] if issubclass(tp, ir.lists.ListFunction) else tp + def _from_list_agg(cls, tp: type[ir.lists.Aggregation], name: str) -> Self: + tp_agg = SUPPORTED_LIST_AGG[tp] if tp_agg in {agg.Std, agg.Var}: msg = f"TODO: {qualified_type_name(agg)!r} needs access to `ddof`, so can't be passed in without an instance" raise NotImplementedError(msg) - return cls(name, SUPPORTED_AGG[tp_agg], options.AGG.get(tp_agg), name) + return cls(name, SUPPORTED_AGG[tp_agg], options.LIST_AGG.get(tp), name) @classmethod def any(cls, name: str) -> Self: diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 3d44487bc7..a9f768f700 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -39,6 +39,7 @@ AGG: Mapping[type[agg.AggExpr], acero.AggregateOptions] FUNCTION: Mapping[type[ir.Function], acero.AggregateOptions] +LIST_AGG: Mapping[type[ir.lists.Aggregation], acero.AggregateOptions] _NULLS_LAST = True _NULLS_FIRST = False @@ -159,6 +160,12 @@ def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: } +def _generate_list_agg() -> Mapping[type[ir.lists.Aggregation], acero.AggregateOptions]: + from narwhals._plan.expressions import lists + + return {lists.Sum: scalar_aggregate(ignore_nulls=True)} + + def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: from narwhals._plan.expressions import boolean, functions @@ -183,5 +190,9 @@ def __getattr__(name: str) -> Any: global FUNCTION FUNCTION = _generate_function() return FUNCTION + if name == "LIST_AGG": + global LIST_AGG + LIST_AGG = _generate_list_agg() + return LIST_AGG msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 3ec3362efb..57a27cb4e9 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -35,13 +35,7 @@ def data_median(data: Data) -> Data: (a.list.max(), {"a": [4, -1, None, None, None]}), (a.list.mean(), {"a": [2.75, -1, None, None, None]}), (a.list.min(), {"a": [2, -1, None, None, None]}), - pytest.param( - a.list.sum(), - {"a": [11, -1, None, 0, 0]}, - marks=pytest.mark.xfail( - reason="Mismatch at index 3, key a: None != 0", raises=AssertionError - ), - ), + (a.list.sum(), {"a": [11, -1, None, 0, 0]}), ], ids=["max", "mean", "min", "sum"], ) From a8672061a309efeaa92ee2ac2509ecc955c1deed Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 20:13:15 +0000 Subject: [PATCH 06/26] simplify `list.sum`, break `list.median` you win some, you lose some I guess --- narwhals/_plan/arrow/expr.py | 20 ++++++-------------- tests/plan/list_agg_test.py | 2 ++ 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 8e2a60b6bf..8cb3f755e6 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -17,12 +17,7 @@ from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.group_by import AggSpec from narwhals._plan.arrow.series import ArrowSeries as Series -from narwhals._plan.arrow.typing import ( - ChunkedOrArrayAny, - ChunkedOrScalarAny, - NativeScalar, - StoresNativeT_co, -) +from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co from narwhals._plan.common import temp from narwhals._plan.compliant.accessors import ( ExprCatNamespace, @@ -1013,23 +1008,20 @@ def aggregate( lists = native # TODO @dangotbanned: Experiment with explode step # These options are to mirror `main`, but setting them to `True` may simplify everything after? - builder = fn.ExplodeBuilder(empty_as_null=False, keep_nulls=False) + builder = fn.ExplodeBuilder() explode_w_idx = builder.explode_with_indices(lists) idx, v = "idx", "values" agg_result = ( AggSpec._from_list_agg(type(func), v) .over(explode_w_idx, [idx]) - .sort_by(idx) + .sort_by(idx) # <--- won't be needed now that exploding keeps all indices .column(v) ) + if isinstance(func, ir.lists.Sum): + return self.with_native(fn.when_then(fn.is_not_null(lists), agg_result), name) dtype: pa.DataType = agg_result.type non_empty_mask = fn.not_eq(fn.list_len(lists), fn.lit(0)) - base_array: ChunkedOrArrayAny - if isinstance(func, ir.lists.Sum): - # Make sure sum of empty list is 0. - base_array = fn.when_then(fn.is_not_null(non_empty_mask), fn.lit(0, dtype)) - else: - base_array = fn.repeat_unchecked(fn.lit(None, dtype), len(lists)) + base_array = fn.repeat_unchecked(fn.lit(None, dtype), len(lists)) replaced = fn.replace_with_mask( base_array, fn.fill_null(non_empty_mask, False), agg_result ) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 57a27cb4e9..4a5ca0005f 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -45,6 +45,8 @@ def test_list_agg(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> assert_equal_data(result, expected) +# TODO @dangotbanned: Fix after simplifying `list.{max,mean,min}` +@pytest.mark.xfail(reason="Mismatch at index 5, key a: None != 3", raises=AssertionError) @pytest.mark.xfail( is_windows() and sys.version_info < (3, 10), reason="Old pyarrow windows bad?" ) From e9c3656c19e4f785390a2ecb33671c470521d89e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 20:30:13 +0000 Subject: [PATCH 07/26] simplify `list.{max,mean,min}` --- narwhals/_plan/arrow/expr.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 8cb3f755e6..1ad50504f1 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1008,16 +1008,12 @@ def aggregate( lists = native # TODO @dangotbanned: Experiment with explode step # These options are to mirror `main`, but setting them to `True` may simplify everything after? - builder = fn.ExplodeBuilder() - explode_w_idx = builder.explode_with_indices(lists) + explode_w_idx = fn.ExplodeBuilder().explode_with_indices(lists) idx, v = "idx", "values" agg_result = ( - AggSpec._from_list_agg(type(func), v) - .over(explode_w_idx, [idx]) - .sort_by(idx) # <--- won't be needed now that exploding keeps all indices - .column(v) + AggSpec._from_list_agg(type(func), v).over(explode_w_idx, [idx]).column(v) ) - if isinstance(func, ir.lists.Sum): + if not isinstance(func, ir.lists.Median): return self.with_native(fn.when_then(fn.is_not_null(lists), agg_result), name) dtype: pa.DataType = agg_result.type non_empty_mask = fn.not_eq(fn.list_len(lists), fn.lit(0)) From 7cd45d6802685281527a6e79e1b63d2a0c9bec95 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 20:47:37 +0000 Subject: [PATCH 08/26] fix: Let `median` take the simpler path --- narwhals/_plan/arrow/expr.py | 24 ++++-------------------- tests/plan/list_agg_test.py | 2 -- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 1ad50504f1..5f152ed75f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -999,30 +999,14 @@ def aggregate( self, node: FExpr[lists.Aggregation], frame: Frame, name: str ) -> Expr | Scalar: previous = node.input[0].dispatch(self.compliant, frame, name) - func = node.function if isinstance(previous, ArrowScalar): - msg = f"TODO: ArrowScalar.{func!r}" + msg = f"TODO: ArrowScalar.{node.function!r}" raise NotImplementedError(msg) native = previous.native - lists = native - # TODO @dangotbanned: Experiment with explode step - # These options are to mirror `main`, but setting them to `True` may simplify everything after? - explode_w_idx = fn.ExplodeBuilder().explode_with_indices(lists) - idx, v = "idx", "values" - agg_result = ( - AggSpec._from_list_agg(type(func), v).over(explode_w_idx, [idx]).column(v) - ) - if not isinstance(func, ir.lists.Median): - return self.with_native(fn.when_then(fn.is_not_null(lists), agg_result), name) - dtype: pa.DataType = agg_result.type - non_empty_mask = fn.not_eq(fn.list_len(lists), fn.lit(0)) - base_array = fn.repeat_unchecked(fn.lit(None, dtype), len(lists)) - replaced = fn.replace_with_mask( - base_array, fn.fill_null(non_empty_mask, False), agg_result - ) - result = fn.chunked_array(replaced) - return self.with_native(result, name) + agg = AggSpec._from_list_agg(type(node.function), "values") + result = agg.over_index(fn.ExplodeBuilder().explode_with_indices(native), "idx") + return self.with_native(fn.when_then(native.is_valid(), result), name) min = aggregate max = aggregate diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 4a5ca0005f..57a27cb4e9 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -45,8 +45,6 @@ def test_list_agg(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> assert_equal_data(result, expected) -# TODO @dangotbanned: Fix after simplifying `list.{max,mean,min}` -@pytest.mark.xfail(reason="Mismatch at index 5, key a: None != 3", raises=AssertionError) @pytest.mark.xfail( is_windows() and sys.version_info < (3, 10), reason="Old pyarrow windows bad?" ) From 501480f4928adbe6a1debcc312825dc2c3a43f74 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 21:40:59 +0000 Subject: [PATCH 09/26] test: "Fix" `list.median` test Demonstrated in (https://github.com/narwhals-dev/narwhals/pull/3332#discussion_r2617508167) The issue is unrelated to group_by and lists --- tests/plan/list_agg_test.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 57a27cb4e9..603346c8dd 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -22,7 +22,10 @@ def data() -> Data: @pytest.fixture(scope="module") def data_median(data: Data) -> Data: - return {"a": [*data["a"], [3, 4, None]]} + # NOTE: `pyarrow` needs at least 3 (non-null) values to calculate `median` correctly + # Otherwise it picks the lowest non-null + # https://github.com/narwhals-dev/narwhals/pull/3332#discussion_r2617508167 + return {"a": [*data["a"], [3, 4, None, 4, None, 3]]} a = nwp.col("a") @@ -51,9 +54,4 @@ def test_list_agg(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> def test_list_median(data_median: Data) -> None: df = dataframe(data_median).with_columns(cast_a) result = df.select(a.list.median()) - - # TODO @dangotbanned: Is this fixable with `FunctionOptions`? - expected = [2.5, -1, None, None, None, 3.5] - expected_pyarrow = [2.5, -1, None, None, None, 3] - expected = expected_pyarrow - assert_equal_data(result, {"a": expected}) + assert_equal_data(result, {"a": [2.5, -1, None, None, None, 3.5]}) From b5a78b07fe6fb3d9bceb642587420d331a7b421f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 21:55:35 +0000 Subject: [PATCH 10/26] test: Try removing `xfail`? https://github.com/narwhals-dev/narwhals/actions/runs/20214632946/job/58025633966?pr=3353 --- tests/plan/list_agg_test.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 603346c8dd..d8f76b9b19 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from typing import TYPE_CHECKING import pytest @@ -8,7 +7,6 @@ import narwhals as nw import narwhals._plan as nwp from tests.plan.utils import assert_equal_data, dataframe -from tests.utils import is_windows if TYPE_CHECKING: from narwhals._plan.typing import OneOrIterable @@ -48,9 +46,6 @@ def test_list_agg(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> assert_equal_data(result, expected) -@pytest.mark.xfail( - is_windows() and sys.version_info < (3, 10), reason="Old pyarrow windows bad?" -) def test_list_median(data_median: Data) -> None: df = dataframe(data_median).with_columns(cast_a) result = df.select(a.list.median()) From a3a43a40391442096817d1406aef0bebcd8be5a9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 23:23:37 +0000 Subject: [PATCH 11/26] test: Shrink list tests --- tests/plan/list_agg_test.py | 45 ++++++++++++++++---------------- tests/plan/list_contains_test.py | 13 ++++----- tests/plan/list_join_test.py | 5 ++-- tests/plan/list_unique_test.py | 5 ++-- tests/plan/utils.py | 4 +++ 5 files changed, 38 insertions(+), 34 deletions(-) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index d8f76b9b19..f8a36d92c4 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Final import pytest @@ -11,19 +11,23 @@ if TYPE_CHECKING: from narwhals._plan.typing import OneOrIterable from tests.conftest import Data + from tests.plan.utils import SubList -@pytest.fixture(scope="module") -def data() -> Data: - return {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], []]} +R1: Final[SubList[float]] = [3, None, 2, 2, 4, None] +R2: Final[SubList[float]] = [-1] +R3: Final[SubList[float]] = None +R4: Final[SubList[float]] = [None, None, None] +R5: Final[SubList[float]] = [] +# NOTE: `pyarrow` needs at least 3 (non-null) values to calculate `median` correctly +# Otherwise it picks the lowest non-null +# https://github.com/narwhals-dev/narwhals/pull/3332#discussion_r2617508167 +R6: Final[SubList[float]] = [3, 4, None, 4, None, 3] @pytest.fixture(scope="module") -def data_median(data: Data) -> Data: - # NOTE: `pyarrow` needs at least 3 (non-null) values to calculate `median` correctly - # Otherwise it picks the lowest non-null - # https://github.com/narwhals-dev/narwhals/pull/3332#discussion_r2617508167 - return {"a": [*data["a"], [3, 4, None, 4, None, 3]]} +def data() -> Data: + return {"a": [R1, R2, R3, R4, R5, R6]} a = nwp.col("a") @@ -33,20 +37,17 @@ def data_median(data: Data) -> Data: @pytest.mark.parametrize( ("exprs", "expected"), [ - (a.list.max(), {"a": [4, -1, None, None, None]}), - (a.list.mean(), {"a": [2.75, -1, None, None, None]}), - (a.list.min(), {"a": [2, -1, None, None, None]}), - (a.list.sum(), {"a": [11, -1, None, 0, 0]}), + (a.list.max(), [4, -1, None, None, None, 4]), + (a.list.mean(), [2.75, -1, None, None, None, 3.5]), + (a.list.min(), [2, -1, None, None, None, 3]), + (a.list.sum(), [11, -1, None, 0, 0, 14]), + (a.list.median(), [2.5, -1, None, None, None, 3.5]), ], - ids=["max", "mean", "min", "sum"], + ids=["max", "mean", "min", "sum", "median"], ) -def test_list_agg(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: +def test_list_agg( + data: Data, exprs: OneOrIterable[nwp.Expr], expected: list[float | None] +) -> None: df = dataframe(data).with_columns(cast_a) result = df.select(exprs) - assert_equal_data(result, expected) - - -def test_list_median(data_median: Data) -> None: - df = dataframe(data_median).with_columns(cast_a) - result = df.select(a.list.median()) - assert_equal_data(result, {"a": [2.5, -1, None, None, None, 3.5]}) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/list_contains_test.py b/tests/plan/list_contains_test.py index 0761434e97..186c186b3e 100644 --- a/tests/plan/list_contains_test.py +++ b/tests/plan/list_contains_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Final import pytest @@ -11,6 +11,7 @@ if TYPE_CHECKING: from narwhals._plan.typing import IntoExpr from tests.conftest import Data + from tests.plan.utils import SubList @pytest.fixture(scope="module") @@ -42,10 +43,10 @@ def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) assert_equal_data(result, {"a": expected}) -R1: Final[list[Any]] = [None, "A", "B", "A", "A", "B"] -R2: Final = None -R3: Final[list[Any]] = [] -R4: Final = [None] +R1: Final[SubList[str]] = [None, "A", "B", "A", "A", "B"] +R2: Final[SubList[str]] = None +R3: Final[SubList[str]] = [] +R4: Final[SubList[str]] = [None] @pytest.mark.parametrize( @@ -66,7 +67,7 @@ def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) ], ) def test_list_contains_scalar( - row: list[str | None] | None, item: IntoExpr, *, expected: bool | None + row: SubList[str], item: IntoExpr, *, expected: bool | None ) -> None: data = {"a": [row]} df = dataframe(data).select(a.cast(nw.List(nw.String))) diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py index 881700f6c7..fa524f7880 100644 --- a/tests/plan/list_join_test.py +++ b/tests/plan/list_join_test.py @@ -10,14 +10,13 @@ if TYPE_CHECKING: from collections.abc import Sequence - from typing import Final, TypeVar + from typing import Final from typing_extensions import TypeAlias from tests.conftest import Data + from tests.plan.utils import SubList - T = TypeVar("T") - SubList: TypeAlias = list[T] | list[T | None] | list[None] | None SubListStr: TypeAlias = SubList[str] diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py index 7f82e593b5..65763b0176 100644 --- a/tests/plan/list_unique_test.py +++ b/tests/plan/list_unique_test.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from tests.conftest import Data + from tests.plan.utils import SubList @pytest.fixture(scope="module") @@ -50,9 +51,7 @@ def test_list_unique(data: Data) -> None: ([None], [None]), ], ) -def test_list_unique_scalar( - row: list[str | None] | None, expected: list[str | None] | None -) -> None: +def test_list_unique_scalar(row: SubList[str], expected: SubList[str]) -> None: data = {"a": [row]} df = dataframe(data).select(a.cast(nw.List(nw.String))) # NOTE: Don't separate `first().list.unique()` diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 92446a48cd..4b87ea5bd5 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: import sys from collections.abc import Iterable, Mapping + from typing import TypeVar from typing_extensions import LiteralString, TypeAlias @@ -31,6 +32,9 @@ else: _Flags: TypeAlias = int + T = TypeVar("T") + SubList: TypeAlias = list[T] | list[T | None] | list[None] | None + def first(*names: str | Sequence[str]) -> nwp.Expr: return nwp.col(*names).first() From e99f97a870067bb3d80d2d5bb9f30d644875992a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 23:51:04 +0000 Subject: [PATCH 12/26] test: Add `test_list_agg_scalar` Pretty much undoes all the shrinking, but oh well - can clean up later ... --- tests/plan/list_agg_test.py | 53 +++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index f8a36d92c4..6b7b22eb5c 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -51,3 +51,56 @@ def test_list_agg( df = dataframe(data).with_columns(cast_a) result = df.select(exprs) assert_equal_data(result, {"a": expected}) + + +first = a.first() +first_list_max = first.list.max() +first_list_mean = first.list.mean() +first_list_min = first.list.min() +first_list_sum = first.list.sum() +first_list_median = first.list.median() + + +@pytest.mark.xfail(reason="TODO: ArrowScalar.list.", raises=NotImplementedError) +@pytest.mark.parametrize( + ("row", "expr", "expected"), + [ + (R1, first_list_max, 4), + (R2, first_list_max, -1), + (R3, first_list_max, None), + (R4, first_list_max, None), + (R5, first_list_max, None), + (R6, first_list_max, 4), + (R1, first_list_mean, 2.75), + (R2, first_list_mean, -1), + (R3, first_list_mean, None), + (R4, first_list_mean, None), + (R5, first_list_mean, None), + (R6, first_list_mean, 3.5), + (R1, first_list_min, 2), + (R2, first_list_min, -1), + (R3, first_list_min, None), + (R4, first_list_min, None), + (R5, first_list_min, None), + (R6, first_list_min, 3), + (R1, first_list_sum, 11), + (R2, first_list_sum, -1), + (R3, first_list_sum, None), + (R4, first_list_sum, 0), + (R5, first_list_sum, 0), + (R6, first_list_sum, 14), + (R1, first_list_median, 2.5), + (R2, first_list_median, -1), + (R3, first_list_median, None), + (R4, first_list_median, None), + (R5, first_list_median, None), + (R6, first_list_median, 3.5), + ], +) +def test_list_agg_scalar( + row: SubList[float], expr: nwp.Expr, expected: float | None +) -> None: # pragma: no cover + data = {"a": [row]} + df = dataframe(data).select(cast_a) + result = df.select(expr) + assert_equal_data(result, {"a": [expected]}) From 5d0376efd8179712dc8bb0585b529d305f903781 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 23:57:18 +0000 Subject: [PATCH 13/26] why are you like this mypy? --- tests/plan/list_agg_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 6b7b22eb5c..fd39b47b62 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -14,15 +14,15 @@ from tests.plan.utils import SubList -R1: Final[SubList[float]] = [3, None, 2, 2, 4, None] -R2: Final[SubList[float]] = [-1] +R1: Final[SubList[int]] = [3, None, 2, 2, 4, None] +R2: Final[SubList[int]] = [-1] R3: Final[SubList[float]] = None R4: Final[SubList[float]] = [None, None, None] R5: Final[SubList[float]] = [] # NOTE: `pyarrow` needs at least 3 (non-null) values to calculate `median` correctly # Otherwise it picks the lowest non-null # https://github.com/narwhals-dev/narwhals/pull/3332#discussion_r2617508167 -R6: Final[SubList[float]] = [3, 4, None, 4, None, 3] +R6: Final[SubList[int]] = [3, 4, None, 4, None, 3] @pytest.fixture(scope="module") From f8f9909fff844b2ea5d4176dc83d7f7309eaac0d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Dec 2025 00:23:20 +0000 Subject: [PATCH 14/26] perf: Add `ListScalar` fastpaths --- narwhals/_plan/arrow/expr.py | 23 ++++++++++++++++++++--- tests/plan/list_agg_test.py | 4 ++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 5f152ed75f..d5947238aa 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,7 +1,16 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, overload +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Protocol, + TypeVar, + cast, + overload, +) import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -1000,8 +1009,16 @@ def aggregate( ) -> Expr | Scalar: previous = node.input[0].dispatch(self.compliant, frame, name) if isinstance(previous, ArrowScalar): - msg = f"TODO: ArrowScalar.{node.function!r}" - raise NotImplementedError(msg) + scalar = cast("pa.ListScalar[Any]", previous.native) + if not scalar.is_valid: + return self.with_native(scalar, name) + + # TODO @dangotbanned: Do this in a less hacky way + agg = AggSpec._from_list_agg(type(node.function), "values") + func_name = agg.agg.removeprefix("hash_") + exploded = scalar.values + s_result: NativeScalar = pc.call_function(func_name, [exploded], agg.option) + return self.with_native(s_result, name) native = previous.native agg = AggSpec._from_list_agg(type(node.function), "values") diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index fd39b47b62..0e2ae8e7a4 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -61,7 +61,7 @@ def test_list_agg( first_list_median = first.list.median() -@pytest.mark.xfail(reason="TODO: ArrowScalar.list.", raises=NotImplementedError) +# TODO @dangotbanned: Shrink this @pytest.mark.parametrize( ("row", "expr", "expected"), [ @@ -99,7 +99,7 @@ def test_list_agg( ) def test_list_agg_scalar( row: SubList[float], expr: nwp.Expr, expected: float | None -) -> None: # pragma: no cover +) -> None: data = {"a": [row]} df = dataframe(data).select(cast_a) result = df.select(expr) From abd4843da75df3e20e09a3f0006502cff787b938 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Dec 2025 15:13:04 +0000 Subject: [PATCH 15/26] Move to `group_by`, generalize, fix `` --- narwhals/_plan/arrow/expr.py | 29 +------- narwhals/_plan/arrow/group_by.py | 109 ++++++++++++++++++++++++++----- tests/plan/list_agg_test.py | 3 + 3 files changed, 98 insertions(+), 43 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index d5947238aa..e5253ef78d 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,16 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Generic, - Protocol, - TypeVar, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -1008,22 +999,8 @@ def aggregate( self, node: FExpr[lists.Aggregation], frame: Frame, name: str ) -> Expr | Scalar: previous = node.input[0].dispatch(self.compliant, frame, name) - if isinstance(previous, ArrowScalar): - scalar = cast("pa.ListScalar[Any]", previous.native) - if not scalar.is_valid: - return self.with_native(scalar, name) - - # TODO @dangotbanned: Do this in a less hacky way - agg = AggSpec._from_list_agg(type(node.function), "values") - func_name = agg.agg.removeprefix("hash_") - exploded = scalar.values - s_result: NativeScalar = pc.call_function(func_name, [exploded], agg.option) - return self.with_native(s_result, name) - - native = previous.native - agg = AggSpec._from_list_agg(type(node.function), "values") - result = agg.over_index(fn.ExplodeBuilder().explode_with_indices(native), "idx") - return self.with_native(fn.when_then(native.is_valid(), result), name) + agg = AggSpec._from_list_agg(node.function, "values") + return self.with_native(agg.agg_list(previous.native), name) min = aggregate max = aggregate diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 886db057a2..68a3b1516b 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -26,16 +26,17 @@ ArrayAny, ChunkedArray, ChunkedArrayAny, + ChunkedList, + ChunkedOrScalarAny, Indices, + ListScalar, + ScalarAny, ) from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq Incomplete: TypeAlias = Any -# NOTE: Unless stated otherwise, all aggregations have 2 variants: -# - `` (pc.Function.kind == "scalar_aggregate") -# - `hash_` (pc.Function.kind == "hash_aggregate") SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { agg.Sum: "hash_sum", agg.Mean: "hash_mean", @@ -60,7 +61,7 @@ } SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", - ir.Column: "hash_list", # `hash_aggregate` only + ir.Column: "hash_list", } _version_dependent: dict[Any, acero.Aggregation] = {} @@ -72,16 +73,36 @@ SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { ir.boolean.All: "hash_all", ir.boolean.Any: "hash_any", - ir.functions.Unique: "hash_distinct", # `hash_aggregate` only + ir.functions.Unique: "hash_distinct", ir.functions.NullCount: "hash_count", **_version_dependent, } del _version_dependent +SCALAR_OUTPUT_TYPE: Mapping[acero.Aggregation, pa.DataType] = { + "all": fn.BOOL, + "any": fn.BOOL, + "approximate_median": fn.F64, + "count": fn.I64, + "count_all": fn.I64, + "count_distinct": fn.I64, + "kurtosis": fn.F64, + "mean": fn.F64, + "skew": fn.F64, + "stddev": fn.F64, + "variance": fn.F64, +} +"""Scalar aggregates that have an output type **not** dependent on input types*. + +For use in list aggregates, where the input was null. + +*Except `"mean"` can be `Decimal`. +""" + class AggSpec: - __slots__ = ("agg", "name", "option", "target") + __slots__ = ("_function", "_name", "_option", "_target") def __init__( self, @@ -90,22 +111,22 @@ def __init__( option: acero.Opts = None, name: acero.OutputName = "", ) -> None: - self.target = target - self.agg = agg - self.option = option - self.name = name or str(target) + self._target = target + self._function: acero.Aggregation = agg + self._option: acero.Opts = option + self._name: acero.OutputName = name or str(target) @property def use_threads(self) -> bool: """See https://github.com/apache/arrow/issues/36709.""" - return acero.can_thread(self.agg) + return acero.can_thread(self._function) def __iter__(self) -> Iterator[acero.Target | acero.Aggregation | acero.Opts]: """Let's us duck-type as a 4-tuple.""" - yield from (self.target, self.agg, self.option, self.name) + yield from (self._target, self._function, self._option, self._name) def __repr__(self) -> str: - return f"{type(self).__name__}({self.target!r}, {self.agg!r}, {self.option!r}, {self.name!r})" + return f"{type(self).__name__}({self._target!r}, {self._function!r}, {self._option!r}, {self._name!r})" @classmethod def from_named_ir(cls, named_ir: NamedIR) -> Self: @@ -152,10 +173,15 @@ def _from_function(cls, tp: type[ir.Function], name: str) -> Self: return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name) @classmethod - def _from_list_agg(cls, tp: type[ir.lists.Aggregation], name: str) -> Self: + def _from_list_agg(cls, list_agg: ir.lists.Aggregation, /, name: str) -> Self: + tp = type(list_agg) tp_agg = SUPPORTED_LIST_AGG[tp] if tp_agg in {agg.Std, agg.Var}: - msg = f"TODO: {qualified_type_name(agg)!r} needs access to `ddof`, so can't be passed in without an instance" + msg = ( + f"TODO: {qualified_type_name(list_agg)!r} needs access to `ddof`.\n" + "Add some sugar around mapping `ListFunction.` -> `AggExpr.`\n" + "or using `Immutable.__immutable_keys__`" + ) raise NotImplementedError(msg) return cls(name, SUPPORTED_AGG[tp_agg], options.LIST_AGG.get(tp), name) @@ -173,6 +199,26 @@ def implode(cls, name: str) -> Self: # https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-plan/src/dsl/expr/mod.rs#L44 return cls(name, SUPPORTED_IR[ir.Column], None, name) + @overload + def agg_list(self, native: ChunkedList) -> ChunkedArrayAny: ... + @overload + def agg_list(self, native: ListScalar) -> ScalarAny: ... + def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny: + """Execute this aggregation over the values in *each* list, reducing *each* to a single value.""" + result: ChunkedOrScalarAny + if isinstance(native, pa.Scalar): + scalar = cast("pa.ListScalar[Any]", native) + func = HASH_TO_SCALAR_NAME[self._function] + if not scalar.is_valid: + return fn.lit(None, SCALAR_OUTPUT_TYPE.get(func, scalar.type.value_type)) + result = pc.call_function(func, [scalar.values], self._option) + else: + result = self.over_index( + fn.ExplodeBuilder().explode_with_indices(native), "idx" + ) + result = fn.when_then(native.is_valid(), result) + return result + def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: """Sugar for `native.group_by(keys).aggregate([self])`. @@ -185,7 +231,7 @@ def over_index(self, native: pa.Table, index_column: str) -> ChunkedArrayAny: Returns a single, (unnamed) array, representing the aggregation results. """ - return acero.group_by_table(native, [index_column], [self]).column(self.name) + return acero.group_by_table(native, [index_column], [self]).column(self._name) def group_by_error( @@ -339,3 +385,32 @@ def _partition_by_many( # E.g, to push down column selection to *before* collection # Not needed for this task though yield acero.collect(source, acero.filter(key == v), select) + + +def _generate_hash_to_scalar_name() -> Mapping[acero.Aggregation, acero.Aggregation]: + nw_to_hash = SUPPORTED_AGG, SUPPORTED_IR, SUPPORTED_FUNCTION + only_hash = {"hash_distinct", "hash_list", "hash_one"} + targets = set[str](chain.from_iterable(m.values() for m in nw_to_hash)) - only_hash + hash_to_scalar = {hash_name: hash_name.removeprefix("hash_") for hash_name in targets} + # NOTE: Support both of these when using `AggSpec` directly for scalar aggregates + # `(..., "hash_mean", ..., ...)` + # `(..., "mean", ..., ...)` + scalar_names = hash_to_scalar.values() + scalar_to_scalar = zip(scalar_names, scalar_names) + hash_to_scalar.update(dict(scalar_to_scalar)) + return cast("Mapping[acero.Aggregation, acero.Aggregation]", hash_to_scalar) + + +# TODO @dangotbanned: Replace this with a lazier version +# Don't really want this running at import-time, but using `ModuleType.__getattr__` means +# defining it somewhere else +HASH_TO_SCALAR_NAME: Mapping[acero.Aggregation, acero.Aggregation] = ( + _generate_hash_to_scalar_name() +) +"""Mapping between [Hash aggregate] and [Scalar aggregate] names. + +Dynamically built for use in `ListScalar` aggregations, accounting for version availability. + +[Hash aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#grouped-aggregations-group-by +[Scalar aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#aggregations +""" diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 0e2ae8e7a4..0672c8d23c 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -103,4 +103,7 @@ def test_list_agg_scalar( data = {"a": [row]} df = dataframe(data).select(cast_a) result = df.select(expr) + # NOTE: Doing a pure noop on `` will pass `assert_equal_data`, + # but will have the wrong dtype when compared with a non-null agg + assert result.collect_schema()["a"] != nw.List assert_equal_data(result, {"a": [expected]}) From a7c9ee13916937276d5710f5dd949647afab8f95 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Dec 2025 16:19:55 +0000 Subject: [PATCH 16/26] test: Make scalar cases less of a disaster The line diff might only be small right now, but polars has 8 more list aggs ... --- tests/plan/list_agg_test.py | 95 +++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 51 deletions(-) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 0672c8d23c..f698f71481 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -1,28 +1,46 @@ from __future__ import annotations +from itertools import chain from typing import TYPE_CHECKING, Final import pytest import narwhals as nw import narwhals._plan as nwp +from narwhals._plan._dispatch import get_dispatch_name +from narwhals._utils import zip_strict from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from _pytest.mark import ParameterSet + from typing_extensions import TypeAlias + from narwhals._plan.typing import OneOrIterable from tests.conftest import Data from tests.plan.utils import SubList + SubListNumeric: TypeAlias = SubList[int] | SubList[float] -R1: Final[SubList[int]] = [3, None, 2, 2, 4, None] -R2: Final[SubList[int]] = [-1] -R3: Final[SubList[float]] = None -R4: Final[SubList[float]] = [None, None, None] -R5: Final[SubList[float]] = [] + +R1: Final[SubListNumeric] = [3, None, 2, 2, 4, None] +R2: Final[SubListNumeric] = [-1] +R3: Final[SubListNumeric] = None +R4: Final[SubListNumeric] = [None, None, None] +R5: Final[SubListNumeric] = [] # NOTE: `pyarrow` needs at least 3 (non-null) values to calculate `median` correctly # Otherwise it picks the lowest non-null # https://github.com/narwhals-dev/narwhals/pull/3332#discussion_r2617508167 -R6: Final[SubList[int]] = [3, 4, None, 4, None, 3] +R6: Final[SubListNumeric] = [3, 4, None, 4, None, 3] + +ROWS: Final[tuple[SubListNumeric, ...]] = R1, R2, R3, R4, R5, R6 + +EXPECTED_MAX = [4, -1, None, None, None, 4] +EXPECTED_MEAN = [2.75, -1, None, None, None, 3.5] +EXPECTED_MIN = [2, -1, None, None, None, 3] +EXPECTED_SUM = [11, -1, None, 0, 0, 14] +EXPECTED_MEDIAN = [2.5, -1, None, None, None, 3.5] @pytest.fixture(scope="module") @@ -37,11 +55,11 @@ def data() -> Data: @pytest.mark.parametrize( ("exprs", "expected"), [ - (a.list.max(), [4, -1, None, None, None, 4]), - (a.list.mean(), [2.75, -1, None, None, None, 3.5]), - (a.list.min(), [2, -1, None, None, None, 3]), - (a.list.sum(), [11, -1, None, 0, 0, 14]), - (a.list.median(), [2.5, -1, None, None, None, 3.5]), + (a.list.max(), EXPECTED_MAX), + (a.list.mean(), EXPECTED_MEAN), + (a.list.min(), EXPECTED_MIN), + (a.list.sum(), EXPECTED_SUM), + (a.list.median(), EXPECTED_MEDIAN), ], ids=["max", "mean", "min", "sum", "median"], ) @@ -53,52 +71,27 @@ def test_list_agg( assert_equal_data(result, {"a": expected}) -first = a.first() -first_list_max = first.list.max() -first_list_mean = first.list.mean() -first_list_min = first.list.min() -first_list_sum = first.list.sum() -first_list_median = first.list.median() +def cases_scalar( + expr: nwp.Expr, expected: Sequence[float | None] +) -> Iterator[ParameterSet]: + for idx, row_expected in enumerate(zip_strict(ROWS, expected), start=1): + row, out = row_expected + name = get_dispatch_name(expr._ir).removeprefix("list.") + yield pytest.param(row, expr, out, id=f"{name}-R{idx}") -# TODO @dangotbanned: Shrink this @pytest.mark.parametrize( ("row", "expr", "expected"), - [ - (R1, first_list_max, 4), - (R2, first_list_max, -1), - (R3, first_list_max, None), - (R4, first_list_max, None), - (R5, first_list_max, None), - (R6, first_list_max, 4), - (R1, first_list_mean, 2.75), - (R2, first_list_mean, -1), - (R3, first_list_mean, None), - (R4, first_list_mean, None), - (R5, first_list_mean, None), - (R6, first_list_mean, 3.5), - (R1, first_list_min, 2), - (R2, first_list_min, -1), - (R3, first_list_min, None), - (R4, first_list_min, None), - (R5, first_list_min, None), - (R6, first_list_min, 3), - (R1, first_list_sum, 11), - (R2, first_list_sum, -1), - (R3, first_list_sum, None), - (R4, first_list_sum, 0), - (R5, first_list_sum, 0), - (R6, first_list_sum, 14), - (R1, first_list_median, 2.5), - (R2, first_list_median, -1), - (R3, first_list_median, None), - (R4, first_list_median, None), - (R5, first_list_median, None), - (R6, first_list_median, 3.5), - ], + chain( + cases_scalar(a.first().list.max(), EXPECTED_MAX), + cases_scalar(a.first().list.mean(), EXPECTED_MEAN), + cases_scalar(a.first().list.min(), EXPECTED_MIN), + cases_scalar(a.first().list.sum(), EXPECTED_SUM), + cases_scalar(a.first().list.median(), EXPECTED_MEDIAN), + ), ) def test_list_agg_scalar( - row: SubList[float], expr: nwp.Expr, expected: float | None + row: SubListNumeric, expr: nwp.Expr, expected: float | None ) -> None: data = {"a": [row]} df = dataframe(data).select(cast_a) From 3fefcdb216ea9c44195027566282a3c21fde1b0a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Dec 2025 17:29:50 +0000 Subject: [PATCH 17/26] feat(expr-ir): Add `list.{all,any}` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mostly just plumbing things together 😄 --- narwhals/_plan/arrow/expr.py | 2 + narwhals/_plan/arrow/group_by.py | 29 ++++++--- narwhals/_plan/arrow/options.py | 6 +- narwhals/_plan/compliant/accessors.py | 6 ++ narwhals/_plan/expressions/lists.py | 12 +++- tests/plan/list_agg_test.py | 84 ++++++++++++++++----------- 6 files changed, 95 insertions(+), 44 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index e5253ef78d..d794c1b08e 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1007,6 +1007,8 @@ def aggregate( mean = aggregate median = aggregate sum = aggregate + any = aggregate + all = aggregate class ArrowStringNamespace( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 68a3b1516b..43b2906436 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -80,6 +80,12 @@ del _version_dependent + +SUPPORTED_LIST_FUNCTION: Mapping[type[ir.lists.Aggregation], type[ir.Function]] = { + ir.lists.Any: ir.boolean.Any, + ir.lists.All: ir.boolean.All, +} + SCALAR_OUTPUT_TYPE: Mapping[acero.Aggregation, pa.DataType] = { "all": fn.BOOL, "any": fn.BOOL, @@ -175,15 +181,20 @@ def _from_function(cls, tp: type[ir.Function], name: str) -> Self: @classmethod def _from_list_agg(cls, list_agg: ir.lists.Aggregation, /, name: str) -> Self: tp = type(list_agg) - tp_agg = SUPPORTED_LIST_AGG[tp] - if tp_agg in {agg.Std, agg.Var}: - msg = ( - f"TODO: {qualified_type_name(list_agg)!r} needs access to `ddof`.\n" - "Add some sugar around mapping `ListFunction.` -> `AggExpr.`\n" - "or using `Immutable.__immutable_keys__`" - ) - raise NotImplementedError(msg) - return cls(name, SUPPORTED_AGG[tp_agg], options.LIST_AGG.get(tp), name) + if tp_agg := SUPPORTED_LIST_AGG.get(tp): + if tp_agg in {agg.Std, agg.Var}: + msg = ( + f"TODO: {qualified_type_name(list_agg)!r} needs access to `ddof`.\n" + "Add some sugar around mapping `ListFunction.` -> `AggExpr.`\n" + "or using `Immutable.__immutable_keys__`" + ) + raise NotImplementedError(msg) + fn_name = SUPPORTED_AGG[tp_agg] + elif tp_func := SUPPORTED_LIST_FUNCTION.get(tp): + fn_name = SUPPORTED_FUNCTION[tp_func] + else: + raise NotImplementedError(tp) + return cls(name, fn_name, options.LIST_AGG.get(tp), name) @classmethod def any(cls, name: str) -> Self: diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index a9f768f700..158a4e3323 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -163,7 +163,11 @@ def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: def _generate_list_agg() -> Mapping[type[ir.lists.Aggregation], acero.AggregateOptions]: from narwhals._plan.expressions import lists - return {lists.Sum: scalar_aggregate(ignore_nulls=True)} + return { + lists.Sum: scalar_aggregate(ignore_nulls=True), + lists.All: scalar_aggregate(ignore_nulls=True), + lists.Any: scalar_aggregate(ignore_nulls=True), + } def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index d7ef786a34..44e2624b04 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -53,6 +53,12 @@ def median( def sum( self, node: FExpr[lists.Sum], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def any( + self, node: FExpr[lists.Any], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def all( + self, node: FExpr[lists.All], frame: FrameT_contra, name: str + ) -> ExprT_co: ... class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index a4882acb3f..69787cc2d8 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -20,6 +20,8 @@ # fmt: off class ListFunction(Function, accessor="list", options=FunctionOptions.elementwise): ... +class Any(ListFunction): ... +class All(ListFunction): ... class Min(ListFunction): ... class Max(ListFunction): ... class Mean(ListFunction): ... @@ -45,7 +47,7 @@ def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: return expr, item -Aggregation: TypeAlias = "Min | Max | Mean | Median | Sum" +Aggregation: TypeAlias = "Any | All | Min | Max | Mean | Median | Sum" class IRListNamespace(IRNamespace): @@ -59,6 +61,8 @@ class IRListNamespace(IRNamespace): mean: ClassVar = Mean median: ClassVar = Median sum: ClassVar = Sum + any: ClassVar = Any + all: ClassVar = All class ExprListNamespace(ExprNamespace[IRListNamespace]): @@ -106,3 +110,9 @@ def join(self, separator: str, *, ignore_nulls: bool = True) -> Expr: return self._with_unary( self._ir.join(separator=separator, ignore_nulls=ignore_nulls) ) + + def any(self) -> Expr: + return self._with_unary(self._ir.any()) + + def all(self) -> Expr: + return self._with_unary(self._ir.all()) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index f698f71481..9e49d0306b 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -12,90 +12,108 @@ from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Iterable, Iterator, Sequence from _pytest.mark import ParameterSet - from typing_extensions import TypeAlias - from narwhals._plan.typing import OneOrIterable + from narwhals.typing import NonNestedLiteral from tests.conftest import Data from tests.plan.utils import SubList - SubListNumeric: TypeAlias = SubList[int] | SubList[float] - -R1: Final[SubListNumeric] = [3, None, 2, 2, 4, None] -R2: Final[SubListNumeric] = [-1] -R3: Final[SubListNumeric] = None -R4: Final[SubListNumeric] = [None, None, None] -R5: Final[SubListNumeric] = [] +ROWS_N: Final[tuple[SubList[int] | SubList[float], ...]] = ( + [3, None, 2, 2, 4, None], + [-1], + None, + [None, None, None], + [], + [3, 4, None, 4, None, 3], +) # NOTE: `pyarrow` needs at least 3 (non-null) values to calculate `median` correctly # Otherwise it picks the lowest non-null # https://github.com/narwhals-dev/narwhals/pull/3332#discussion_r2617508167 -R6: Final[SubListNumeric] = [3, 4, None, 4, None, 3] -ROWS: Final[tuple[SubListNumeric, ...]] = R1, R2, R3, R4, R5, R6 + +ROWS_B: Final[tuple[SubList[bool], ...]] = ( + [True, True], + [False, True], + [False, False], + [None], + [], + None, +) EXPECTED_MAX = [4, -1, None, None, None, 4] EXPECTED_MEAN = [2.75, -1, None, None, None, 3.5] EXPECTED_MIN = [2, -1, None, None, None, 3] EXPECTED_SUM = [11, -1, None, 0, 0, 14] EXPECTED_MEDIAN = [2.5, -1, None, None, None, 3.5] +EXPECTED_ALL = [True, False, False, True, True, None] +EXPECTED_ANY = [True, True, False, False, False, None] @pytest.fixture(scope="module") def data() -> Data: - return {"a": [R1, R2, R3, R4, R5, R6]} + return {"a": [*ROWS_N], "b": [*ROWS_B]} a = nwp.col("a") +b = nwp.col("b") cast_a = a.cast(nw.List(nw.Int32)) +cast_b = b.cast(nw.List(nw.Boolean)) @pytest.mark.parametrize( - ("exprs", "expected"), + ("expr", "expected"), [ (a.list.max(), EXPECTED_MAX), (a.list.mean(), EXPECTED_MEAN), (a.list.min(), EXPECTED_MIN), (a.list.sum(), EXPECTED_SUM), (a.list.median(), EXPECTED_MEDIAN), + (b.list.all(), EXPECTED_ALL), + (b.list.any(), EXPECTED_ANY), ], - ids=["max", "mean", "min", "sum", "median"], + ids=["max", "mean", "min", "sum", "median", "all", "any"], ) -def test_list_agg( - data: Data, exprs: OneOrIterable[nwp.Expr], expected: list[float | None] -) -> None: - df = dataframe(data).with_columns(cast_a) - result = df.select(exprs) - assert_equal_data(result, {"a": expected}) +def test_list_agg(data: Data, expr: nwp.Expr, expected: list[NonNestedLiteral]) -> None: + df = dataframe(data).with_columns(cast_a, cast_b) + result = df.select(result=expr) + assert_equal_data(result, {"result": expected}) def cases_scalar( - expr: nwp.Expr, expected: Sequence[float | None] + expr: nwp.Expr, + rows: Iterable[Sequence[NonNestedLiteral] | None], + expected: Sequence[NonNestedLiteral], ) -> Iterator[ParameterSet]: - for idx, row_expected in enumerate(zip_strict(ROWS, expected), start=1): + for idx, row_expected in enumerate(zip_strict(rows, expected), start=1): row, out = row_expected name = get_dispatch_name(expr._ir).removeprefix("list.") - yield pytest.param(row, expr, out, id=f"{name}-R{idx}") + yield pytest.param(expr, row, out, id=f"{name}-R{idx}") + + +first_n = nwp.nth(0).cast(nw.List(nw.Int32)).first() +first_b = nwp.nth(0).cast(nw.List(nw.Boolean)).first() @pytest.mark.parametrize( - ("row", "expr", "expected"), + ("expr", "row", "expected"), chain( - cases_scalar(a.first().list.max(), EXPECTED_MAX), - cases_scalar(a.first().list.mean(), EXPECTED_MEAN), - cases_scalar(a.first().list.min(), EXPECTED_MIN), - cases_scalar(a.first().list.sum(), EXPECTED_SUM), - cases_scalar(a.first().list.median(), EXPECTED_MEDIAN), + cases_scalar(first_n.list.max(), ROWS_N, EXPECTED_MAX), + cases_scalar(first_n.list.mean(), ROWS_N, EXPECTED_MEAN), + cases_scalar(first_n.list.min(), ROWS_N, EXPECTED_MIN), + cases_scalar(first_n.list.sum(), ROWS_N, EXPECTED_SUM), + cases_scalar(first_n.list.median(), ROWS_N, EXPECTED_MEDIAN), + cases_scalar(first_b.list.all(), ROWS_B, EXPECTED_ALL), + cases_scalar(first_b.list.any(), ROWS_B, EXPECTED_ANY), ), ) def test_list_agg_scalar( - row: SubListNumeric, expr: nwp.Expr, expected: float | None + expr: nwp.Expr, row: SubList[NonNestedLiteral], expected: NonNestedLiteral ) -> None: data = {"a": [row]} - df = dataframe(data).select(cast_a) - result = df.select(expr) + result = dataframe(data).select(expr) # NOTE: Doing a pure noop on `` will pass `assert_equal_data`, # but will have the wrong dtype when compared with a non-null agg assert result.collect_schema()["a"] != nw.List From 96b66385c9e699b5ae2742f66da3e72283728cf0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:11:39 +0000 Subject: [PATCH 18/26] feat(expr-ir): Add `list.{first,last}` Well this is going swimmingly --- narwhals/_plan/arrow/expr.py | 2 ++ narwhals/_plan/arrow/group_by.py | 2 ++ narwhals/_plan/arrow/options.py | 2 ++ narwhals/_plan/compliant/accessors.py | 6 ++++++ narwhals/_plan/expressions/lists.py | 12 +++++++++++- tests/plan/list_agg_test.py | 8 +++++++- 6 files changed, 30 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index d794c1b08e..dda6cbb0f9 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1009,6 +1009,8 @@ def aggregate( sum = aggregate any = aggregate all = aggregate + first = aggregate + last = aggregate class ArrowStringNamespace( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 43b2906436..56ded2aef2 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -58,6 +58,8 @@ ir.lists.Max: agg.Max, ir.lists.Min: agg.Min, ir.lists.Sum: agg.Sum, + ir.lists.First: agg.First, + ir.lists.Last: agg.Last, } SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 158a4e3323..8fa0d07ccd 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -167,6 +167,8 @@ def _generate_list_agg() -> Mapping[type[ir.lists.Aggregation], acero.AggregateO lists.Sum: scalar_aggregate(ignore_nulls=True), lists.All: scalar_aggregate(ignore_nulls=True), lists.Any: scalar_aggregate(ignore_nulls=True), + lists.First: scalar_aggregate(), + lists.Last: scalar_aggregate(), } diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 44e2624b04..5b2dc17c81 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -59,6 +59,12 @@ def any( def all( self, node: FExpr[lists.All], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def first( + self, node: FExpr[lists.First], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def last( + self, node: FExpr[lists.Last], frame: FrameT_contra, name: str + ) -> ExprT_co: ... class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 69787cc2d8..b74b2974fc 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -22,6 +22,8 @@ class ListFunction(Function, accessor="list", options=FunctionOptions.elementwise): ... class Any(ListFunction): ... class All(ListFunction): ... +class First(ListFunction): ... +class Last(ListFunction): ... class Min(ListFunction): ... class Max(ListFunction): ... class Mean(ListFunction): ... @@ -47,7 +49,7 @@ def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: return expr, item -Aggregation: TypeAlias = "Any | All | Min | Max | Mean | Median | Sum" +Aggregation: TypeAlias = "Any | All | First | Last | Min | Max | Mean | Median | Sum" class IRListNamespace(IRNamespace): @@ -63,6 +65,8 @@ class IRListNamespace(IRNamespace): sum: ClassVar = Sum any: ClassVar = Any all: ClassVar = All + first: ClassVar = First + last: ClassVar = Last class ExprListNamespace(ExprNamespace[IRListNamespace]): @@ -116,3 +120,9 @@ def any(self) -> Expr: def all(self) -> Expr: return self._with_unary(self._ir.all()) + + def first(self) -> Expr: + return self._with_unary(self._ir.first()) + + def last(self) -> Expr: + return self._with_unary(self._ir.last()) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 9e49d0306b..b7f8d80bdb 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -50,6 +50,8 @@ EXPECTED_MEDIAN = [2.5, -1, None, None, None, 3.5] EXPECTED_ALL = [True, False, False, True, True, None] EXPECTED_ANY = [True, True, False, False, False, None] +EXPECTED_FIRST = [3, -1, None, None, None, 3] +EXPECTED_LAST = [None, -1, None, None, None, 3] @pytest.fixture(scope="module") @@ -73,8 +75,10 @@ def data() -> Data: (a.list.median(), EXPECTED_MEDIAN), (b.list.all(), EXPECTED_ALL), (b.list.any(), EXPECTED_ANY), + (a.list.first(), EXPECTED_FIRST), + (a.list.last(), EXPECTED_LAST), ], - ids=["max", "mean", "min", "sum", "median", "all", "any"], + ids=["max", "mean", "min", "sum", "median", "all", "any", "first", "last"], ) def test_list_agg(data: Data, expr: nwp.Expr, expected: list[NonNestedLiteral]) -> None: df = dataframe(data).with_columns(cast_a, cast_b) @@ -107,6 +111,8 @@ def cases_scalar( cases_scalar(first_n.list.median(), ROWS_N, EXPECTED_MEDIAN), cases_scalar(first_b.list.all(), ROWS_B, EXPECTED_ALL), cases_scalar(first_b.list.any(), ROWS_B, EXPECTED_ANY), + cases_scalar(first_n.list.first(), ROWS_N, EXPECTED_FIRST), + cases_scalar(first_n.list.last(), ROWS_N, EXPECTED_LAST), ), ) def test_list_agg_scalar( From 5b310c622dd312f4dd555225240dd8b73580354c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Dec 2025 19:37:24 +0000 Subject: [PATCH 19/26] feat(expr-ir): Add `list.n_unique` --- narwhals/_plan/arrow/expr.py | 1 + narwhals/_plan/arrow/group_by.py | 17 +++++++--- narwhals/_plan/arrow/options.py | 1 + narwhals/_plan/compliant/accessors.py | 3 ++ narwhals/_plan/expressions/lists.py | 9 ++++- tests/plan/list_agg_test.py | 47 ++++++++++++++++++++------- 6 files changed, 61 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index dda6cbb0f9..31aaf08801 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1011,6 +1011,7 @@ def aggregate( all = aggregate first = aggregate last = aggregate + n_unique = aggregate class ArrowStringNamespace( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 56ded2aef2..3246725a88 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -60,6 +60,7 @@ ir.lists.Sum: agg.Sum, ir.lists.First: agg.First, ir.lists.Last: agg.Last, + ir.lists.NUnique: agg.NUnique, } SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", @@ -225,11 +226,14 @@ def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny: if not scalar.is_valid: return fn.lit(None, SCALAR_OUTPUT_TYPE.get(func, scalar.type.value_type)) result = pc.call_function(func, [scalar.values], self._option) - else: - result = self.over_index( - fn.ExplodeBuilder().explode_with_indices(native), "idx" - ) - result = fn.when_then(native.is_valid(), result) + return result + result = self.over_index(fn.ExplodeBuilder().explode_with_indices(native), "idx") + result = fn.when_then(native.is_valid(), result) + if self._is_n_unique(): + # NOTE: Exploding `[]` becomes `[None]` - so we need to adjust the unique count *iff* we were unlucky + len_not_eq_0 = fn.not_eq(fn.list_len(native), fn.lit(0)) + if not fn.all_(len_not_eq_0, ignore_nulls=False).as_py(): + result = fn.when_then(fn.not_(len_not_eq_0), fn.lit(0), result) return result def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: @@ -246,6 +250,9 @@ def over_index(self, native: pa.Table, index_column: str) -> ChunkedArrayAny: """ return acero.group_by_table(native, [index_column], [self]).column(self._name) + def _is_n_unique(self) -> bool: + return self._function == SUPPORTED_AGG[agg.NUnique] + def group_by_error( column_name: str, expr: ir.ExprIR, reason: Literal["too complex"] | None = None diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 8fa0d07ccd..26ff708767 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -169,6 +169,7 @@ def _generate_list_agg() -> Mapping[type[ir.lists.Aggregation], acero.AggregateO lists.Any: scalar_aggregate(ignore_nulls=True), lists.First: scalar_aggregate(), lists.Last: scalar_aggregate(), + lists.NUnique: count("all"), } diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 5b2dc17c81..b028df1fc3 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -65,6 +65,9 @@ def first( def last( self, node: FExpr[lists.Last], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def n_unique( + self, node: FExpr[lists.NUnique], frame: FrameT_contra, name: str + ) -> ExprT_co: ... class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index b74b2974fc..a6fb0772cb 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -28,6 +28,7 @@ class Min(ListFunction): ... class Max(ListFunction): ... class Mean(ListFunction): ... class Median(ListFunction): ... +class NUnique(ListFunction): ... class Sum(ListFunction): ... class Len(ListFunction): ... class Unique(ListFunction): ... @@ -49,7 +50,9 @@ def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: return expr, item -Aggregation: TypeAlias = "Any | All | First | Last | Min | Max | Mean | Median | Sum" +Aggregation: TypeAlias = ( + "Any | All | First | Last | Min | Max | Mean | Median | NUnique | Sum" +) class IRListNamespace(IRNamespace): @@ -67,6 +70,7 @@ class IRListNamespace(IRNamespace): all: ClassVar = All first: ClassVar = First last: ClassVar = Last + n_unique: ClassVar = NUnique class ExprListNamespace(ExprNamespace[IRListNamespace]): @@ -126,3 +130,6 @@ def first(self) -> Expr: def last(self) -> Expr: return self._with_unary(self._ir.last()) + + def n_unique(self) -> Expr: + return self._with_unary(self._ir.n_unique()) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index b7f8d80bdb..4d196dd7ec 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -21,7 +21,7 @@ from tests.plan.utils import SubList -ROWS_N: Final[tuple[SubList[int] | SubList[float], ...]] = ( +ROWS_A: Final[tuple[SubList[int] | SubList[float], ...]] = ( [3, None, 2, 2, 4, None], [-1], None, @@ -42,6 +42,14 @@ [], None, ) +ROWS_C: Final[tuple[SubList[float], ...]] = ( + [1.0, None, None, 3.0], + [1.0, None, 4.0, 5.0, 1.1, 4.0, None, 1.0], + [1.0, None, None, 1.0, 2.0, 2.0, 2.0, None, 3.0], + [], + [None, None, None], + None, +) EXPECTED_MAX = [4, -1, None, None, None, 4] EXPECTED_MEAN = [2.75, -1, None, None, None, 3.5] @@ -52,17 +60,20 @@ EXPECTED_ANY = [True, True, False, False, False, None] EXPECTED_FIRST = [3, -1, None, None, None, 3] EXPECTED_LAST = [None, -1, None, None, None, 3] +EXPECTED_N_UNIQUE = [3, 5, 4, 0, 1, None] @pytest.fixture(scope="module") def data() -> Data: - return {"a": [*ROWS_N], "b": [*ROWS_B]} + return {"a": [*ROWS_A], "b": [*ROWS_B], "c": [*ROWS_C]} a = nwp.col("a") b = nwp.col("b") +c = nwp.col("c") cast_a = a.cast(nw.List(nw.Int32)) cast_b = b.cast(nw.List(nw.Boolean)) +cast_c = b.cast(nw.List(nw.Float64)) @pytest.mark.parametrize( @@ -77,8 +88,20 @@ def data() -> Data: (b.list.any(), EXPECTED_ANY), (a.list.first(), EXPECTED_FIRST), (a.list.last(), EXPECTED_LAST), + (c.list.n_unique(), EXPECTED_N_UNIQUE), + ], + ids=[ + "max", + "mean", + "min", + "sum", + "median", + "all", + "any", + "first", + "last", + "n_unique", ], - ids=["max", "mean", "min", "sum", "median", "all", "any", "first", "last"], ) def test_list_agg(data: Data, expr: nwp.Expr, expected: list[NonNestedLiteral]) -> None: df = dataframe(data).with_columns(cast_a, cast_b) @@ -97,22 +120,24 @@ def cases_scalar( yield pytest.param(expr, row, out, id=f"{name}-R{idx}") -first_n = nwp.nth(0).cast(nw.List(nw.Int32)).first() +first_a = nwp.nth(0).cast(nw.List(nw.Int32)).first() first_b = nwp.nth(0).cast(nw.List(nw.Boolean)).first() +first_c = nwp.nth(0).cast(nw.List(nw.Float64)).first() @pytest.mark.parametrize( ("expr", "row", "expected"), chain( - cases_scalar(first_n.list.max(), ROWS_N, EXPECTED_MAX), - cases_scalar(first_n.list.mean(), ROWS_N, EXPECTED_MEAN), - cases_scalar(first_n.list.min(), ROWS_N, EXPECTED_MIN), - cases_scalar(first_n.list.sum(), ROWS_N, EXPECTED_SUM), - cases_scalar(first_n.list.median(), ROWS_N, EXPECTED_MEDIAN), + cases_scalar(first_a.list.max(), ROWS_A, EXPECTED_MAX), + cases_scalar(first_a.list.mean(), ROWS_A, EXPECTED_MEAN), + cases_scalar(first_a.list.min(), ROWS_A, EXPECTED_MIN), + cases_scalar(first_a.list.sum(), ROWS_A, EXPECTED_SUM), + cases_scalar(first_a.list.median(), ROWS_A, EXPECTED_MEDIAN), cases_scalar(first_b.list.all(), ROWS_B, EXPECTED_ALL), cases_scalar(first_b.list.any(), ROWS_B, EXPECTED_ANY), - cases_scalar(first_n.list.first(), ROWS_N, EXPECTED_FIRST), - cases_scalar(first_n.list.last(), ROWS_N, EXPECTED_LAST), + cases_scalar(first_a.list.first(), ROWS_A, EXPECTED_FIRST), + cases_scalar(first_a.list.last(), ROWS_A, EXPECTED_LAST), + cases_scalar(first_c.list.n_unique(), ROWS_C, EXPECTED_N_UNIQUE), ), ) def test_list_agg_scalar( From 86a3060b2b0227e87b030022e7753bc3c147fc43 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Dec 2025 13:37:22 +0000 Subject: [PATCH 20/26] docs: Rephrase `explode_with_indices` Applied suggestion from @FBruzzesi Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> --- narwhals/_plan/arrow/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 5d9ef53a4a..895a771dca 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -451,7 +451,7 @@ def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table: >>> arr = fn.array([[1, 2, 3], None, [4, 5, 6], []]) >>> fn.ExplodeBuilder().explode_with_indices(arr).to_pydict() {'idx': [0, 0, 0, 1, 2, 2, 2, 3], 'values': [1, 2, 3, None, 4, 5, 6, None]} - # ^ Which sublist we came from ^ The exploded values themselves + # ^ Which sublist values come from ^ The exploded values themselves """ safe = self._fill_with_null(native) if self.options.any() else native arrays = [_list_parent_indices(safe), _list_explode(safe)] From d232439b51eef94522176440c2c78b6f97051e1f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Dec 2025 13:43:38 +0000 Subject: [PATCH 21/26] style: re-align --- narwhals/_plan/arrow/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 895a771dca..db153e56ea 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -451,7 +451,7 @@ def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table: >>> arr = fn.array([[1, 2, 3], None, [4, 5, 6], []]) >>> fn.ExplodeBuilder().explode_with_indices(arr).to_pydict() {'idx': [0, 0, 0, 1, 2, 2, 2, 3], 'values': [1, 2, 3, None, 4, 5, 6, None]} - # ^ Which sublist values come from ^ The exploded values themselves + # ^ Which sublist values come from ^ The exploded values themselves """ safe = self._fill_with_null(native) if self.options.any() else native arrays = [_list_parent_indices(safe), _list_explode(safe)] From 92d0b749a0f996a27d3a1fdae26a5d36516f49b4 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Dec 2025 13:45:21 +0000 Subject: [PATCH 22/26] Apply suggestions from code review --- narwhals/_plan/arrow/group_by.py | 2 +- tests/plan/list_agg_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 3246725a88..8f92c46a7d 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -106,7 +106,7 @@ For use in list aggregates, where the input was null. -*Except `"mean"` can be `Decimal`. +*Except `"mean"` will preserve `Decimal`, if that's where we started. """ diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py index 4d196dd7ec..83e085a068 100644 --- a/tests/plan/list_agg_test.py +++ b/tests/plan/list_agg_test.py @@ -114,9 +114,9 @@ def cases_scalar( rows: Iterable[Sequence[NonNestedLiteral] | None], expected: Sequence[NonNestedLiteral], ) -> Iterator[ParameterSet]: + name = get_dispatch_name(expr._ir).removeprefix("list.") for idx, row_expected in enumerate(zip_strict(rows, expected), start=1): row, out = row_expected - name = get_dispatch_name(expr._ir).removeprefix("list.") yield pytest.param(expr, row, out, id=f"{name}-R{idx}") From 76ba6233fe0a630b0a69fc7acfe97fb0160d4958 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:01:00 +0000 Subject: [PATCH 23/26] refactor: Simplify double negations Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> --- narwhals/_plan/arrow/group_by.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 8f92c46a7d..714b7b52c5 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -231,9 +231,9 @@ def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny: result = fn.when_then(native.is_valid(), result) if self._is_n_unique(): # NOTE: Exploding `[]` becomes `[None]` - so we need to adjust the unique count *iff* we were unlucky - len_not_eq_0 = fn.not_eq(fn.list_len(native), fn.lit(0)) - if not fn.all_(len_not_eq_0, ignore_nulls=False).as_py(): - result = fn.when_then(fn.not_(len_not_eq_0), fn.lit(0), result) + len_eq_0 = fn.eq(fn.list_len(native), fn.lit(0)) + if fn.any_(len_eq_0, ignore_nulls=False).as_py(): + result = fn.when_then(len_eq_0, fn.lit(0), result) return result def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: From f6de206fec8d136ef593e49e296cb1016d6e354c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:04:58 +0000 Subject: [PATCH 24/26] ooh nice, we don't need `ignore_nulls=False` this way! --- narwhals/_plan/arrow/group_by.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 714b7b52c5..15a21841f5 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -232,7 +232,7 @@ def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny: if self._is_n_unique(): # NOTE: Exploding `[]` becomes `[None]` - so we need to adjust the unique count *iff* we were unlucky len_eq_0 = fn.eq(fn.list_len(native), fn.lit(0)) - if fn.any_(len_eq_0, ignore_nulls=False).as_py(): + if fn.any_(len_eq_0).as_py(): result = fn.when_then(len_eq_0, fn.lit(0), result) return result From fc761a98923f3580999e24152a25347e3fad014f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:06:18 +0000 Subject: [PATCH 25/26] refactor: Rename `len_eq_0` -> `is_sublist_empty` --- narwhals/_plan/arrow/group_by.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 15a21841f5..7bee6aa76f 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -231,9 +231,9 @@ def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny: result = fn.when_then(native.is_valid(), result) if self._is_n_unique(): # NOTE: Exploding `[]` becomes `[None]` - so we need to adjust the unique count *iff* we were unlucky - len_eq_0 = fn.eq(fn.list_len(native), fn.lit(0)) - if fn.any_(len_eq_0).as_py(): - result = fn.when_then(len_eq_0, fn.lit(0), result) + is_sublist_empty = fn.eq(fn.list_len(native), fn.lit(0)) + if fn.any_(is_sublist_empty).as_py(): + result = fn.when_then(is_sublist_empty, fn.lit(0), result) return result def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: From f74c4ddbb1327a5e1c75ced822fc2afd403ab291 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 17 Dec 2025 13:14:25 +0000 Subject: [PATCH 26/26] docs: More clearly demo `arrow.options` lazy mappings Resolves https://github.com/narwhals-dev/narwhals/pull/3353#discussion_r2622698504 --- narwhals/_plan/arrow/options.py | 59 ++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 26ff708767..4ebfb3d444 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -1,20 +1,25 @@ -"""Cached `pyarrow.compute` options classes, using `polars` defaults. +"""Cached [`pyarrow.compute` options], using `polars` defaults and naming conventions. -Important: - `AGG` and `FUNCTION` mappings are constructed on first `__getattr__` access. +See `LazyOptions` for [`__getattr__`] usage. + +[`pyarrow.compute` options]: https://arrow.apache.org/docs/dev/python/api/compute.html#compute-options +[`__getattr__`]: https://docs.python.org/3/reference/datamodel.html#module.__getattr__ """ from __future__ import annotations import functools -from typing import TYPE_CHECKING, Any, Literal +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Literal, TypeVar import pyarrow.compute as pc from narwhals._utils import zip_strict if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Sequence + + from typing_extensions import TypeAlias from narwhals._plan import expressions as ir from narwhals._plan.arrow import acero @@ -26,20 +31,57 @@ __all__ = [ "AGG", "FUNCTION", + "LIST_AGG", "array_sort", "count", "join", "join_replace_nulls", + "match_substring", "rank", "scalar_aggregate", "sort", + "split_pattern", "variance", ] -AGG: Mapping[type[agg.AggExpr], acero.AggregateOptions] -FUNCTION: Mapping[type[ir.Function], acero.AggregateOptions] -LIST_AGG: Mapping[type[ir.lists.Aggregation], acero.AggregateOptions] +_T = TypeVar("_T", bound="type[ir.ExprIR | ir.Function]") + +LazyOptions: TypeAlias = Mapping[_T, "acero.AggregateOptions"] +"""Lazily constructed mapping to `pyarrow.compute.FunctionOptions` instances. + +Examples: + >>> from narwhals import _plan as nwp + >>> from narwhals._plan import expressions as ir + >>> from narwhals._plan.arrow import options + >>> + >>> expr = nwp.col("a").first() + >>> expr_ir = expr._ir + >>> expr_ir + col('a').first() + >>> if isinstance(expr_ir, ir.AggExpr): + >>> print(options.AGG.get(type(expr_ir))) + ScalarAggregateOptions(skip_nulls=false, min_count=0) + + The first access to `AGG` generated the mapping + + >>> lazy = {"AGG", "FUNCTION", "LIST_AGG"} + >>> [key for key in options.__dict__ if key in lazy] + ['AGG'] + + We *didn't* generate `FUNCTION`, but it'll be there *when* we need it + + >>> options.FUNCTION.get(ir.functions.NullCount) + CountOptions(mode=NULLS) + + >>> [key for key in options.__dict__ if key in lazy] + ['AGG', 'FUNCTION'] +""" + +AGG: LazyOptions[type[agg.AggExpr]] +FUNCTION: LazyOptions[type[ir.Function]] +LIST_AGG: LazyOptions[type[ir.lists.Aggregation]] + _NULLS_LAST = True _NULLS_FIRST = False @@ -185,7 +227,6 @@ def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: # ruff: noqa: PLW0603 -# NOTE: Using globals for lazy-loading cache if not TYPE_CHECKING: def __getattr__(name: str) -> Any: