diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 74f26de213..31aaf08801 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -15,6 +15,7 @@ is_seq_column, ) from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow.group_by import AggSpec from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co from narwhals._plan.common import temp @@ -994,6 +995,24 @@ def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scala ) return self.with_native(result, name) + def aggregate( + self, node: FExpr[lists.Aggregation], frame: Frame, name: str + ) -> Expr | Scalar: + previous = node.input[0].dispatch(self.compliant, frame, name) + agg = AggSpec._from_list_agg(node.function, "values") + return self.with_native(agg.agg_list(previous.native), name) + + min = aggregate + max = aggregate + mean = aggregate + median = aggregate + sum = aggregate + any = aggregate + all = aggregate + first = aggregate + last = aggregate + n_unique = aggregate + class ArrowStringNamespace( ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 42b44983b5..db153e56ea 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -442,6 +442,17 @@ def explode( return chunked_array(_list_explode(safe)) def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table: + """Explode list elements, expanding one-level into a table indexing the origin. + + Returns a 2-column table, with names `"idx"` and `"values"`: + + >>> from narwhals._plan.arrow import functions as fn + >>> + >>> arr = fn.array([[1, 2, 3], None, [4, 5, 6], []]) + >>> fn.ExplodeBuilder().explode_with_indices(arr).to_pydict() + {'idx': [0, 0, 0, 1, 2, 2, 2, 3], 'values': [1, 2, 3, None, 4, 5, 6, None]} + # ^ Which sublist values come from ^ The exploded values themselves + """ safe = self._fill_with_null(native) if self.options.any() else native arrays = [_list_parent_indices(safe), _list_explode(safe)] return concat_horizontal(arrays, ["idx", "values"]) @@ -1042,6 +1053,12 @@ def _str_zfill_compat( ) +@t.overload +def when_then( + predicate: ChunkedArray[BooleanScalar], then: ScalarAny +) -> ChunkedArrayAny: ... +@t.overload +def when_then(predicate: Array[BooleanScalar], then: ScalarAny) -> ArrayAny: ... @t.overload def when_then( predicate: Predicate, then: SameArrowT, otherwise: SameArrowT | None @@ -1059,6 +1076,11 @@ def when_then( def when_then( predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None ) -> Incomplete: + """Thin wrapper around `pyarrow.compute.if_else`. + + - Supports a 2-arg form, like `pl.when(...).then(...)` + - Accepts python literals, but only in the `otherwise` position + """ if is_non_nested_literal(otherwise): otherwise = lit(otherwise, then.type) return pc.if_else(predicate, then, otherwise) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index af918a3e4c..7bee6aa76f 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -13,7 +13,7 @@ from narwhals._plan.common import temp from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy from narwhals._plan.expressions import aggregation as agg -from narwhals._utils import Implementation +from narwhals._utils import Implementation, qualified_type_name from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: @@ -26,16 +26,17 @@ ArrayAny, ChunkedArray, ChunkedArrayAny, + ChunkedList, + ChunkedOrScalarAny, Indices, + ListScalar, + ScalarAny, ) from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq Incomplete: TypeAlias = Any -# NOTE: Unless stated otherwise, all aggregations have 2 variants: -# - `` (pc.Function.kind == "scalar_aggregate") -# - `hash_` (pc.Function.kind == "hash_aggregate") SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { agg.Sum: "hash_sum", agg.Mean: "hash_mean", @@ -51,9 +52,19 @@ agg.Last: "hash_last", fn.MinMax: "hash_min_max", } +SUPPORTED_LIST_AGG: Mapping[type[ir.lists.Aggregation], type[agg.AggExpr]] = { + ir.lists.Mean: agg.Mean, + ir.lists.Median: agg.Median, + ir.lists.Max: agg.Max, + ir.lists.Min: agg.Min, + ir.lists.Sum: agg.Sum, + ir.lists.First: agg.First, + ir.lists.Last: agg.Last, + ir.lists.NUnique: agg.NUnique, +} SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", - ir.Column: "hash_list", # `hash_aggregate` only + ir.Column: "hash_list", } _version_dependent: dict[Any, acero.Aggregation] = {} @@ -65,7 +76,7 @@ SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { ir.boolean.All: "hash_all", ir.boolean.Any: "hash_any", - ir.functions.Unique: "hash_distinct", # `hash_aggregate` only + ir.functions.Unique: "hash_distinct", ir.functions.NullCount: "hash_count", **_version_dependent, } @@ -73,8 +84,34 @@ del _version_dependent +SUPPORTED_LIST_FUNCTION: Mapping[type[ir.lists.Aggregation], type[ir.Function]] = { + ir.lists.Any: ir.boolean.Any, + ir.lists.All: ir.boolean.All, +} + +SCALAR_OUTPUT_TYPE: Mapping[acero.Aggregation, pa.DataType] = { + "all": fn.BOOL, + "any": fn.BOOL, + "approximate_median": fn.F64, + "count": fn.I64, + "count_all": fn.I64, + "count_distinct": fn.I64, + "kurtosis": fn.F64, + "mean": fn.F64, + "skew": fn.F64, + "stddev": fn.F64, + "variance": fn.F64, +} +"""Scalar aggregates that have an output type **not** dependent on input types*. + +For use in list aggregates, where the input was null. + +*Except `"mean"` will preserve `Decimal`, if that's where we started. +""" + + class AggSpec: - __slots__ = ("agg", "name", "option", "target") + __slots__ = ("_function", "_name", "_option", "_target") def __init__( self, @@ -83,19 +120,22 @@ def __init__( option: acero.Opts = None, name: acero.OutputName = "", ) -> None: - self.target = target - self.agg = agg - self.option = option - self.name = name or str(target) + self._target = target + self._function: acero.Aggregation = agg + self._option: acero.Opts = option + self._name: acero.OutputName = name or str(target) @property def use_threads(self) -> bool: """See https://github.com/apache/arrow/issues/36709.""" - return acero.can_thread(self.agg) + return acero.can_thread(self._function) def __iter__(self) -> Iterator[acero.Target | acero.Aggregation | acero.Opts]: """Let's us duck-type as a 4-tuple.""" - yield from (self.target, self.agg, self.option, self.name) + yield from (self._target, self._function, self._option, self._name) + + def __repr__(self) -> str: + return f"{type(self).__name__}({self._target!r}, {self._function!r}, {self._option!r}, {self._name!r})" @classmethod def from_named_ir(cls, named_ir: NamedIR) -> Self: @@ -141,6 +181,24 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: def _from_function(cls, tp: type[ir.Function], name: str) -> Self: return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name) + @classmethod + def _from_list_agg(cls, list_agg: ir.lists.Aggregation, /, name: str) -> Self: + tp = type(list_agg) + if tp_agg := SUPPORTED_LIST_AGG.get(tp): + if tp_agg in {agg.Std, agg.Var}: + msg = ( + f"TODO: {qualified_type_name(list_agg)!r} needs access to `ddof`.\n" + "Add some sugar around mapping `ListFunction.` -> `AggExpr.`\n" + "or using `Immutable.__immutable_keys__`" + ) + raise NotImplementedError(msg) + fn_name = SUPPORTED_AGG[tp_agg] + elif tp_func := SUPPORTED_LIST_FUNCTION.get(tp): + fn_name = SUPPORTED_FUNCTION[tp_func] + else: + raise NotImplementedError(tp) + return cls(name, fn_name, options.LIST_AGG.get(tp), name) + @classmethod def any(cls, name: str) -> Self: return cls._from_function(ir.boolean.Any, name) @@ -155,6 +213,29 @@ def implode(cls, name: str) -> Self: # https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-plan/src/dsl/expr/mod.rs#L44 return cls(name, SUPPORTED_IR[ir.Column], None, name) + @overload + def agg_list(self, native: ChunkedList) -> ChunkedArrayAny: ... + @overload + def agg_list(self, native: ListScalar) -> ScalarAny: ... + def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny: + """Execute this aggregation over the values in *each* list, reducing *each* to a single value.""" + result: ChunkedOrScalarAny + if isinstance(native, pa.Scalar): + scalar = cast("pa.ListScalar[Any]", native) + func = HASH_TO_SCALAR_NAME[self._function] + if not scalar.is_valid: + return fn.lit(None, SCALAR_OUTPUT_TYPE.get(func, scalar.type.value_type)) + result = pc.call_function(func, [scalar.values], self._option) + return result + result = self.over_index(fn.ExplodeBuilder().explode_with_indices(native), "idx") + result = fn.when_then(native.is_valid(), result) + if self._is_n_unique(): + # NOTE: Exploding `[]` becomes `[None]` - so we need to adjust the unique count *iff* we were unlucky + is_sublist_empty = fn.eq(fn.list_len(native), fn.lit(0)) + if fn.any_(is_sublist_empty).as_py(): + result = fn.when_then(is_sublist_empty, fn.lit(0), result) + return result + def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: """Sugar for `native.group_by(keys).aggregate([self])`. @@ -167,7 +248,10 @@ def over_index(self, native: pa.Table, index_column: str) -> ChunkedArrayAny: Returns a single, (unnamed) array, representing the aggregation results. """ - return acero.group_by_table(native, [index_column], [self]).column(self.name) + return acero.group_by_table(native, [index_column], [self]).column(self._name) + + def _is_n_unique(self) -> bool: + return self._function == SUPPORTED_AGG[agg.NUnique] def group_by_error( @@ -321,3 +405,32 @@ def _partition_by_many( # E.g, to push down column selection to *before* collection # Not needed for this task though yield acero.collect(source, acero.filter(key == v), select) + + +def _generate_hash_to_scalar_name() -> Mapping[acero.Aggregation, acero.Aggregation]: + nw_to_hash = SUPPORTED_AGG, SUPPORTED_IR, SUPPORTED_FUNCTION + only_hash = {"hash_distinct", "hash_list", "hash_one"} + targets = set[str](chain.from_iterable(m.values() for m in nw_to_hash)) - only_hash + hash_to_scalar = {hash_name: hash_name.removeprefix("hash_") for hash_name in targets} + # NOTE: Support both of these when using `AggSpec` directly for scalar aggregates + # `(..., "hash_mean", ..., ...)` + # `(..., "mean", ..., ...)` + scalar_names = hash_to_scalar.values() + scalar_to_scalar = zip(scalar_names, scalar_names) + hash_to_scalar.update(dict(scalar_to_scalar)) + return cast("Mapping[acero.Aggregation, acero.Aggregation]", hash_to_scalar) + + +# TODO @dangotbanned: Replace this with a lazier version +# Don't really want this running at import-time, but using `ModuleType.__getattr__` means +# defining it somewhere else +HASH_TO_SCALAR_NAME: Mapping[acero.Aggregation, acero.Aggregation] = ( + _generate_hash_to_scalar_name() +) +"""Mapping between [Hash aggregate] and [Scalar aggregate] names. + +Dynamically built for use in `ListScalar` aggregations, accounting for version availability. + +[Hash aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#grouped-aggregations-group-by +[Scalar aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#aggregations +""" diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 3d44487bc7..4ebfb3d444 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -1,20 +1,25 @@ -"""Cached `pyarrow.compute` options classes, using `polars` defaults. +"""Cached [`pyarrow.compute` options], using `polars` defaults and naming conventions. -Important: - `AGG` and `FUNCTION` mappings are constructed on first `__getattr__` access. +See `LazyOptions` for [`__getattr__`] usage. + +[`pyarrow.compute` options]: https://arrow.apache.org/docs/dev/python/api/compute.html#compute-options +[`__getattr__`]: https://docs.python.org/3/reference/datamodel.html#module.__getattr__ """ from __future__ import annotations import functools -from typing import TYPE_CHECKING, Any, Literal +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Literal, TypeVar import pyarrow.compute as pc from narwhals._utils import zip_strict if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Sequence + + from typing_extensions import TypeAlias from narwhals._plan import expressions as ir from narwhals._plan.arrow import acero @@ -26,19 +31,57 @@ __all__ = [ "AGG", "FUNCTION", + "LIST_AGG", "array_sort", "count", "join", "join_replace_nulls", + "match_substring", "rank", "scalar_aggregate", "sort", + "split_pattern", "variance", ] -AGG: Mapping[type[agg.AggExpr], acero.AggregateOptions] -FUNCTION: Mapping[type[ir.Function], acero.AggregateOptions] +_T = TypeVar("_T", bound="type[ir.ExprIR | ir.Function]") + +LazyOptions: TypeAlias = Mapping[_T, "acero.AggregateOptions"] +"""Lazily constructed mapping to `pyarrow.compute.FunctionOptions` instances. + +Examples: + >>> from narwhals import _plan as nwp + >>> from narwhals._plan import expressions as ir + >>> from narwhals._plan.arrow import options + >>> + >>> expr = nwp.col("a").first() + >>> expr_ir = expr._ir + >>> expr_ir + col('a').first() + >>> if isinstance(expr_ir, ir.AggExpr): + >>> print(options.AGG.get(type(expr_ir))) + ScalarAggregateOptions(skip_nulls=false, min_count=0) + + The first access to `AGG` generated the mapping + + >>> lazy = {"AGG", "FUNCTION", "LIST_AGG"} + >>> [key for key in options.__dict__ if key in lazy] + ['AGG'] + + We *didn't* generate `FUNCTION`, but it'll be there *when* we need it + + >>> options.FUNCTION.get(ir.functions.NullCount) + CountOptions(mode=NULLS) + + >>> [key for key in options.__dict__ if key in lazy] + ['AGG', 'FUNCTION'] +""" + +AGG: LazyOptions[type[agg.AggExpr]] +FUNCTION: LazyOptions[type[ir.Function]] +LIST_AGG: LazyOptions[type[ir.lists.Aggregation]] + _NULLS_LAST = True _NULLS_FIRST = False @@ -159,6 +202,19 @@ def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: } +def _generate_list_agg() -> Mapping[type[ir.lists.Aggregation], acero.AggregateOptions]: + from narwhals._plan.expressions import lists + + return { + lists.Sum: scalar_aggregate(ignore_nulls=True), + lists.All: scalar_aggregate(ignore_nulls=True), + lists.Any: scalar_aggregate(ignore_nulls=True), + lists.First: scalar_aggregate(), + lists.Last: scalar_aggregate(), + lists.NUnique: count("all"), + } + + def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: from narwhals._plan.expressions import boolean, functions @@ -171,7 +227,6 @@ def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: # ruff: noqa: PLW0603 -# NOTE: Using globals for lazy-loading cache if not TYPE_CHECKING: def __getattr__(name: str) -> Any: @@ -183,5 +238,9 @@ def __getattr__(name: str) -> Any: global FUNCTION FUNCTION = _generate_function() return FUNCTION + if name == "LIST_AGG": + global LIST_AGG + LIST_AGG = _generate_list_agg() + return LIST_AGG msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 26df3ff3ba..b028df1fc3 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -38,6 +38,36 @@ def unique( def join( self, node: FExpr[lists.Join], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def min( + self, node: FExpr[lists.Min], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def max( + self, node: FExpr[lists.Max], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def mean( + self, node: FExpr[lists.Mean], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def median( + self, node: FExpr[lists.Median], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def sum( + self, node: FExpr[lists.Sum], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def any( + self, node: FExpr[lists.Any], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def all( + self, node: FExpr[lists.All], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def first( + self, node: FExpr[lists.First], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def last( + self, node: FExpr[lists.Last], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def n_unique( + self, node: FExpr[lists.NUnique], frame: FrameT_contra, name: str + ) -> ExprT_co: ... class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index e35b4fb41e..a6fb0772cb 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -11,7 +11,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan.expr import Expr from narwhals._plan.expressions import ExprIR, FunctionExpr as FExpr @@ -20,6 +20,16 @@ # fmt: off class ListFunction(Function, accessor="list", options=FunctionOptions.elementwise): ... +class Any(ListFunction): ... +class All(ListFunction): ... +class First(ListFunction): ... +class Last(ListFunction): ... +class Min(ListFunction): ... +class Max(ListFunction): ... +class Mean(ListFunction): ... +class Median(ListFunction): ... +class NUnique(ListFunction): ... +class Sum(ListFunction): ... class Len(ListFunction): ... class Unique(ListFunction): ... class Get(ListFunction): @@ -40,12 +50,27 @@ def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: return expr, item +Aggregation: TypeAlias = ( + "Any | All | First | Last | Min | Max | Mean | Median | NUnique | Sum" +) + + class IRListNamespace(IRNamespace): len: ClassVar = Len unique: ClassVar = Unique contains: ClassVar = Contains get: ClassVar = Get join: ClassVar = Join + min: ClassVar = Min + max: ClassVar = Max + mean: ClassVar = Mean + median: ClassVar = Median + sum: ClassVar = Sum + any: ClassVar = Any + all: ClassVar = All + first: ClassVar = First + last: ClassVar = Last + n_unique: ClassVar = NUnique class ExprListNamespace(ExprNamespace[IRListNamespace]): @@ -53,6 +78,21 @@ class ExprListNamespace(ExprNamespace[IRListNamespace]): def _ir_namespace(self) -> type[IRListNamespace]: return IRListNamespace + def min(self) -> Expr: + return self._with_unary(self._ir.min()) + + def max(self) -> Expr: + return self._with_unary(self._ir.max()) + + def mean(self) -> Expr: + return self._with_unary(self._ir.mean()) + + def median(self) -> Expr: + return self._with_unary(self._ir.median()) + + def sum(self) -> Expr: + return self._with_unary(self._ir.sum()) + def len(self) -> Expr: return self._with_unary(self._ir.len()) @@ -78,3 +118,18 @@ def join(self, separator: str, *, ignore_nulls: bool = True) -> Expr: return self._with_unary( self._ir.join(separator=separator, ignore_nulls=ignore_nulls) ) + + def any(self) -> Expr: + return self._with_unary(self._ir.any()) + + def all(self) -> Expr: + return self._with_unary(self._ir.all()) + + def first(self) -> Expr: + return self._with_unary(self._ir.first()) + + def last(self) -> Expr: + return self._with_unary(self._ir.last()) + + def n_unique(self) -> Expr: + return self._with_unary(self._ir.n_unique()) diff --git a/tests/plan/list_agg_test.py b/tests/plan/list_agg_test.py new file mode 100644 index 0000000000..83e085a068 --- /dev/null +++ b/tests/plan/list_agg_test.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from itertools import chain +from typing import TYPE_CHECKING, Final + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from narwhals._plan._dispatch import get_dispatch_name +from narwhals._utils import zip_strict +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + from _pytest.mark import ParameterSet + + from narwhals.typing import NonNestedLiteral + from tests.conftest import Data + from tests.plan.utils import SubList + + +ROWS_A: Final[tuple[SubList[int] | SubList[float], ...]] = ( + [3, None, 2, 2, 4, None], + [-1], + None, + [None, None, None], + [], + [3, 4, None, 4, None, 3], +) +# NOTE: `pyarrow` needs at least 3 (non-null) values to calculate `median` correctly +# Otherwise it picks the lowest non-null +# https://github.com/narwhals-dev/narwhals/pull/3332#discussion_r2617508167 + + +ROWS_B: Final[tuple[SubList[bool], ...]] = ( + [True, True], + [False, True], + [False, False], + [None], + [], + None, +) +ROWS_C: Final[tuple[SubList[float], ...]] = ( + [1.0, None, None, 3.0], + [1.0, None, 4.0, 5.0, 1.1, 4.0, None, 1.0], + [1.0, None, None, 1.0, 2.0, 2.0, 2.0, None, 3.0], + [], + [None, None, None], + None, +) + +EXPECTED_MAX = [4, -1, None, None, None, 4] +EXPECTED_MEAN = [2.75, -1, None, None, None, 3.5] +EXPECTED_MIN = [2, -1, None, None, None, 3] +EXPECTED_SUM = [11, -1, None, 0, 0, 14] +EXPECTED_MEDIAN = [2.5, -1, None, None, None, 3.5] +EXPECTED_ALL = [True, False, False, True, True, None] +EXPECTED_ANY = [True, True, False, False, False, None] +EXPECTED_FIRST = [3, -1, None, None, None, 3] +EXPECTED_LAST = [None, -1, None, None, None, 3] +EXPECTED_N_UNIQUE = [3, 5, 4, 0, 1, None] + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [*ROWS_A], "b": [*ROWS_B], "c": [*ROWS_C]} + + +a = nwp.col("a") +b = nwp.col("b") +c = nwp.col("c") +cast_a = a.cast(nw.List(nw.Int32)) +cast_b = b.cast(nw.List(nw.Boolean)) +cast_c = b.cast(nw.List(nw.Float64)) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (a.list.max(), EXPECTED_MAX), + (a.list.mean(), EXPECTED_MEAN), + (a.list.min(), EXPECTED_MIN), + (a.list.sum(), EXPECTED_SUM), + (a.list.median(), EXPECTED_MEDIAN), + (b.list.all(), EXPECTED_ALL), + (b.list.any(), EXPECTED_ANY), + (a.list.first(), EXPECTED_FIRST), + (a.list.last(), EXPECTED_LAST), + (c.list.n_unique(), EXPECTED_N_UNIQUE), + ], + ids=[ + "max", + "mean", + "min", + "sum", + "median", + "all", + "any", + "first", + "last", + "n_unique", + ], +) +def test_list_agg(data: Data, expr: nwp.Expr, expected: list[NonNestedLiteral]) -> None: + df = dataframe(data).with_columns(cast_a, cast_b) + result = df.select(result=expr) + assert_equal_data(result, {"result": expected}) + + +def cases_scalar( + expr: nwp.Expr, + rows: Iterable[Sequence[NonNestedLiteral] | None], + expected: Sequence[NonNestedLiteral], +) -> Iterator[ParameterSet]: + name = get_dispatch_name(expr._ir).removeprefix("list.") + for idx, row_expected in enumerate(zip_strict(rows, expected), start=1): + row, out = row_expected + yield pytest.param(expr, row, out, id=f"{name}-R{idx}") + + +first_a = nwp.nth(0).cast(nw.List(nw.Int32)).first() +first_b = nwp.nth(0).cast(nw.List(nw.Boolean)).first() +first_c = nwp.nth(0).cast(nw.List(nw.Float64)).first() + + +@pytest.mark.parametrize( + ("expr", "row", "expected"), + chain( + cases_scalar(first_a.list.max(), ROWS_A, EXPECTED_MAX), + cases_scalar(first_a.list.mean(), ROWS_A, EXPECTED_MEAN), + cases_scalar(first_a.list.min(), ROWS_A, EXPECTED_MIN), + cases_scalar(first_a.list.sum(), ROWS_A, EXPECTED_SUM), + cases_scalar(first_a.list.median(), ROWS_A, EXPECTED_MEDIAN), + cases_scalar(first_b.list.all(), ROWS_B, EXPECTED_ALL), + cases_scalar(first_b.list.any(), ROWS_B, EXPECTED_ANY), + cases_scalar(first_a.list.first(), ROWS_A, EXPECTED_FIRST), + cases_scalar(first_a.list.last(), ROWS_A, EXPECTED_LAST), + cases_scalar(first_c.list.n_unique(), ROWS_C, EXPECTED_N_UNIQUE), + ), +) +def test_list_agg_scalar( + expr: nwp.Expr, row: SubList[NonNestedLiteral], expected: NonNestedLiteral +) -> None: + data = {"a": [row]} + result = dataframe(data).select(expr) + # NOTE: Doing a pure noop on `` will pass `assert_equal_data`, + # but will have the wrong dtype when compared with a non-null agg + assert result.collect_schema()["a"] != nw.List + assert_equal_data(result, {"a": [expected]}) diff --git a/tests/plan/list_contains_test.py b/tests/plan/list_contains_test.py index 0761434e97..186c186b3e 100644 --- a/tests/plan/list_contains_test.py +++ b/tests/plan/list_contains_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Final import pytest @@ -11,6 +11,7 @@ if TYPE_CHECKING: from narwhals._plan.typing import IntoExpr from tests.conftest import Data + from tests.plan.utils import SubList @pytest.fixture(scope="module") @@ -42,10 +43,10 @@ def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) assert_equal_data(result, {"a": expected}) -R1: Final[list[Any]] = [None, "A", "B", "A", "A", "B"] -R2: Final = None -R3: Final[list[Any]] = [] -R4: Final = [None] +R1: Final[SubList[str]] = [None, "A", "B", "A", "A", "B"] +R2: Final[SubList[str]] = None +R3: Final[SubList[str]] = [] +R4: Final[SubList[str]] = [None] @pytest.mark.parametrize( @@ -66,7 +67,7 @@ def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) ], ) def test_list_contains_scalar( - row: list[str | None] | None, item: IntoExpr, *, expected: bool | None + row: SubList[str], item: IntoExpr, *, expected: bool | None ) -> None: data = {"a": [row]} df = dataframe(data).select(a.cast(nw.List(nw.String))) diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py index 881700f6c7..fa524f7880 100644 --- a/tests/plan/list_join_test.py +++ b/tests/plan/list_join_test.py @@ -10,14 +10,13 @@ if TYPE_CHECKING: from collections.abc import Sequence - from typing import Final, TypeVar + from typing import Final from typing_extensions import TypeAlias from tests.conftest import Data + from tests.plan.utils import SubList - T = TypeVar("T") - SubList: TypeAlias = list[T] | list[T | None] | list[None] | None SubListStr: TypeAlias = SubList[str] diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py index 7f82e593b5..65763b0176 100644 --- a/tests/plan/list_unique_test.py +++ b/tests/plan/list_unique_test.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from tests.conftest import Data + from tests.plan.utils import SubList @pytest.fixture(scope="module") @@ -50,9 +51,7 @@ def test_list_unique(data: Data) -> None: ([None], [None]), ], ) -def test_list_unique_scalar( - row: list[str | None] | None, expected: list[str | None] | None -) -> None: +def test_list_unique_scalar(row: SubList[str], expected: SubList[str]) -> None: data = {"a": [row]} df = dataframe(data).select(a.cast(nw.List(nw.String))) # NOTE: Don't separate `first().list.unique()` diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 92446a48cd..4b87ea5bd5 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: import sys from collections.abc import Iterable, Mapping + from typing import TypeVar from typing_extensions import LiteralString, TypeAlias @@ -31,6 +32,9 @@ else: _Flags: TypeAlias = int + T = TypeVar("T") + SubList: TypeAlias = list[T] | list[T | None] | list[None] | None + def first(*names: str | Sequence[str]) -> nwp.Expr: return nwp.col(*names).first()