From eb6a2f8cbe7aa7e13c841ea0cfbe60e1ad23e37f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Nov 2025 19:43:14 +0000 Subject: [PATCH 001/215] feat: Port some simple impls --- narwhals/_plan/arrow/expr.py | 46 +++++++++++++++++++------------ narwhals/_plan/arrow/functions.py | 4 +++ 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 31181b7a95..d51cad8ccd 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -36,6 +36,7 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction + from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -195,10 +196,20 @@ def ternary_expr( result = pc.if_else(when.native, then.native, otherwise.native) return self._with_native(result, name) - exp = not_implemented() # type: ignore[misc] - log = not_implemented() # type: ignore[misc] - sqrt = not_implemented() # type: ignore[misc] - round = not_implemented() # type: ignore[misc] + def log(self, node: FExpr[F.Log], frame: Frame, name: str) -> StoresNativeT_co: + native = node.input[0].dispatch(self, frame, name).native + return self._with_native(pc.logb(native, fn.lit(node.function.base)), name) + + def exp(self, node: FExpr[F.Exp], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(pc.exp)(node, frame, name) + + def sqrt(self, node: FExpr[F.Sqrt], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(pc.sqrt)(node, frame, name) + + def round(self, node: FExpr[F.Round], frame: Frame, name: str) -> StoresNativeT_co: + native = node.input[0].dispatch(self, frame, name).native + return self._with_native(fn.round(native, node.function.decimals), name) + clip = not_implemented() # type: ignore[misc] drop_nulls = not_implemented() # type: ignore[misc] replace_strict = not_implemented() # type: ignore[misc] @@ -381,22 +392,10 @@ def min(self, node: Min, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.min_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def null_count(self, node: FExpr[NullCount], frame: Frame, name: str) -> Scalar: + def null_count(self, node: FExpr[F.NullCount], frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.input[0], frame, name).native return self._with_native(fn.null_count(native), name) - # TODO @dangotbanned: top-level, complex-ish nodes - # - [ ] Over - # - [x] `over_ordered` - # - [x] `group_by`, `join` - # - [x] `over` (with partitions) - # - [x] `over_ordered` (with partitions) - # - [ ] fix: join on nulls after https://github.com/narwhals-dev/narwhals/issues/3300 - # - [ ] `map_batches` - # - [x] elementwise - # - [ ] scalar - # - [ ] `rolling_expr` has 4 variants - def over( self, node: ir.WindowExpr, @@ -453,6 +452,12 @@ def _is_first_last_distinct( index = df.to_series().alias(name) return self.from_series(index.is_in(distinct.get_column(idx_name))) + # TODO @dangotbanned: top-level, complex-ish nodes + # - [ ] `map_batches` + # - [x] elementwise + # - [ ] scalar + # - [ ] `rolling_expr` has 4 variants + # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: if node.is_scalar: @@ -489,6 +494,10 @@ def _cumulative(self, node: FExpr[CumAgg], frame: Frame, name: str) -> Self: result = fn.reverse(func(fn.reverse(native))) return self._with_native(result, name) + def unique(self, node: ir.FunctionExpr[F.Unique], frame: Frame, name: str) -> Self: + result = self._dispatch_expr(node.input[0], frame, name).native.unique() + return self._with_native(result, name) + cum_count = _cumulative cum_min = _cumulative cum_max = _cumulative @@ -501,8 +510,9 @@ def _cumulative(self, node: FExpr[CumAgg], frame: Frame, name: str) -> Self: hist_bins = not_implemented() hist_bin_count = not_implemented() mode = not_implemented() - unique = not_implemented() fill_null_with_strategy = not_implemented() + # NOTE: `kurtosis` and `skew` will need tests + # wanna try adding a `pyarrow>=20` version kurtosis = not_implemented() skew = not_implemented() gather_every = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index a17cb0ebc7..b8b6d74cf4 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -230,6 +230,10 @@ def n_unique(native: Any) -> pa.Int64Scalar: return count(native, mode="all") +def round(native: ChunkedOrScalarAny, decimals: int = 0) -> ChunkedOrScalarAny: + return pc.round(native, decimals, round_mode="half_towards_infinity") + + def reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: """Unlike other slicing ops, `[::-1]` creates a full-copy. From 5033437587bd14ecfa57587ff79ac880756fb750 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Nov 2025 19:59:13 +0000 Subject: [PATCH 002/215] feat: Add `gather_every` --- narwhals/_plan/arrow/expr.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index d51cad8ccd..2a76b2af7b 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -494,10 +494,16 @@ def _cumulative(self, node: FExpr[CumAgg], frame: Frame, name: str) -> Self: result = fn.reverse(func(fn.reverse(native))) return self._with_native(result, name) - def unique(self, node: ir.FunctionExpr[F.Unique], frame: Frame, name: str) -> Self: + def unique(self, node: FExpr[F.Unique], frame: Frame, name: str) -> Self: result = self._dispatch_expr(node.input[0], frame, name).native.unique() return self._with_native(result, name) + # TODO @dangotbanned: Only implement in `ArrowSeries` and reuse + def gather_every(self, node: FExpr[F.GatherEvery], frame: Frame, name: str) -> Self: + native = self._dispatch_expr(node.input[0], frame, name).native + result = native[node.function.offset :: node.function.n] + return self._with_native(result, name) + cum_count = _cumulative cum_min = _cumulative cum_max = _cumulative @@ -515,7 +521,7 @@ def unique(self, node: ir.FunctionExpr[F.Unique], frame: Frame, name: str) -> Se # wanna try adding a `pyarrow>=20` version kurtosis = not_implemented() skew = not_implemented() - gather_every = not_implemented() + is_duplicated = not_implemented() is_unique = not_implemented() From dfe91557624d806bcec2deacabeb80753865bc0a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Nov 2025 21:10:57 +0000 Subject: [PATCH 003/215] feat: Add `drop_nulls` Quite odd behavior for scalar lol --- narwhals/_plan/arrow/expr.py | 15 +++++++++++++-- narwhals/_plan/arrow/functions.py | 9 +++++++++ narwhals/_plan/compliant/scalar.py | 12 +++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 2a76b2af7b..46796df3a0 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -198,7 +198,7 @@ def ternary_expr( def log(self, node: FExpr[F.Log], frame: Frame, name: str) -> StoresNativeT_co: native = node.input[0].dispatch(self, frame, name).native - return self._with_native(pc.logb(native, fn.lit(node.function.base)), name) + return self._with_native(fn.log(native, node.function.base), name) def exp(self, node: FExpr[F.Exp], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.exp)(node, frame, name) @@ -211,7 +211,6 @@ def round(self, node: FExpr[F.Round], frame: Frame, name: str) -> StoresNativeT_ return self._with_native(fn.round(native, node.function.decimals), name) clip = not_implemented() # type: ignore[misc] - drop_nulls = not_implemented() # type: ignore[misc] replace_strict = not_implemented() # type: ignore[misc] @@ -504,6 +503,9 @@ def gather_every(self, node: FExpr[F.GatherEvery], frame: Frame, name: str) -> S result = native[node.function.offset :: node.function.n] return self._with_native(result, name) + def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: + return self._vector_function(fn.drop_nulls)(node, frame, name) + cum_count = _cumulative cum_min = _cumulative cum_max = _cumulative @@ -611,6 +613,15 @@ def null_count(self, node: FExpr[NullCount], frame: Frame, name: str) -> Self: native = node.input[0].dispatch(self, frame, name).native return self._with_native(pa.scalar(0 if native.is_valid else 1), name) + def drop_nulls( # type: ignore[override] + self, node: FExpr[F.DropNulls], frame: Frame, name: str + ) -> Scalar | Expr: + previous = node.input[0].dispatch(self, frame, name) + if previous.native.is_valid: + return previous + chunked = fn.chunked_array([[]], previous.native.type) + return ArrowExpr.from_native(chunked, name, version=self.version) + filter = not_implemented() over = not_implemented() over_ordered = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index b8b6d74cf4..7be357468c 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math import typing as t from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any, Final, Literal, overload @@ -55,6 +56,7 @@ StringScalar, StringType, UnaryFunction, + VectorFunction, ) from narwhals._plan.options import RankOptions, SortMultipleOptions, SortOptions from narwhals.typing import ClosedInterval, IntoArrowSchema, PythonLiteral @@ -234,6 +236,10 @@ def round(native: ChunkedOrScalarAny, decimals: int = 0) -> ChunkedOrScalarAny: return pc.round(native, decimals, round_mode="half_towards_infinity") +def log(native: ChunkedOrScalarAny, base: float = math.e) -> ChunkedOrScalarAny: + return t.cast("ChunkedOrScalarAny", pc.logb(native, lit(base))) + + def reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: """Unlike other slicing ops, `[::-1]` creates a full-copy. @@ -320,6 +326,9 @@ def preserve_nulls( return after +drop_nulls = t.cast("VectorFunction[...]", pc.drop_null) + + def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT], diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index 3e86327240..586bcab7a3 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -10,7 +10,11 @@ from typing_extensions import Self from narwhals._plan import expressions as ir - from narwhals._plan.expressions import FunctionExpr, aggregation as agg + from narwhals._plan.expressions import ( + FunctionExpr, + aggregation as agg, + functions as F, + ) from narwhals._plan.expressions.functions import EwmMean, NullCount, Shift from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -102,6 +106,12 @@ def shift(self, node: FunctionExpr[Shift], frame: FrameT_contra, name: str) -> S return self._with_evaluated(self._evaluated, name) return self.from_python(None, name, dtype=None, version=self.version) + def drop_nulls( # type: ignore[override] + self, node: FunctionExpr[F.DropNulls], frame: FrameT_contra, name: str + ) -> Self | CompliantExpr[FrameT_contra, SeriesT_co]: + """Returns a 0-length Series if null, else noop.""" + ... + arg_max = _always_zero # type: ignore[misc] arg_min = _always_zero # type: ignore[misc] is_first_distinct = _always_true # type: ignore[misc] From 6fbb7e7675559cc75d919d3e741936c54fc71f91 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 19 Nov 2025 22:40:19 +0000 Subject: [PATCH 004/215] prepping for windowed `is_{duplicated,unique}` --- narwhals/_plan/arrow/expr.py | 3 ++- narwhals/_plan/arrow/functions.py | 28 +++++++++++++++++++++++++++- narwhals/_plan/arrow/group_by.py | 1 + 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 46796df3a0..7055dcad94 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -443,7 +443,7 @@ def _is_first_last_distinct( df = df._with_native(df.native.add_column(0, idx_name, column)) else: df = df.with_row_index(idx_name) - agg = fn.IS_FIRST_LAST_DISTINCT[type(node.function)](idx_name) + agg = fn.BOOLEAN_GLUE_FUNCTIONS[type(node.function)](idx_name) if not (partition_by or sort_indices is not None): distinct = df.group_by_names((name,)).agg((ir.named_ir(idx_name, agg),)) else: @@ -524,6 +524,7 @@ def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: kurtosis = not_implemented() skew = not_implemented() + # I think they might both be possible in a similar way to `_is_first_last_distinct` is_duplicated = not_implemented() is_unique = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7be357468c..36c78c6679 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -76,6 +76,14 @@ I64: Final = pa.int64() F64: Final = pa.float64() + +class MinMax(ir.AggExpr): + """Returns a `Struct({'min': ..., 'max': ...})`. + + https://arrow.apache.org/docs/python/generated/pyarrow.compute.min_max.html#pyarrow.compute.min_max + """ + + IntoColumnAgg: TypeAlias = Callable[[str], ir.AggExpr] """Helper constructor for single-column aggregations.""" @@ -139,10 +147,28 @@ def modulus(lhs: Any, rhs: Any) -> Any: "none": (gt, lt), "both": (gt_eq, lt_eq), } -IS_FIRST_LAST_DISTINCT: Mapping[type[ir.boolean.BooleanFunction], IntoColumnAgg] = { + + +def ir_min_max(name: str, /) -> MinMax: + return MinMax(expr=ir.col(name)) + + +BOOLEAN_GLUE_FUNCTIONS: Mapping[type[ir.boolean.BooleanFunction], IntoColumnAgg] = { ir.boolean.IsFirstDistinct: ir.min, ir.boolean.IsLastDistinct: ir.max, + ir.boolean.IsUnique: ir_min_max, + ir.boolean.IsDuplicated: ir_min_max, } +"""Planning to pattern up `_is_first_last_distinct` to work for some other cases. + +The final two lines will need some tweaking, but the same concept is going on: + + index = df.to_series().alias(name) + return self.from_series(index.is_in(distinct.get_column(idx_name))) + + +Will mean at least 2 more functions with `over(*partition_by, order_by=...)` support 😄 +""" @t.overload diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index cce6463239..7ebb674d94 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -49,6 +49,7 @@ agg.NUnique: "hash_count_distinct", agg.First: "hash_first", agg.Last: "hash_last", + fn.MinMax: "hash_min_max", } SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", From 9c0b13eb4550a6ca799cabf8faa610bbc5354410 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 12:32:51 +0000 Subject: [PATCH 005/215] fix(typing): Tighten up `StoresNativeT_co` bound --- narwhals/_plan/arrow/expr.py | 5 ++--- narwhals/_plan/arrow/typing.py | 4 +++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 7055dcad94..c581188193 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -154,9 +154,8 @@ def is_in_expr( ) -> StoresNativeT_co: expr, other = node.function.unwrap_input(node) right = other.dispatch(self, frame, name).native - if isinstance(right, pa.Scalar): - right = fn.array(right) - result = fn.is_in(expr.dispatch(self, frame, name).native, right) + arr = fn.array(right) if isinstance(right, pa.Scalar) else right + result = fn.is_in(expr.dispatch(self, frame, name).native, arr) return self._with_native(result, name) def is_in_series( diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index ad2d42cb16..030cc73cfb 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -156,7 +156,9 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] -StoresNativeT_co = TypeVar("StoresNativeT_co", bound=StoresNative[Any], covariant=True) +StoresNativeT_co = TypeVar( + "StoresNativeT_co", bound=StoresNative[ChunkedOrScalarAny], covariant=True +) DataTypeRemap: TypeAlias = Mapping[DataType, DataType] NullPlacement: TypeAlias = Literal["at_start", "at_end"] From 618c83fbf73b3b65f2929813e4e259f8daa58cd3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 13:29:52 +0000 Subject: [PATCH 006/215] feat: Impl and update `clip` Will come back to this later to shrink --- narwhals/_plan/arrow/expr.py | 30 ++++++++++++++++- narwhals/_plan/arrow/functions.py | 20 ++++++++++++ narwhals/_plan/compliant/expr.py | 6 ++++ narwhals/_plan/expr.py | 11 +++++-- narwhals/_plan/expressions/functions.py | 13 +++++++- tests/plan/clip_test.py | 43 +++++++++++++++++++++++++ 6 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 tests/plan/clip_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index c581188193..9d0f0eeb2d 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -209,7 +209,35 @@ def round(self, node: FExpr[F.Round], frame: Frame, name: str) -> StoresNativeT_ native = node.input[0].dispatch(self, frame, name).native return self._with_native(fn.round(native, node.function.decimals), name) - clip = not_implemented() # type: ignore[misc] + def clip(self, node: FExpr[F.Clip], frame: Frame, name: str) -> StoresNativeT_co: + expr, lower, upper = node.function.unwrap_input(node) + result = fn.clip( + expr.dispatch(self, frame, name).native, + lower.dispatch(self, frame, name).native, + upper.dispatch(self, frame, name).native, + ) + return self._with_native(result, name) + + def clip_lower( + self, node: FExpr[F.ClipLower], frame: Frame, name: str + ) -> StoresNativeT_co: + expr, other = node.function.unwrap_input(node) + result = fn.clip_lower( + expr.dispatch(self, frame, name).native, + other.dispatch(self, frame, name).native, + ) + return self._with_native(result, name) + + def clip_upper( + self, node: FExpr[F.ClipUpper], frame: Frame, name: str + ) -> StoresNativeT_co: + expr, other = node.function.unwrap_input(node) + result = fn.clip_upper( + expr.dispatch(self, frame, name).native, + other.dispatch(self, frame, name).native, + ) + return self._with_native(result, name) + replace_strict = not_implemented() # type: ignore[misc] diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 36c78c6679..7de0b1a6a8 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -243,6 +243,8 @@ def sum_(native: Any) -> NativeScalar: min_ = pc.min +# TODO @dangotbanned: Wrap horizontal functions with correct typing +# Should only return scalar if all elements are as well min_horizontal = pc.min_element_wise max_ = pc.max max_horizontal = pc.max_element_wise @@ -254,6 +256,24 @@ def sum_(native: Any) -> NativeScalar: quantile = pc.quantile +def clip_lower( + native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + return max_horizontal(native, lower) + + +def clip_upper( + native: ChunkedOrScalarAny, upper: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + return min_horizontal(native, upper) + + +def clip( + native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny, upper: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + return clip_lower(clip_upper(native, upper), lower) + + def n_unique(native: Any) -> pa.Int64Scalar: return count(native, mode="all") diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index aead4a8e1a..b294fddb36 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -191,6 +191,12 @@ def skew( def clip( self, node: FunctionExpr[F.Clip], frame: FrameT_contra, name: str ) -> Self: ... + def clip_lower( + self, node: FunctionExpr[F.ClipLower], frame: FrameT_contra, name: str + ) -> Self: ... + def clip_upper( + self, node: FunctionExpr[F.ClipUpper], frame: FrameT_contra, name: str + ) -> Self: ... def drop_nulls( self, node: FunctionExpr[F.DropNulls], frame: FrameT_contra, name: str ) -> Self: ... diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 6b5798f8b1..4d627d6474 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -253,8 +253,15 @@ def clip( lower_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, upper_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, ) -> Self: - it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) - return self._from_ir(F.Clip().to_function_expr(self._ir, *it)) + f: ir.FunctionExpr + if upper_bound is None: + f = F.ClipLower().to_function_expr(self._ir, parse_into_expr_ir(lower_bound)) + elif lower_bound is None: + f = F.ClipUpper().to_function_expr(self._ir, parse_into_expr_ir(upper_bound)) + else: + it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) + f = F.Clip().to_function_expr(self._ir, *it) + return self._from_ir(f) def cum_count(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumCount(reverse=reverse)) diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index 8e910113a1..a5bc4383ef 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -47,7 +47,18 @@ class Sqrt(Function, options=FunctionOptions.elementwise): ... class DropNulls(Function, options=FunctionOptions.row_separable): ... class Mode(Function): ... class Skew(Function, options=FunctionOptions.aggregation): ... -class Clip(Function, options=FunctionOptions.elementwise): ... +class Clip(Function, options=FunctionOptions.elementwise): + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, ExprIR]: + expr, lower_bound, upper_bound = node.input + return expr, lower_bound, upper_bound +class ClipLower(Function, options=FunctionOptions.elementwise): + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, lower_bound = node.input + return expr, lower_bound +class ClipUpper(Function, options=FunctionOptions.elementwise): + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, upper_bound = node.input + return expr, upper_bound class CumCount(CumAgg): ... class CumMin(CumAgg): ... class CumMax(CumAgg): ... diff --git a/tests/plan/clip_test.py b/tests/plan/clip_test.py new file mode 100644 index 0000000000..3beb3f4db1 --- /dev/null +++ b/tests/plan/clip_test.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from narwhals.exceptions import MultiOutputExpressionError +from tests.plan.utils import assert_equal_data, dataframe, series + +if TYPE_CHECKING: + from narwhals._plan.typing import IntoExprColumn + from narwhals.typing import NumericLiteral, TemporalLiteral + +pytest.importorskip("pyarrow") + + +@pytest.mark.parametrize( + ("lower", "upper", "expected"), + [ + (3, 4, [3, 3, 3, 3, 4]), + (0, 4, [1, 2, 3, 0, 4]), + (None, 4, [1, 2, 3, -4, 4]), + (-2, 0, [0, 0, 0, -2, 0]), + (-2, None, [1, 2, 3, -2, 5]), + ("lb", nwp.col("ub") + 1, [3, 2, 3, 1, 3]), + (series([1, 1, 2, 4, 3]), None, [1, 2, 3, 4, 5]), + ], +) +def test_clip_expr( + lower: IntoExprColumn | NumericLiteral | TemporalLiteral | None, + upper: IntoExprColumn | NumericLiteral | TemporalLiteral | None, + expected: list[int], +) -> None: + data = {"a": [1, 2, 3, -4, 5], "lb": [3, 2, 1, 1, 1], "ub": [4, 4, 2, 2, 2]} + result = dataframe(data).select(nwp.col("a").clip(lower, upper)) + assert_equal_data(result, {"a": expected}) + + +def test_clip_invalid() -> None: + df = dataframe({"a": [1, 2, 3], "b": [4, 5, 6]}) + with pytest.raises(MultiOutputExpressionError): + df.select(nwp.col("a").clip(nwp.all(), nwp.col("a", "b"))) From 7da8f5b9bb74eda4eac645c874028c91564f626e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:51:55 +0000 Subject: [PATCH 007/215] fix: Don't allow `__len__` to be used for `__bool__` Discovered while adding `kurtosis` test which had an empty series --- narwhals/_plan/compliant/expr.py | 5 ++++- narwhals/_plan/compliant/scalar.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index b294fddb36..7280295937 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Literal, Protocol from narwhals._plan.compliant.column import EagerBroadcast, SupportsBroadcast from narwhals._plan.compliant.typing import ( @@ -254,6 +254,9 @@ def is_in_series( frame: FrameT_contra, name: str, ) -> Self: ... + def __bool__(self) -> Literal[True]: + # NOTE: Avoids falling back to `__len__` when truth-testing on dispatch + return True class LazyExpr( diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index 586bcab7a3..1244eaa89d 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Literal, Protocol from narwhals._plan.compliant.expr import CompliantExpr, EagerExpr, LazyExpr from narwhals._plan.compliant.typing import FrameT_contra, LengthT, SeriesT, SeriesT_co @@ -146,6 +146,10 @@ class EagerScalar( def __len__(self) -> int: return 1 + def __bool__(self) -> Literal[True]: + # NOTE: Avoids falling back to `__len__` when truth-testing on dispatch + return True + def to_python(self) -> PythonLiteral: ... gather_every = not_implemented() # type: ignore[misc] From 9f23d8e36607073195bf8e58fb929d3100bb4e2f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:38:28 +0000 Subject: [PATCH 008/215] fix: Ensure `with_columns` always broadcasts Adding the `skew` test revealed this "edge case" --- narwhals/_plan/arrow/dataframe.py | 10 ++++++---- narwhals/_plan/compliant/column.py | 22 +++++++++++++++++----- narwhals/_plan/compliant/dataframe.py | 7 ++++++- tests/plan/compliant_test.py | 20 ++++++++++++++++++++ 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 2c0b2e37e8..0497c6cc41 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -96,10 +96,12 @@ def to_polars(self) -> pl.DataFrame: return pl.DataFrame(self.native) - def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[Series]: - ns = namespace(self) - from_named_ir = ns._expr.from_named_ir - yield from ns._expr.align(from_named_ir(e, self) for e in nodes) + def _evaluate_irs( + self, nodes: Iterable[NamedIR[ExprIR]], /, *, length: int | None = None + ) -> Iterator[Series]: + expr = namespace(self)._expr + from_named_ir = expr.from_named_ir + yield from expr.align((from_named_ir(e, self) for e in nodes), default=length) def sort(self, by: Sequence[str], options: SortMultipleOptions | None = None) -> Self: return self.gather(fn.sort_indices(self.native, *by, options=options)) diff --git a/narwhals/_plan/compliant/column.py b/narwhals/_plan/compliant/column.py index 2669a598db..49d49035f3 100644 --- a/narwhals/_plan/compliant/column.py +++ b/narwhals/_plan/compliant/column.py @@ -42,16 +42,25 @@ def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT: @classmethod def _length_required( - cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / + cls, + exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], + /, + default: LengthT | None = None, ) -> LengthT | None: """Return the broadcast length, if all lengths do not equal the maximum.""" @classmethod def align( - cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]] + cls, + *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]], + default: LengthT | None = None, ) -> Iterator[SeriesT]: + """Yield broadcasted `Scalar`s and unwrapped `Expr`s from `exprs`. + + `default` must be provided when operating in a `with_columns` context. + """ exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs)) - length = cls._length_required(exprs) + length = cls._length_required(exprs, default) if length is None: for e in exprs: yield e.to_series() @@ -85,12 +94,15 @@ def _length_max(cls, lengths: Sequence[int], /) -> int: @classmethod def _length_required( - cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], / + cls, + exprs: Sequence[SupportsBroadcast[SeriesT, int]], + /, + default: int | None = None, ) -> int | None: lengths = cls._length_all(exprs) max_length = cls._length_max(lengths) required = any(len_ != max_length for len_ in lengths) - return max_length if required else None + return max_length if required else default class ExprDispatch(HasVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index c439cda262..5e64e105b5 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -178,6 +178,9 @@ class EagerDataFrame( def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... @property def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... + def _evaluate_irs( + self, nodes: Iterable[NamedIR[ir.ExprIR]], /, *, length: int | None = None + ) -> Iterator[SeriesT]: ... def group_by_resolver( self, resolver: GroupByResolver, / @@ -188,7 +191,9 @@ def select(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) def with_columns(self, irs: Seq[NamedIR]) -> Self: - return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) + return self.__narwhals_namespace__()._concat_horizontal( + self._evaluate_irs(irs, length=len(self)) + ) def to_series(self, index: int = 0) -> SeriesT: return self.get_column(self.columns[index]) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 541983246b..18eed7b49a 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -521,6 +521,26 @@ def test_with_columns( assert_equal_data(result, expected) +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.all().first(), {"a": 8, "b": 58, "c": 2.5, "d": 2, "idx": 0}), + (ncs.numeric().null_count(), {"a": 1, "b": 0, "c": 0, "d": 0, "idx": 0}), + ( + ncs.by_index(range(5)).cast(nw.Boolean).fill_null(False).all(), + {"a": False, "b": True, "c": True, "d": True, "idx": False}, + ), + ], +) +def test_with_columns_all_aggregates( + data_indexed: dict[str, Any], expr: nwp.Expr, expected: dict[str, PythonLiteral] +) -> None: + height = len(next(iter(data_indexed.values()))) + expected_full = {k: height * [v] for k, v in expected.items()} + result = dataframe(data_indexed).with_columns(expr) + assert_equal_data(result, expected_full) + + @pytest.mark.parametrize( ("agg", "expected"), [ From 039cb4a034bc5614c1146df62238212413b53945 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:46:35 +0000 Subject: [PATCH 009/215] feat: Partial impl `kurtosis`, `skew` The rest will allow them to be used in `group_by` --- narwhals/_plan/arrow/expr.py | 13 ++++++++---- narwhals/_plan/arrow/functions.py | 33 +++++++++++++++++++++++++++++++ tests/plan/kurtosis_skew_test.py | 31 +++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 tests/plan/kurtosis_skew_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 9d0f0eeb2d..94dd475d90 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -422,6 +422,15 @@ def null_count(self, node: FExpr[F.NullCount], frame: Frame, name: str) -> Scala native = self._dispatch_expr(node.input[0], frame, name).native return self._with_native(fn.null_count(native), name) + # NOTE: `kurtosis` and `skew` will need tests for checking `pyarrow>=20` behaves the same + def kurtosis(self, node: FExpr[F.Kurtosis], frame: Frame, name: str) -> Scalar: + native = self._dispatch_expr(node.input[0], frame, name).native + return self._with_native(fn.kurtosis(native), name) + + def skew(self, node: FExpr[F.Skew], frame: Frame, name: str) -> Scalar: + native = self._dispatch_expr(node.input[0], frame, name).native + return self._with_native(fn.skew(native), name) + def over( self, node: ir.WindowExpr, @@ -546,10 +555,6 @@ def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: hist_bin_count = not_implemented() mode = not_implemented() fill_null_with_strategy = not_implemented() - # NOTE: `kurtosis` and `skew` will need tests - # wanna try adding a `pyarrow>=20` version - kurtosis = not_implemented() - skew = not_implemented() # I think they might both be possible in a similar way to `_is_first_last_distinct` is_duplicated = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7de0b1a6a8..b6dc5c65c2 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -256,6 +256,39 @@ def sum_(native: Any) -> NativeScalar: quantile = pc.quantile +# TODO @dangotbanned: Add `pyarrow>=20` paths +# Share code with `skew` +def kurtosis(native: ChunkedOrArrayAny) -> NativeScalar: + non_null = native.drop_null() + if len(non_null) == 0: + result = lit(None, F64) + elif len(non_null) == 1: + result = lit(float("nan")) + else: + m = sub(non_null, mean(non_null)) + m2 = mean(pc.power(m, lit(2))) + m4 = mean(pc.power(m, lit(4))) + result = sub(pc.divide(m4, pc.power(m2, lit(2))), lit(3)) + return result + + +# TODO @dangotbanned: Add `pyarrow>=20` paths +def skew(native: ChunkedOrArrayAny) -> NativeScalar: + non_null = native.drop_null() + if len(non_null) == 0: + result = lit(None, F64) + elif len(non_null) == 1: + result = lit(float("nan")) + elif len(non_null) == 2: + result = lit(0.0, F64) + else: + m = sub(non_null, mean(non_null)) + m2 = mean(pc.power(m, lit(2))) + m3 = mean(pc.power(m, lit(3))) + result = pc.divide(m3, pc.power(m2, lit(1.5))) + return result + + def clip_lower( native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny ) -> ChunkedOrScalarAny: diff --git a/tests/plan/kurtosis_skew_test.py b/tests/plan/kurtosis_skew_test.py new file mode 100644 index 0000000000..9e1fa6ce16 --- /dev/null +++ b/tests/plan/kurtosis_skew_test.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import pytest + +from narwhals import _plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +@pytest.mark.parametrize( + ("data", "expected_kurtosis", "expected_skew"), + [ + ([], None, None), + ([1], None, None), + ([1, 2], -2, 0.0), + ([0.0, 0.0, 0.0], None, None), + ([1, 2, 3, 2, 1], -1.153061, 0.343622), + ], + ids=range(5), +) +def test_kurtosis_skew_expr( + data: list[float], expected_kurtosis: float | None, expected_skew: float | None +) -> None: + df = dataframe({"a": data}) + kurtosis = nwp.col("a").kurtosis() + skew = nwp.col("a").skew() + height = len(data) + + assert_equal_data(df.select(kurtosis), {"a": [expected_kurtosis]}) + assert_equal_data(df.select(skew), {"a": [expected_skew]}) + assert_equal_data(df.with_columns(kurtosis), {"a": [expected_kurtosis] * height}) + assert_equal_data(df.with_columns(skew), {"a": [expected_skew] * height}) From cc1cffb5a594ca526d817ef825af5f560736aa2e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:49:51 +0000 Subject: [PATCH 010/215] chore: Make the `rolling_*` gap more visible --- narwhals/_plan/arrow/expr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 94dd475d90..0a662298a4 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -508,9 +508,6 @@ def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: result = result.cast(dtype) return self.from_series(result) - def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self: - raise NotImplementedError - def shift(self, node: FExpr[Shift], frame: Frame, name: str) -> Self: return self._vector_function(fn.shift, node.function.n)(node, frame, name) @@ -550,6 +547,9 @@ def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: is_first_distinct = _is_first_last_distinct is_last_distinct = _is_first_last_distinct + # TODO @dangotbanned: Plan composing with `functions.cum_*` + rolling_expr = not_implemented() + # ewm_mean = not_implemented() # noqa: ERA001 hist_bins = not_implemented() hist_bin_count = not_implemented() From 03dfef5e23e84db277dbf927f15afe975ffd81ee Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:44:17 +0000 Subject: [PATCH 011/215] feat: Use native `kurtosis` and `skew` when available --- narwhals/_plan/arrow/functions.py | 17 ++++++++++++++--- tests/plan/kurtosis_skew_test.py | 3 ++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index b6dc5c65c2..eeab7c0d0e 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -69,6 +69,9 @@ HAS_SCATTER: Final = BACKEND_VERSION >= (20,) """`pyarrow.compute.scatter` added in https://github.com/apache/arrow/pull/44394""" +HAS_KURTOSIS_SKEW = BACKEND_VERSION >= (20,) +"""`pyarrow.compute.{kurtosis,skew}` added in https://github.com/apache/arrow/pull/45677""" + HAS_ARANGE: Final = BACKEND_VERSION >= (21,) """`pyarrow.arange` added in https://github.com/apache/arrow/pull/46778""" @@ -256,9 +259,13 @@ def sum_(native: Any) -> NativeScalar: quantile = pc.quantile -# TODO @dangotbanned: Add `pyarrow>=20` paths -# Share code with `skew` +# TODO @dangotbanned: Add `pyarrow>=20` support for `group_by` +# TODO @dangotbanned: Share code with `skew` def kurtosis(native: ChunkedOrArrayAny) -> NativeScalar: + if HAS_KURTOSIS_SKEW: + if pa.types.is_null(native.type): + native = native.cast(F64) + return pc.kurtosis(native) # type: ignore[attr-defined] non_null = native.drop_null() if len(non_null) == 0: result = lit(None, F64) @@ -272,8 +279,12 @@ def kurtosis(native: ChunkedOrArrayAny) -> NativeScalar: return result -# TODO @dangotbanned: Add `pyarrow>=20` paths +# TODO @dangotbanned: See `kurtosis` def skew(native: ChunkedOrArrayAny) -> NativeScalar: + if HAS_KURTOSIS_SKEW: + if pa.types.is_null(native.type): + native = native.cast(F64) + return pc.skew(native) # type: ignore[attr-defined] non_null = native.drop_null() if len(non_null) == 0: result = lit(None, F64) diff --git a/tests/plan/kurtosis_skew_test.py b/tests/plan/kurtosis_skew_test.py index 9e1fa6ce16..1dd98483d5 100644 --- a/tests/plan/kurtosis_skew_test.py +++ b/tests/plan/kurtosis_skew_test.py @@ -10,12 +10,13 @@ ("data", "expected_kurtosis", "expected_skew"), [ ([], None, None), + ([None], None, None), ([1], None, None), ([1, 2], -2, 0.0), ([0.0, 0.0, 0.0], None, None), ([1, 2, 3, 2, 1], -1.153061, 0.343622), + ([None, 1.4, 1.3, 5.9, None, 2.9], -1.014744, 0.801638), ], - ids=range(5), ) def test_kurtosis_skew_expr( data: list[float], expected_kurtosis: float | None, expected_skew: float | None From 291dffdd5b6eed385992eefb7f87c53686a04d55 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:57:18 +0000 Subject: [PATCH 012/215] chore(typing): Happy mypy --- narwhals/_plan/arrow/functions.py | 48 +++++++++++++++++-------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index eeab7c0d0e..069e5af499 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -262,41 +262,45 @@ def sum_(native: Any) -> NativeScalar: # TODO @dangotbanned: Add `pyarrow>=20` support for `group_by` # TODO @dangotbanned: Share code with `skew` def kurtosis(native: ChunkedOrArrayAny) -> NativeScalar: + result: NativeScalar if HAS_KURTOSIS_SKEW: if pa.types.is_null(native.type): native = native.cast(F64) - return pc.kurtosis(native) # type: ignore[attr-defined] - non_null = native.drop_null() - if len(non_null) == 0: - result = lit(None, F64) - elif len(non_null) == 1: - result = lit(float("nan")) + result = pc.kurtosis(native) # type: ignore[attr-defined] else: - m = sub(non_null, mean(non_null)) - m2 = mean(pc.power(m, lit(2))) - m4 = mean(pc.power(m, lit(4))) - result = sub(pc.divide(m4, pc.power(m2, lit(2))), lit(3)) + non_null = native.drop_null() + if len(non_null) == 0: + result = lit(None, F64) + elif len(non_null) == 1: + result = lit(float("nan")) + else: + m = sub(non_null, mean(non_null)) + m2 = mean(pc.power(m, lit(2))) + m4 = mean(pc.power(m, lit(4))) + result = sub(pc.divide(m4, pc.power(m2, lit(2))), lit(3)) return result # TODO @dangotbanned: See `kurtosis` def skew(native: ChunkedOrArrayAny) -> NativeScalar: + result: NativeScalar if HAS_KURTOSIS_SKEW: if pa.types.is_null(native.type): native = native.cast(F64) - return pc.skew(native) # type: ignore[attr-defined] - non_null = native.drop_null() - if len(non_null) == 0: - result = lit(None, F64) - elif len(non_null) == 1: - result = lit(float("nan")) - elif len(non_null) == 2: - result = lit(0.0, F64) + result = pc.skew(native) # type: ignore[attr-defined] else: - m = sub(non_null, mean(non_null)) - m2 = mean(pc.power(m, lit(2))) - m3 = mean(pc.power(m, lit(3))) - result = pc.divide(m3, pc.power(m2, lit(1.5))) + non_null = native.drop_null() + if len(non_null) == 0: + result = lit(None, F64) + elif len(non_null) == 1: + result = lit(float("nan")) + elif len(non_null) == 2: + result = lit(0.0, F64) + else: + m = sub(non_null, mean(non_null)) + m2 = mean(pc.power(m, lit(2))) + m3 = mean(pc.power(m, lit(3))) + result = pc.divide(m3, pc.power(m2, lit(1.5))) return result From 4a360b5fadf79e6404a65eec9052fc814902f412 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 18:33:54 +0000 Subject: [PATCH 013/215] feat: Support `kurtosis`, `skew` in group_by Indirectly adds support for `over` too, but haven't added tests yet --- narwhals/_plan/arrow/acero.py | 6 ++++-- narwhals/_plan/arrow/functions.py | 1 - narwhals/_plan/arrow/group_by.py | 15 ++++++++------ tests/plan/group_by_test.py | 34 +++++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 541d38671a..82102e189d 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -18,7 +18,7 @@ import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any, Final, Union, cast +from typing import TYPE_CHECKING, Any, Final, Literal, Union, cast import pyarrow as pa # ignore-banned-import import pyarrow.acero as pac @@ -61,7 +61,9 @@ """ Target: TypeAlias = OneOrSeq[Field] -Aggregation: TypeAlias = "_Aggregation" +Aggregation: TypeAlias = Union[ + "_Aggregation", Literal["hash_kurtosis", "hash_skew", "kurtosis", "skew"] +] AggregateOptions: TypeAlias = "_AggregateOptions" Opts: TypeAlias = "AggregateOptions | None" OutputName: TypeAlias = str diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 069e5af499..9ebad08b65 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -259,7 +259,6 @@ def sum_(native: Any) -> NativeScalar: quantile = pc.quantile -# TODO @dangotbanned: Add `pyarrow>=20` support for `group_by` # TODO @dangotbanned: Share code with `skew` def kurtosis(native: ChunkedOrArrayAny) -> NativeScalar: result: NativeScalar diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 7ebb674d94..c9eb41fd8c 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -55,19 +55,22 @@ ir.Len: "hash_count_all", ir.Column: "hash_list", # `hash_aggregate` only } + +_version_dependent: dict[Any, acero.Aggregation] = {} +if fn.HAS_KURTOSIS_SKEW: + _version_dependent.update( + {ir.functions.Kurtosis: "hash_kurtosis", ir.functions.Skew: "hash_skew"} + ) + SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { ir.boolean.All: "hash_all", ir.boolean.Any: "hash_any", ir.functions.Unique: "hash_distinct", # `hash_aggregate` only ir.functions.NullCount: "hash_count", + **_version_dependent, } -REQUIRES_PYARROW_20: tuple[Literal["kurtosis"], Literal["skew"]] = ("kurtosis", "skew") -"""They don't show in [our version of the stubs], but are possible in [`pyarrow>=20`]. - -[our version of the stubs]: https://github.com/narwhals-dev/narwhals/issues/2124#issuecomment-3191374210 -[`pyarrow>=20`]: https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations -""" +del _version_dependent class AggSpec: diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index b7f5035fe2..d9bba6223e 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -8,6 +8,7 @@ import narwhals as nw from narwhals import _plan as nwp from narwhals._plan import selectors as ncs +from narwhals._utils import Implementation from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_equal_data, dataframe from tests.utils import PYARROW_VERSION, assert_equal_data as _assert_equal_data @@ -604,6 +605,39 @@ def test_group_by_agg_unique( assert_equal_data(result, expected) +def test_group_by_agg_kurtosis_skew(request: pytest.FixtureRequest) -> None: + data = { + "p1": ["a", "b", None, None, "b", "b"], + "p2": [1, 2, 1, None, None, None], + "p3": [None, 1, 1, 2, 2, None], + "a": [1, 2, 3, 4, 2, 1], + "b": [None, 9.9, 1.5, None, 1.0, 2.1], + } + expected = { + "p1": [None, "a", "b"], + "a_skew": [0.0, float("nan"), -0.707107], + "b_skew": [float("nan"), None, 0.666442], + "b_kurtosis": [float("nan"), None, -1.4999999999999996], + "a_kurtosis": [-2.0, float("nan"), -1.4999999999999998], + } + df = dataframe(data) + + request.applymarker( + pytest.mark.xfail( + (df.implementation is Implementation.PYARROW and PYARROW_VERSION < (20,)), + reason="too old for `pyarrow.compute.{kurtosis,skew}`", + ) + ) + + keys = ("p1",) + aggs = ( + nwp.col("a", "b").skew().name.suffix("_skew"), + nwp.nth(-1, -2).kurtosis().name.suffix("_kurtosis"), + ) + result = df.group_by(keys).agg(aggs).sort(keys) + assert_equal_data(result, expected) + + def test_group_by_args() -> None: """Adapted from [upstream]. From 33f7707c31bea8c601e8c4ada6f3f8107e216ede Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:28:39 +0000 Subject: [PATCH 014/215] refactor: De-duplicate `kurtosis`, `skew` --- narwhals/_plan/arrow/expr.py | 5 ++-- narwhals/_plan/arrow/functions.py | 38 +++++++++---------------------- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 0a662298a4..799959edd1 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -422,14 +422,13 @@ def null_count(self, node: FExpr[F.NullCount], frame: Frame, name: str) -> Scala native = self._dispatch_expr(node.input[0], frame, name).native return self._with_native(fn.null_count(native), name) - # NOTE: `kurtosis` and `skew` will need tests for checking `pyarrow>=20` behaves the same def kurtosis(self, node: FExpr[F.Kurtosis], frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.input[0], frame, name).native - return self._with_native(fn.kurtosis(native), name) + return self._with_native(fn.kurtosis_skew(native, "kurtosis"), name) def skew(self, node: FExpr[F.Skew], frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.input[0], frame, name).native - return self._with_native(fn.skew(native), name) + return self._with_native(fn.kurtosis_skew(native, "skew"), name) def over( self, diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 9ebad08b65..fc1b6aabca 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -259,47 +259,31 @@ def sum_(native: Any) -> NativeScalar: quantile = pc.quantile -# TODO @dangotbanned: Share code with `skew` -def kurtosis(native: ChunkedOrArrayAny) -> NativeScalar: +def kurtosis_skew( + native: ChunkedOrArrayAny, function: Literal["kurtosis", "skew"], / +) -> NativeScalar: result: NativeScalar if HAS_KURTOSIS_SKEW: if pa.types.is_null(native.type): native = native.cast(F64) - result = pc.kurtosis(native) # type: ignore[attr-defined] + result = getattr(pc, function)(native) else: non_null = native.drop_null() if len(non_null) == 0: result = lit(None, F64) elif len(non_null) == 1: result = lit(float("nan")) - else: - m = sub(non_null, mean(non_null)) - m2 = mean(pc.power(m, lit(2))) - m4 = mean(pc.power(m, lit(4))) - result = sub(pc.divide(m4, pc.power(m2, lit(2))), lit(3)) - return result - - -# TODO @dangotbanned: See `kurtosis` -def skew(native: ChunkedOrArrayAny) -> NativeScalar: - result: NativeScalar - if HAS_KURTOSIS_SKEW: - if pa.types.is_null(native.type): - native = native.cast(F64) - result = pc.skew(native) # type: ignore[attr-defined] - else: - non_null = native.drop_null() - if len(non_null) == 0: - result = lit(None, F64) - elif len(non_null) == 1: - result = lit(float("nan")) - elif len(non_null) == 2: + elif function == "skew" and len(non_null) == 2: result = lit(0.0, F64) else: m = sub(non_null, mean(non_null)) m2 = mean(pc.power(m, lit(2))) - m3 = mean(pc.power(m, lit(3))) - result = pc.divide(m3, pc.power(m2, lit(1.5))) + if function == "kurtosis": + m4 = mean(pc.power(m, lit(4))) + result = sub(pc.divide(m4, pc.power(m2, lit(2))), lit(3)) + else: + m3 = mean(pc.power(m, lit(3))) + result = pc.divide(m3, pc.power(m2, lit(1.5))) return result From 2e39678954095f773cac7da9072821578122dfc5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 21:42:55 +0000 Subject: [PATCH 015/215] feat(DRAFT): Impl and update `mode` - Still needs tests - Also unsure what the scalar behavior should be for `mode_all` --- narwhals/_plan/arrow/expr.py | 8 +++++++- narwhals/_plan/arrow/functions.py | 16 ++++++++++++++++ narwhals/_plan/compliant/expr.py | 7 +++++-- narwhals/_plan/compliant/scalar.py | 2 ++ narwhals/_plan/expr.py | 8 ++++++-- narwhals/_plan/expressions/functions.py | 3 ++- tests/plan/expr_parsing_test.py | 7 +++++++ 7 files changed, 45 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 799959edd1..673d144569 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -538,6 +538,13 @@ def gather_every(self, node: FExpr[F.GatherEvery], frame: Frame, name: str) -> S def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: return self._vector_function(fn.drop_nulls)(node, frame, name) + def mode_all(self, node: FExpr[F.ModeAll], frame: Frame, name: str) -> Self: + return self._vector_function(fn.mode_all)(node, frame, name) + + def mode_any(self, node: FExpr[F.ModeAny], frame: Frame, name: str) -> Scalar: + native = self._dispatch_expr(node.input[0], frame, name).native + return self._with_native(fn.mode_any(native), name) + cum_count = _cumulative cum_min = _cumulative cum_max = _cumulative @@ -552,7 +559,6 @@ def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: # ewm_mean = not_implemented() # noqa: ERA001 hist_bins = not_implemented() hist_bin_count = not_implemented() - mode = not_implemented() fill_null_with_strategy = not_implemented() # I think they might both be possible in a similar way to `_is_first_last_distinct` diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index fc1b6aabca..4fec670702 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -245,6 +245,14 @@ def sum_(native: Any) -> NativeScalar: return pc.sum(native, min_count=0) +def first(native: ChunkedOrArrayAny) -> NativeScalar: + return pc.first(native, options=pa_options.scalar_aggregate()) + + +def last(native: ChunkedOrArrayAny) -> NativeScalar: + return pc.last(native, options=pa_options.scalar_aggregate()) + + min_ = pc.min # TODO @dangotbanned: Wrap horizontal functions with correct typing # Should only return scalar if all elements are as well @@ -259,6 +267,14 @@ def sum_(native: Any) -> NativeScalar: quantile = pc.quantile +def mode_all(native: ChunkedArrayAny) -> ChunkedArrayAny: + return pa.chunked_array([pc.mode(native, n=len(native)).field("mode")]) + + +def mode_any(native: ChunkedArrayAny) -> NativeScalar: + return first(pc.mode(native, n=1).field("mode")) + + def kurtosis_skew( native: ChunkedOrArrayAny, function: Literal["kurtosis", "skew"], / ) -> NativeScalar: diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 7280295937..908982b087 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -223,9 +223,12 @@ def is_unique( self, node: FunctionExpr[boolean.IsUnique], frame: FrameT_contra, name: str ) -> Self: ... def log(self, node: FunctionExpr[F.Log], frame: FrameT_contra, name: str) -> Self: ... - def mode( - self, node: FunctionExpr[F.Mode], frame: FrameT_contra, name: str + def mode_all( + self, node: FunctionExpr[F.ModeAll], frame: FrameT_contra, name: str ) -> Self: ... + def mode_any( + self, node: FunctionExpr[F.ModeAny], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... def replace_strict( self, node: FunctionExpr[F.ReplaceStrict], frame: FrameT_contra, name: str ) -> Self: ... diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index 1244eaa89d..18a265b282 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -130,6 +130,8 @@ def drop_nulls( # type: ignore[override] sum = _always_noop # type: ignore[misc] mode = _always_noop # type: ignore[misc] unique = _always_noop # type: ignore[misc] + mode_all = not_implemented() # type: ignore[misc] + mode_any = _always_noop # type: ignore[misc] kurtosis = _always_nan # type: ignore[misc] skew = _always_nan # type: ignore[misc] fill_null_with_strategy = not_implemented() # type: ignore[misc] diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 4d627d6474..9c25041111 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -46,6 +46,7 @@ ClosedInterval, FillNullStrategy, IntoDType, + ModeKeepStrategy, NumericLiteral, RankMethod, RollingInterpolationMethod, @@ -238,8 +239,11 @@ def shift(self, n: int) -> Self: def drop_nulls(self) -> Self: return self._with_unary(F.DropNulls()) - def mode(self) -> Self: - return self._with_unary(F.Mode()) + def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: + if func := {"all": F.ModeAll, "any": F.ModeAny}.get(keep): + return self._with_unary(func()) + msg = f"`keep` must be one of ('all', 'any'), but got {keep!r}" + raise TypeError(msg) def skew(self) -> Self: return self._with_unary(F.Skew()) diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index a5bc4383ef..8dafc7ff00 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -45,7 +45,8 @@ class NullCount(Function, options=FunctionOptions.aggregation): ... class Exp(Function, options=FunctionOptions.elementwise): ... class Sqrt(Function, options=FunctionOptions.elementwise): ... class DropNulls(Function, options=FunctionOptions.row_separable): ... -class Mode(Function): ... +class ModeAll(Function): ... +class ModeAny(Function, options=FunctionOptions.aggregation): ... class Skew(Function, options=FunctionOptions.aggregation): ... class Clip(Function, options=FunctionOptions.elementwise): def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, ExprIR]: diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index b588b966b8..e559e5165d 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -693,3 +693,10 @@ def test_replace_strict_invalid() -> None: match="`new` argument cannot be used if `old` argument is a Mapping type", ): nwp.col("a").replace_strict(old={1: 2, 3: 4}, new=[5, 6, 7]) + + +def test_mode_invalid() -> None: + with pytest.raises( + TypeError, match=r"keep.+must be one of.+all.+any.+but got 'first'" + ): + nwp.col("a").mode(keep="first") # type: ignore[arg-type] From 84ce86c618c0103cb08bc63d68a709c424da2106 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Nov 2025 23:13:55 +0000 Subject: [PATCH 016/215] fix `mode_all` and port tests --- narwhals/_plan/arrow/functions.py | 5 +++- tests/plan/mode_test.py | 49 +++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 tests/plan/mode_test.py diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 4fec670702..5c31cbcb10 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -268,7 +268,10 @@ def last(native: ChunkedOrArrayAny) -> NativeScalar: def mode_all(native: ChunkedArrayAny) -> ChunkedArrayAny: - return pa.chunked_array([pc.mode(native, n=len(native)).field("mode")]) + struct = pc.mode(native, n=len(native)) + indices: pa.Int32Array = struct.field("count").dictionary_encode().indices # type: ignore[attr-defined] + index_true_modes = lit(0) + return chunked_array(struct.field("mode").filter(pc.equal(indices, index_true_modes))) def mode_any(native: ChunkedArrayAny) -> NativeScalar: diff --git a/tests/plan/mode_test.py b/tests/plan/mode_test.py new file mode 100644 index 0000000000..e3cef89f2f --- /dev/null +++ b/tests/plan/mode_test.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from narwhals._plan import selectors as ncs +from narwhals.exceptions import ShapeError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [1, 1, 2, 2, 3], "b": [1, 2, 3, 3, 4]} + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("b").mode(), {"b": [3]}), + (nwp.col("a").mode(keep="all"), {"a": [1, 2]}), + (nwp.col("b").filter(nwp.col("b") != 3).mode(), {"b": [1, 2, 4]}), + (nwp.col("a").mode().sum(), {"a": [3]}), + ], + ids=["single", "multiple-1", "multiple-2", "mutliple-agg"], +) +def test_mode_expr_keep_all(data: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data).select(expr).sort(ncs.first()) + assert_equal_data(result, expected) + + +def test_mode_expr_different_lengths_keep_all(data: Data) -> None: + df = dataframe(data) + with pytest.raises(ShapeError): + df.select(nwp.col("a", "b").mode(keep="all")) + + +def test_mode_expr_keep_any(data: Data) -> None: + result = dataframe(data).select(nwp.col("a", "b").mode(keep="any")) + try: + expected = {"a": [1], "b": [3]} + assert_equal_data(result, expected) + except AssertionError: # pragma: no cover + expected = {"a": [2], "b": [3]} + assert_equal_data(result, expected) From a0b6395d18476d66b81db43909874250d9d0b4ba Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:54:52 +0000 Subject: [PATCH 017/215] test: Add tests for `is_{duplicated,unique}` --- tests/plan/is_duplicated_unique_test.py | 69 +++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/plan/is_duplicated_unique_test.py diff --git a/tests/plan/is_duplicated_unique_test.py b/tests/plan/is_duplicated_unique_test.py new file mode 100644 index 0000000000..0563dd79e2 --- /dev/null +++ b/tests/plan/is_duplicated_unique_test.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from narwhals._plan import selectors as ncs +from narwhals.exceptions import InvalidOperationError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +XFAIL_NOT_IMPL = pytest.mark.xfail( + reason="TODO: `ArrowExpr.is_{duplicated,unique}`", raises=NotImplementedError +) + +XFAIL_NOT_GROUP_BY = pytest.mark.xfail( + reason="TODO: `ArrowExpr.is_{duplicated,unique}.over(*partition_by)`", + raises=InvalidOperationError, +) + + +@pytest.fixture +def data() -> Data: + return { + "v1": [None, 2, 1, 4, 1], + "v2": ["a", "c", "c", None, None], + "p1": [2, 2, 2, 1, 1], + "i": [0, 1, 2, 3, 4], + } + + +@XFAIL_NOT_IMPL +def test_is_duplicated_unique(data: Data) -> None: + expected = { + "v1_is_unique": [True, True, False, True, False], + "v2_is_unique": [True, False, False, False, False], + "v1_is_duplicated": [False, False, True, False, True], + "v2_is_duplicated": [False, True, True, True, True], + } + vals = nwp.col("v1", "v2") + exprs = ( + vals.is_unique().name.suffix("_is_unique"), + vals.is_duplicated().name.suffix("_is_duplicated"), + ) + result = dataframe(data).select("i", *exprs).sort("i").drop("i") + assert_equal_data(result, expected) # pragma: no cover + + +# NOTE: Not supported on `main` +# Planning to adapt `is_{first,last}_distinct` idea here +@XFAIL_NOT_GROUP_BY +def test_is_duplicated_unique_partitioned(data: Data) -> None: + expected = { + "v1_is_unique": [True, True, True, True, True], + "v2_is_unique": [True, False, False, False, False], + "v1_is_duplicated": [False, False, False, False, False], + "v2_is_duplicated": [False, True, True, True, True], + } + vals = ncs.by_index(0, 1) + exprs = ( + vals.is_unique().name.suffix("_is_unique").over("p1"), + vals.is_duplicated().name.suffix("_is_duplicated").over("p1"), + ) + result = dataframe(data).select("i", *exprs).sort("i").drop("i") + assert_equal_data(result, expected) # pragma: no cover From 4c25b05005c1e676ca3ed7a39750032575a1f80e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 21 Nov 2025 19:19:50 +0000 Subject: [PATCH 018/215] feat: Add `is_{duplicated,unique}` + support in `over` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🥳🥳🥳 --- narwhals/_plan/arrow/expr.py | 37 ++++++++----- narwhals/_plan/arrow/functions.py | 74 +++++++++++++++++-------- narwhals/_plan/arrow/typing.py | 8 +++ tests/plan/is_duplicated_unique_test.py | 18 +----- 4 files changed, 85 insertions(+), 52 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 673d144569..bf6519499f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -17,11 +17,13 @@ from narwhals._plan.compliant.scalar import EagerScalar from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions.boolean import ( + IsDuplicated, IsFirstDistinct, IsInExpr, IsInSeq, IsInSeries, IsLastDistinct, + IsUnique, ) from narwhals._plan.expressions.functions import NullCount from narwhals._utils import Implementation, Version, _StoresNative, not_implemented @@ -441,9 +443,9 @@ def over( expr = node.expr by = node.partition_by if is_function_expr(expr) and isinstance( - expr.function, (IsFirstDistinct, IsLastDistinct) + expr.function, (IsFirstDistinct, IsLastDistinct, IsUnique, IsDuplicated) ): - return self._is_first_last_distinct( + return self._boolean_length_preserving( expr, frame, name, by, sort_indices=sort_indices ) resolved = frame._grouper.by_irs(*by).agg_irs(expr.alias(name)).resolve(frame) @@ -462,15 +464,17 @@ def over_ordered( return evaluated return self.from_series(evaluated.broadcast(len(frame)).gather(indices)) - def _is_first_last_distinct( + def _boolean_length_preserving( self, - node: FExpr[IsFirstDistinct | IsLastDistinct], + node: FExpr[IsFirstDistinct | IsLastDistinct | IsUnique | IsDuplicated], frame: Frame, name: str, partition_by: Seq[ir.ExprIR] = (), *, sort_indices: pa.UInt64Array | None = None, ) -> Self: + # NOTE: This subset of functions can be expressed as a mask applied to indices + into_column_agg, mask = fn.BOOLEAN_LENGTH_PRESERVING[type(node.function)] idx_name = temp.column_name(frame) df = frame._with_columns([node.input[0].dispatch(self, frame, name)]) if sort_indices is not None: @@ -478,13 +482,16 @@ def _is_first_last_distinct( df = df._with_native(df.native.add_column(0, idx_name, column)) else: df = df.with_row_index(idx_name) - agg = fn.BOOLEAN_GLUE_FUNCTIONS[type(node.function)](idx_name) + agg_node = into_column_agg(idx_name) if not (partition_by or sort_indices is not None): - distinct = df.group_by_names((name,)).agg((ir.named_ir(idx_name, agg),)) + aggregated = df.group_by_names((name,)).agg( + (ir.named_ir(idx_name, agg_node),) + ) else: - distinct = df.group_by_agg_irs((ir.col(name), *partition_by), agg) + aggregated = df.group_by_agg_irs((ir.col(name), *partition_by), agg_node) index = df.to_series().alias(name) - return self.from_series(index.is_in(distinct.get_column(idx_name))) + final_result = mask(index.native, aggregated.get_column(idx_name).native) + return self.from_series(index._with_native(final_result)) # TODO @dangotbanned: top-level, complex-ish nodes # - [ ] `map_batches` @@ -550,20 +557,22 @@ def mode_any(self, node: FExpr[F.ModeAny], frame: Frame, name: str) -> Scalar: cum_max = _cumulative cum_prod = _cumulative cum_sum = _cumulative - is_first_distinct = _is_first_last_distinct - is_last_distinct = _is_first_last_distinct + is_first_distinct = _boolean_length_preserving + is_last_distinct = _boolean_length_preserving + is_duplicated = _boolean_length_preserving + is_unique = _boolean_length_preserving # TODO @dangotbanned: Plan composing with `functions.cum_*` rolling_expr = not_implemented() - # ewm_mean = not_implemented() # noqa: ERA001 + # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_compliant/series.py#L349-L415 + # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L1060-L1076 + # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L1130-L1215 hist_bins = not_implemented() hist_bin_count = not_implemented() fill_null_with_strategy = not_implemented() - # I think they might both be possible in a similar way to `_is_first_last_distinct` - is_duplicated = not_implemented() - is_unique = not_implemented() + # ewm_mean = not_implemented() # noqa: ERA001 class ArrowScalar( diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 5c31cbcb10..b5ef564b93 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -27,6 +27,7 @@ from typing_extensions import TypeAlias, TypeIs from narwhals._arrow.typing import Incomplete, PromoteOptions + from narwhals._plan.arrow.acero import Field from narwhals._plan.arrow.typing import ( Array, ArrayAny, @@ -35,6 +36,7 @@ BinaryLogical, BinaryNumericTemporal, BinOp, + BooleanLengthPreserving, ChunkedArray, ChunkedArrayAny, ChunkedOrArray, @@ -42,6 +44,7 @@ ChunkedOrArrayT, ChunkedOrScalar, ChunkedOrScalarAny, + ChunkedStruct, DataType, DataTypeRemap, DataTypeT, @@ -55,10 +58,12 @@ ScalarT, StringScalar, StringType, + StructArray, UnaryFunction, VectorFunction, ) from narwhals._plan.options import RankOptions, SortMultipleOptions, SortOptions + from narwhals._plan.typing import OneOrSeq, Seq from narwhals.typing import ClosedInterval, IntoArrowSchema, PythonLiteral BACKEND_VERSION = Implementation.PYARROW._backend_version() @@ -152,28 +157,6 @@ def modulus(lhs: Any, rhs: Any) -> Any: } -def ir_min_max(name: str, /) -> MinMax: - return MinMax(expr=ir.col(name)) - - -BOOLEAN_GLUE_FUNCTIONS: Mapping[type[ir.boolean.BooleanFunction], IntoColumnAgg] = { - ir.boolean.IsFirstDistinct: ir.min, - ir.boolean.IsLastDistinct: ir.max, - ir.boolean.IsUnique: ir_min_max, - ir.boolean.IsDuplicated: ir_min_max, -} -"""Planning to pattern up `_is_first_last_distinct` to work for some other cases. - -The final two lines will need some tweaking, but the same concept is going on: - - index = df.to_series().alias(name) - return self.from_series(index.is_in(distinct.get_column(idx_name))) - - -Will mean at least 2 more functions with `over(*partition_by, order_by=...)` support 😄 -""" - - @t.overload def cast( native: Scalar[Any], target_type: DataTypeT, *, safe: bool | None = ... @@ -233,6 +216,26 @@ def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStr return pa.large_string() if has_large_string(data_types) else pa.string() +@t.overload +def struct_field(native: ChunkedStruct, field: Field, /) -> ChunkedArrayAny: ... +@t.overload +def struct_field( + native: ChunkedStruct, field: Field, *fields: Field +) -> Seq[ChunkedArrayAny]: ... +@t.overload +def struct_field(native: StructArray, field: Field, /) -> ArrayAny: ... +@t.overload +def struct_field(native: StructArray, field: Field, *fields: Field) -> Seq[ArrayAny]: ... +def struct_field( + native: ChunkedOrArrayAny, field: Field, *fields: Field +) -> OneOrSeq[ChunkedOrArrayAny]: + """Retrieve one or multiple `Struct` field(s) as `(Chunked)Array`(s).""" + func = t.cast("Callable[[Any,Any], ChunkedOrArrayAny]", pc.struct_field) + if not fields: + return func(native, field) + return tuple(func(native, name) for name in (field, *fields)) + + def any_(native: Any) -> pa.BooleanScalar: return pc.any(native, min_count=0) @@ -456,6 +459,33 @@ def is_in(values: ArrowAny, /, other: ChunkedOrArrayAny) -> ArrowAny: return is_in_(values, other) # type: ignore[no-any-return] +def ir_min_max(name: str, /) -> MinMax: + return MinMax(expr=ir.col(name)) + + +def _boolean_is_unique( + indices: ChunkedArrayAny, aggregated: ChunkedStruct, / +) -> ChunkedArrayAny: + min, max = struct_field(aggregated, "min", "max") + return and_(is_in(indices, min), is_in(indices, max)) + + +def _boolean_is_duplicated( + indices: ChunkedArrayAny, aggregated: ChunkedStruct, / +) -> ChunkedArrayAny: + return pc.invert(_boolean_is_unique(indices, aggregated)) + + +BOOLEAN_LENGTH_PRESERVING: Mapping[ + type[ir.boolean.BooleanFunction], tuple[IntoColumnAgg, BooleanLengthPreserving] +] = { + ir.boolean.IsFirstDistinct: (ir.min, is_in), + ir.boolean.IsLastDistinct: (ir.max, is_in), + ir.boolean.IsUnique: (ir_min_max, _boolean_is_unique), + ir.boolean.IsDuplicated: (ir_min_max, _boolean_is_duplicated), +} + + def binary( lhs: ChunkedOrScalarAny, op: type[ops.Operator], rhs: ChunkedOrScalarAny ) -> ChunkedOrScalarAny: diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 030cc73cfb..b648e7071b 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -49,6 +49,11 @@ def __call__( self, native: ChunkedArrayAny, *args: P.args, **kwds: P.kwargs ) -> ChunkedArrayAny: ... + class BooleanLengthPreserving(Protocol): + def __call__( + self, indices: ChunkedArrayAny, aggregated: ChunkedArrayAny, / + ) -> ChunkedArrayAny: ... + ScalarT = TypeVar("ScalarT", bound="pa.Scalar[Any]", default="pa.Scalar[Any]") ScalarPT_contra = TypeVar( @@ -152,6 +157,9 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ChunkedOrArrayT = TypeVar("ChunkedOrArrayT", ChunkedArrayAny, ArrayAny) Indices: TypeAlias = "_SizedMultiIndexSelector[ChunkedOrArray[pc.IntegerScalar]]" +ChunkedStruct: TypeAlias = "ChunkedArray[pa.StructScalar]" +StructArray: TypeAlias = "pa.StructArray | Array[pa.StructScalar]" + Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" NativeScalar: TypeAlias = ScalarAny diff --git a/tests/plan/is_duplicated_unique_test.py b/tests/plan/is_duplicated_unique_test.py index 0563dd79e2..7f5ceac074 100644 --- a/tests/plan/is_duplicated_unique_test.py +++ b/tests/plan/is_duplicated_unique_test.py @@ -6,23 +6,12 @@ from narwhals import _plan as nwp from narwhals._plan import selectors as ncs -from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: from tests.conftest import Data -XFAIL_NOT_IMPL = pytest.mark.xfail( - reason="TODO: `ArrowExpr.is_{duplicated,unique}`", raises=NotImplementedError -) - -XFAIL_NOT_GROUP_BY = pytest.mark.xfail( - reason="TODO: `ArrowExpr.is_{duplicated,unique}.over(*partition_by)`", - raises=InvalidOperationError, -) - - @pytest.fixture def data() -> Data: return { @@ -33,7 +22,6 @@ def data() -> Data: } -@XFAIL_NOT_IMPL def test_is_duplicated_unique(data: Data) -> None: expected = { "v1_is_unique": [True, True, False, True, False], @@ -47,12 +35,10 @@ def test_is_duplicated_unique(data: Data) -> None: vals.is_duplicated().name.suffix("_is_duplicated"), ) result = dataframe(data).select("i", *exprs).sort("i").drop("i") - assert_equal_data(result, expected) # pragma: no cover + assert_equal_data(result, expected) # NOTE: Not supported on `main` -# Planning to adapt `is_{first,last}_distinct` idea here -@XFAIL_NOT_GROUP_BY def test_is_duplicated_unique_partitioned(data: Data) -> None: expected = { "v1_is_unique": [True, True, True, True, True], @@ -66,4 +52,4 @@ def test_is_duplicated_unique_partitioned(data: Data) -> None: vals.is_duplicated().name.suffix("_is_duplicated").over("p1"), ) result = dataframe(data).select("i", *exprs).sort("i").drop("i") - assert_equal_data(result, expected) # pragma: no cover + assert_equal_data(result, expected) From b880b21e1b2fd7ac0cb3ab44b3fa19f0e639ea35 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 21 Nov 2025 21:56:52 +0000 Subject: [PATCH 019/215] feat(DRAFT): Port `fill_null_with_strategy` I wanna try rewriting this without `numpy` after getting the tests in place --- narwhals/_plan/arrow/expr.py | 9 +++++++- narwhals/_plan/arrow/functions.py | 37 ++++++++++++++++++++++++++++++- narwhals/_plan/arrow/typing.py | 14 +++++++----- 3 files changed, 52 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index bf6519499f..9f2cd4a975 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -570,7 +570,14 @@ def mode_any(self, node: FExpr[F.ModeAny], frame: Frame, name: str) -> Scalar: # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L1130-L1215 hist_bins = not_implemented() hist_bin_count = not_implemented() - fill_null_with_strategy = not_implemented() + + def fill_null_with_strategy( + self, node: FExpr[F.FillNullWithStrategy], frame: Frame, name: str + ) -> Self: + native = self._dispatch_expr(node.input[0], frame, name).native + strategy, limit = node.function.strategy, node.function.limit + func = fn.fill_null_with_strategy + return self._with_native(func(native, strategy, limit), name) # ewm_mean = not_implemented() # noqa: ERA001 diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index b5ef564b93..370a250545 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -64,7 +64,12 @@ ) from narwhals._plan.options import RankOptions, SortMultipleOptions, SortOptions from narwhals._plan.typing import OneOrSeq, Seq - from narwhals.typing import ClosedInterval, IntoArrowSchema, PythonLiteral + from narwhals.typing import ( + ClosedInterval, + FillNullStrategy, + IntoArrowSchema, + PythonLiteral, + ) BACKEND_VERSION = Implementation.PYARROW._backend_version() """Static backend version for `pyarrow`.""" @@ -427,6 +432,36 @@ def preserve_nulls( drop_nulls = t.cast("VectorFunction[...]", pc.drop_null) +_FILL_NULL_STRATEGY: Mapping[FillNullStrategy, UnaryFunction] = { + "forward": pc.fill_null_forward, + "backward": pc.fill_null_backward, +} + + +def fill_null_with_strategy( + native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None +) -> ChunkedArrayAny: + if limit is None: + return _FILL_NULL_STRATEGY[strategy](native) + import numpy as np # ignore-banned-import + + arr = native + valid_mask = pc.is_valid(arr) + indices = pa.array(np.arange(len(arr)), type=pa.int64()) + if strategy == "forward": + valid_index = np.maximum.accumulate(np.where(valid_mask, indices, -1)) + distance = indices - valid_index + else: + valid_index = np.minimum.accumulate( + np.where(valid_mask[::-1], indices[::-1], len(arr)) + )[::-1] + distance = valid_index - indices + return pc.if_else( # type: ignore[no-any-return] + pc.and_(pc.is_null(arr), pc.less_equal(distance, lit(limit))), # pyright: ignore[reportArgumentType, reportCallIssue] + arr.take(valid_index), + arr, + ) + def is_between( native: ChunkedOrScalar[ScalarT], diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index b648e7071b..ac75cc5453 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -75,28 +75,30 @@ def __call__( class UnaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): @overload - def __call__(self, data: ScalarPT_contra, *args: Any, **kwds: Any) -> ScalarRT_co: ... + def __call__( + self, data: ScalarPT_contra, /, *args: Any, **kwds: Any + ) -> ScalarRT_co: ... @overload def __call__( - self, data: ChunkedArray[ScalarPT_contra], *args: Any, **kwds: Any + self, data: ChunkedArray[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> ChunkedArray[ScalarRT_co]: ... @overload def __call__( - self, data: ChunkedOrScalar[ScalarPT_contra], *args: Any, **kwds: Any + self, data: ChunkedOrScalar[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> ChunkedOrScalar[ScalarRT_co]: ... @overload def __call__( - self, data: Array[ScalarPT_contra], *args: Any, **kwds: Any + self, data: Array[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> Array[ScalarRT_co]: ... @overload def __call__( - self, data: ChunkedOrArray[ScalarPT_contra], *args: Any, **kwds: Any + self, data: ChunkedOrArray[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> ChunkedOrArray[ScalarRT_co]: ... def __call__( - self, data: Arrow[ScalarPT_contra], *args: Any, **kwds: Any + self, data: Arrow[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> Arrow[ScalarRT_co]: ... From df0d9b823913b981bfa0646f29a758ef60f79c72 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 21 Nov 2025 22:24:38 +0000 Subject: [PATCH 020/215] test: Add `fill_null_test` --- tests/plan/fill_null_test.py | 88 ++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tests/plan/fill_null_test.py diff --git a/tests/plan/fill_null_test.py b/tests/plan/fill_null_test.py new file mode 100644 index 0000000000..cc01eeb799 --- /dev/null +++ b/tests/plan/fill_null_test.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from narwhals._plan import selectors as ncs +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + +DATA_1 = { + "a": [0.0, None, 2.0, 3.0, 4.0], + "b": [1.0, None, None, 5.0, 3.0], + "c": [5.0, None, 3.0, 2.0, 1.0], +} +DATA_2 = { + "a": [0.0, None, 2.0, 3.0, 4.0], + "b": [1.0, None, None, 5.0, 3.0], + "c": [5.0, 2.0, None, 2.0, 1.0], +} +DATA_LIMITS = { + "a": [1, None, None, None, 5, 6, None, None, None, 10], + "b": ["a", None, None, None, "b", "c", None, None, None, "d"], + "idx": list(range(10)), +} + + +@pytest.mark.parametrize( + ("data", "exprs", "expected"), + [ + ( # test_fill_null + DATA_1, + nwp.all().fill_null(value=99), + {"a": [0.0, 99, 2, 3, 4], "b": [1.0, 99, 99, 5, 3], "c": [5.0, 99, 3, 2, 1]}, + ), + ( # test_fill_null_w_aggregate + {"a": [0.5, None, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", None, "yy"]}, + [nwp.col("a").fill_null(nwp.col("a").mean()), nwp.col("b").fill_null("a")], + {"a": [0.5, 2.5, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", "a", "yy"]}, + ), + ( # test_fill_null_series_expression + DATA_2, + nwp.nth(0, 1).fill_null(nwp.col("c")), + {"a": [0.0, 2, 2, 3, 4], "b": [1.0, 2, None, 5, 3]}, + ), + ( # test_fill_null_strategies_with_limit_as_none (1) + DATA_LIMITS, + ncs.by_index(0, 1).fill_null(strategy="forward").over(order_by="idx"), + { + "a": [1, 1, 1, 1, 5, 6, 6, 6, 6, 10], + "b": ["a", "a", "a", "a", "b", "c", "c", "c", "c", "d"], + }, + ), + ( # test_fill_null_strategies_with_limit_as_none (2) + DATA_LIMITS, + nwp.exclude("idx").fill_null(strategy="backward").over(order_by="idx"), + { + "a": [1, 5, 5, 5, 5, 6, 10, 10, 10, 10], + "b": ["a", "b", "b", "b", "b", "c", "d", "d", "d", "d"], + }, + ), + ( # test_fill_null_limits (1) + DATA_LIMITS, + nwp.col("a", "b").fill_null(strategy="forward", limit=2).over(order_by="idx"), + { + "a": [1, 1, 1, None, 5, 6, 6, 6, None, 10], + "b": ["a", "a", "a", None, "b", "c", "c", "c", None, "d"], + }, + ), + ( # test_fill_null_limits (2) + DATA_LIMITS, + nwp.col("a", "b") + .fill_null(strategy="backward", limit=2) + .over(order_by="idx"), + { + "a": [1, None, 5, 5, 5, 6, None, 10, 10, 10], + "b": ["a", None, "b", "b", "b", "c", None, "d", "d", "d"], + }, + ), + ], +) +def test_fill_null(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: + df = dataframe(data) + assert_equal_data(df.select(exprs), expected) From 58c18e48b9610788d8ea0f92055f4dd01a385a3c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:50:47 +0000 Subject: [PATCH 021/215] feat: Avoid `numpy` dependency in `fill_null_with_strategy` Got quite a few more ideas to experiment with --- narwhals/_plan/arrow/functions.py | 94 +++++++++++++++++++++---------- narwhals/_plan/arrow/typing.py | 46 +++++++++++---- 2 files changed, 98 insertions(+), 42 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 370a250545..f46d097a57 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -16,6 +16,7 @@ floordiv_compat as floordiv, ) from narwhals._plan import expressions as ir +from narwhals._plan._guards import is_non_nested_literal from narwhals._plan.arrow import options as pa_options from narwhals._plan.expressions import functions as F, operators as ops from narwhals._utils import Implementation @@ -32,7 +33,9 @@ Array, ArrayAny, ArrowAny, + ArrowT, BinaryComp, + BinaryFunction, BinaryLogical, BinaryNumericTemporal, BinOp, @@ -53,6 +56,8 @@ IntegerType, LargeStringType, NativeScalar, + Predicate, + SameArrowT, Scalar, ScalarAny, ScalarT, @@ -68,6 +73,7 @@ ClosedInterval, FillNullStrategy, IntoArrowSchema, + NonNestedLiteral, PythonLiteral, ) @@ -118,8 +124,9 @@ class MinMax(ir.AggExpr): add = t.cast("BinaryNumericTemporal", pc.add) -sub = pc.subtract +sub = t.cast("BinaryNumericTemporal", pc.subtract) multiply = pc.multiply +power = t.cast("BinaryFunction[pc.NumericScalar, pc.NumericScalar]", pc.power) def truediv(lhs: Any, rhs: Any) -> Any: @@ -241,6 +248,26 @@ def struct_field( return tuple(func(native, name) for name in (field, *fields)) +@t.overload +def when_then( + predicate: Predicate, then: SameArrowT, otherwise: SameArrowT +) -> SameArrowT: ... +@t.overload +def when_then( + predicate: Predicate, then: ArrowT, otherwise: NonNestedLiteral = ... +) -> ArrowT: ... +@t.overload +def when_then( + predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None +) -> Incomplete: ... +def when_then( + predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None +) -> Incomplete: + if is_non_nested_literal(otherwise): + otherwise = lit(otherwise, then.type) + return pc.if_else(predicate, then, otherwise) + + def any_(native: Any) -> pa.BooleanScalar: return pc.any(native, min_count=0) @@ -267,7 +294,7 @@ def last(native: ChunkedOrArrayAny) -> NativeScalar: min_horizontal = pc.min_element_wise max_ = pc.max max_horizontal = pc.max_element_wise -mean = pc.mean +mean = t.cast("Callable[[ChunkedOrArray[pc.NumericScalar]], pa.DoubleScalar]", pc.mean) count = pc.count median = pc.approximate_median std = pc.stddev @@ -287,7 +314,7 @@ def mode_any(native: ChunkedArrayAny) -> NativeScalar: def kurtosis_skew( - native: ChunkedOrArrayAny, function: Literal["kurtosis", "skew"], / + native: ChunkedArray[pc.NumericScalar], function: Literal["kurtosis", "skew"], / ) -> NativeScalar: result: NativeScalar if HAS_KURTOSIS_SKEW: @@ -304,13 +331,13 @@ def kurtosis_skew( result = lit(0.0, F64) else: m = sub(non_null, mean(non_null)) - m2 = mean(pc.power(m, lit(2))) + m2 = mean(power(m, lit(2))) if function == "kurtosis": - m4 = mean(pc.power(m, lit(4))) - result = sub(pc.divide(m4, pc.power(m2, lit(2))), lit(3)) + m4 = mean(power(m, lit(4))) + result = sub(pc.divide(m4, power(m2, lit(2))), lit(3)) else: - m3 = mean(pc.power(m, lit(3))) - result = pc.divide(m3, pc.power(m2, lit(1.5))) + m3 = mean(power(m, lit(3))) + result = pc.divide(m3, power(m2, lit(1.5))) return result @@ -418,16 +445,10 @@ def null_count(native: ChunkedOrArrayAny) -> pa.Int64Scalar: return pc.count(native, mode="only_null") -def has_nulls(native: ChunkedOrArrayAny) -> bool: - return bool(native.null_count) - - def preserve_nulls( before: ChunkedOrArrayAny, after: ChunkedOrArrayT, / ) -> ChunkedOrArrayT: - if has_nulls(before): - after = pc.if_else(before.is_null(), lit(None, after.type), after) - return after + return when_then(is_not_null(before), after) if before.null_count else after drop_nulls = t.cast("VectorFunction[...]", pc.drop_null) @@ -443,23 +464,34 @@ def fill_null_with_strategy( ) -> ChunkedArrayAny: if limit is None: return _FILL_NULL_STRATEGY[strategy](native) - import numpy as np # ignore-banned-import - - arr = native - valid_mask = pc.is_valid(arr) - indices = pa.array(np.arange(len(arr)), type=pa.int64()) + # NOTE: Original impl comment by @IsaiasGutierrezCruz + # > this algorithm first finds the indices of the valid values to fill all the null value positions + # > then it calculates the distance of each new index and the original index + # > if the distance is equal to or less than the limit and the original value is null, it is replaced + + # TODO @dangotbanned: Return early if we don't have nulls + # TODO @dangotbanned: Fastpaths for length 1 (and 0 if not covered by above) + # TODO @dangotbanned: Can we do this *without* generating a range? + # TODO @dangotbanned: Can we do this *without* using a cumulative function? + valid_mask = native.is_valid() + length = len(native) + indices = int_range(length, chunked=False) if strategy == "forward": - valid_index = np.maximum.accumulate(np.where(valid_mask, indices, -1)) - distance = indices - valid_index + valid_index_or_sentinel = when_then(valid_mask, indices, -1) + valid_index = cum_max(valid_index_or_sentinel) + distance = sub(indices, valid_index) else: - valid_index = np.minimum.accumulate( - np.where(valid_mask[::-1], indices[::-1], len(arr)) - )[::-1] - distance = valid_index - indices - return pc.if_else( # type: ignore[no-any-return] - pc.and_(pc.is_null(arr), pc.less_equal(distance, lit(limit))), # pyright: ignore[reportArgumentType, reportCallIssue] - arr.take(valid_index), - arr, + # TODO @dangotbanned: Every reverse is a full-copy, try to avoid it + # - Does this really need 3x `reverse`? + # - Can we generate any of these in the desired to start with? + valid_index_or_sentinel = when_then(reverse(valid_mask), reverse(indices), length) # type: ignore[assignment] + valid_index = reverse(cum_min(valid_index_or_sentinel)) + distance = sub(valid_index, indices) + # TODO @dangotbanned: Rewrite this to reuse the `is_valid` we have already as the predicate + return when_then( + and_(native.is_null(), lt_eq(distance, lit(limit))), + native.take(valid_index), + native, ) @@ -470,7 +502,7 @@ def is_between( closed: ClosedInterval, ) -> ChunkedOrScalar[pa.BooleanScalar]: fn_lhs, fn_rhs = _IS_BETWEEN[closed] - return and_(fn_lhs(native, lower), fn_rhs(native, upper)) + return and_(fn_lhs(native, lower), fn_rhs(native, upper)) # type: ignore[no-any-return] @t.overload diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index ac75cc5453..77c6eae71c 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -67,9 +67,7 @@ def __call__( ) NumericOrTemporalScalar: TypeAlias = "pc.NumericOrTemporalScalar" NumericOrTemporalScalarT = TypeVar( - "NumericOrTemporalScalarT", - bound=NumericOrTemporalScalar, - default=NumericOrTemporalScalar, + "NumericOrTemporalScalarT", bound=NumericOrTemporalScalar, default="pc.NumericScalar" ) @@ -103,32 +101,53 @@ def __call__( class BinaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): + @overload + def __call__( + self, x: ChunkedArray[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / + ) -> ChunkedArray[ScalarRT_co]: ... + @overload + def __call__( + self, x: Array[ScalarPT_contra], y: Array[ScalarPT_contra], / + ) -> Array[ScalarRT_co]: ... @overload def __call__(self, x: ScalarPT_contra, y: ScalarPT_contra, /) -> ScalarRT_co: ... - @overload def __call__( - self, x: ChunkedArray[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / + self, x: ChunkedArray[ScalarPT_contra], y: ScalarPT_contra, / ) -> ChunkedArray[ScalarRT_co]: ... - + @overload + def __call__( + self, x: Array[ScalarPT_contra], y: ScalarPT_contra, / + ) -> Array[ScalarRT_co]: ... @overload def __call__( self, x: ScalarPT_contra, y: ChunkedArray[ScalarPT_contra], / ) -> ChunkedArray[ScalarRT_co]: ... - @overload def __call__( - self, x: ChunkedArray[ScalarPT_contra], y: ScalarPT_contra, / + self, x: ScalarPT_contra, y: Array[ScalarPT_contra], / + ) -> Array[ScalarRT_co]: ... + @overload + def __call__( + self, x: ChunkedArray[ScalarPT_contra], y: Array[ScalarPT_contra], / + ) -> ChunkedArray[ScalarRT_co]: ... + @overload + def __call__( + self, x: Array[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / ) -> ChunkedArray[ScalarRT_co]: ... - @overload def __call__( self, x: ChunkedOrScalar[ScalarPT_contra], y: ChunkedOrScalar[ScalarPT_contra], / ) -> ChunkedOrScalar[ScalarRT_co]: ... + @overload def __call__( - self, x: ChunkedOrScalar[ScalarPT_contra], y: ChunkedOrScalar[ScalarPT_contra], / - ) -> ChunkedOrScalar[ScalarRT_co]: ... + self, x: Arrow[ScalarPT_contra], y: Arrow[ScalarPT_contra], / + ) -> Arrow[ScalarRT_co]: ... + + def __call__( + self, x: Arrow[ScalarPT_contra], y: Arrow[ScalarPT_contra], / + ) -> Arrow[ScalarRT_co]: ... class BinaryComp( @@ -164,6 +183,11 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" +SameArrowT = TypeVar("SameArrowT", ChunkedArrayAny, ArrayAny, ScalarAny) +ArrowT = TypeVar("ArrowT", bound=ArrowAny) +Predicate: TypeAlias = "Arrow[pa.BooleanScalar]" +"""Any `pyarrow` container that wraps boolean.""" + NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] StoresNativeT_co = TypeVar( From 2f632b49e4e032c9fa9c9b721b5fc1563873b98a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 16:32:54 +0000 Subject: [PATCH 022/215] test: Reveal index out-of-bounds errors --- tests/plan/fill_null_test.py | 38 +++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/plan/fill_null_test.py b/tests/plan/fill_null_test.py index cc01eeb799..058bb58e23 100644 --- a/tests/plan/fill_null_test.py +++ b/tests/plan/fill_null_test.py @@ -8,6 +8,10 @@ from narwhals._plan import selectors as ncs from tests.plan.utils import assert_equal_data, dataframe +pytest.importorskip("pyarrow") + +import pyarrow as pa + if TYPE_CHECKING: from narwhals._plan.typing import OneOrIterable from tests.conftest import Data @@ -25,10 +29,20 @@ DATA_LIMITS = { "a": [1, None, None, None, 5, 6, None, None, None, 10], "b": ["a", None, None, None, "b", "c", None, None, None, "d"], + "c": [None, 2.5, None, None, None, None, 3.6, None, 2.2, 3.0], + "d": [1, None, None, None, None, None, None, None, 2, None], "idx": list(range(10)), } +# TODO @dangotbanned: Fix this in the new version +# Then open an issue demonstrating the bug +XFAIL_INHERITED_INDEX_ERROR = pytest.mark.xfail( + reason="Bug in the implementation on `main` for `fill_null(limit=...)`.", + raises=pa.ArrowIndexError, +) + + @pytest.mark.parametrize( ("data", "exprs", "expected"), [ @@ -61,6 +75,8 @@ { "a": [1, 5, 5, 5, 5, 6, 10, 10, 10, 10], "b": ["a", "b", "b", "b", "b", "c", "d", "d", "d", "d"], + "c": [2.5, 2.5, 3.6, 3.6, 3.6, 3.6, 3.6, 2.2, 2.2, 3.0], + "d": [1, 2, 2, 2, 2, 2, 2, 2, 2, None], }, ), ( # test_fill_null_limits (1) @@ -73,14 +89,30 @@ ), ( # test_fill_null_limits (2) DATA_LIMITS, - nwp.col("a", "b") - .fill_null(strategy="backward", limit=2) - .over(order_by="idx"), + [ + nwp.col("a", "b") + .fill_null(strategy="backward", limit=2) + .over(order_by="idx"), + nwp.col("c").fill_null(strategy="backward", limit=3).over(order_by="idx"), + ], { "a": [1, None, 5, 5, 5, 6, None, 10, 10, 10], "b": ["a", None, "b", "b", "b", "c", None, "d", "d", "d"], + "c": [2.5, 2.5, None, 3.6, 3.6, 3.6, 3.6, 2.2, 2.2, 3.0], }, ), + pytest.param( + DATA_LIMITS, + nwp.col("c").fill_null(strategy="forward", limit=3).over(order_by="idx"), + {"c": [None, 2.5, 2.5, 2.5, 2.5, None, 3.6, 3.6, 2.2, 3.0]}, + marks=XFAIL_INHERITED_INDEX_ERROR, + ), + pytest.param( + DATA_LIMITS, + nwp.col("d").fill_null(strategy="backward", limit=3).over(order_by="idx"), + {"d": [1, None, None, None, None, 2, 2, 2, 2, None]}, + marks=XFAIL_INHERITED_INDEX_ERROR, + ), ], ) def test_fill_null(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: From d65a73b1b4a92fe4cf2c9a01a63b11304f113490 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 17:29:54 +0000 Subject: [PATCH 023/215] fix: Avoid index OOB error --- narwhals/_plan/arrow/functions.py | 10 ++++++++-- tests/plan/fill_null_test.py | 22 ++++++---------------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index f46d097a57..9beea56864 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -478,14 +478,20 @@ def fill_null_with_strategy( indices = int_range(length, chunked=False) if strategy == "forward": valid_index_or_sentinel = when_then(valid_mask, indices, -1) - valid_index = cum_max(valid_index_or_sentinel) + almost_valid_index = cum_max(valid_index_or_sentinel) + # NOTE: The correction here is for nulls at either end of the array + # They should be preserved when the fill direction would need an extra element + valid_index = when_then(not_eq(almost_valid_index, lit(-1)), almost_valid_index) distance = sub(indices, valid_index) else: # TODO @dangotbanned: Every reverse is a full-copy, try to avoid it # - Does this really need 3x `reverse`? # - Can we generate any of these in the desired to start with? valid_index_or_sentinel = when_then(reverse(valid_mask), reverse(indices), length) # type: ignore[assignment] - valid_index = reverse(cum_min(valid_index_or_sentinel)) + almost_valid_index = reverse(cum_min(valid_index_or_sentinel)) + valid_index = when_then( + not_eq(almost_valid_index, lit(length)), almost_valid_index + ) distance = sub(valid_index, indices) # TODO @dangotbanned: Rewrite this to reuse the `is_valid` we have already as the predicate return when_then( diff --git a/tests/plan/fill_null_test.py b/tests/plan/fill_null_test.py index 058bb58e23..ca6d89d89f 100644 --- a/tests/plan/fill_null_test.py +++ b/tests/plan/fill_null_test.py @@ -8,10 +8,6 @@ from narwhals._plan import selectors as ncs from tests.plan.utils import assert_equal_data, dataframe -pytest.importorskip("pyarrow") - -import pyarrow as pa - if TYPE_CHECKING: from narwhals._plan.typing import OneOrIterable from tests.conftest import Data @@ -35,14 +31,10 @@ } -# TODO @dangotbanned: Fix this in the new version -# Then open an issue demonstrating the bug -XFAIL_INHERITED_INDEX_ERROR = pytest.mark.xfail( - reason="Bug in the implementation on `main` for `fill_null(limit=...)`.", - raises=pa.ArrowIndexError, -) - - +# TODO @dangotbanned: Address index out-of-bounds error +# - [x] Fix this in the new version +# - [ ] Open an issue demonstrating the bug +# - Same problem impacts `main` for `fill_null(limit=...)` @pytest.mark.parametrize( ("data", "exprs", "expected"), [ @@ -101,17 +93,15 @@ "c": [2.5, 2.5, None, 3.6, 3.6, 3.6, 3.6, 2.2, 2.2, 3.0], }, ), - pytest.param( + ( DATA_LIMITS, nwp.col("c").fill_null(strategy="forward", limit=3).over(order_by="idx"), {"c": [None, 2.5, 2.5, 2.5, 2.5, None, 3.6, 3.6, 2.2, 3.0]}, - marks=XFAIL_INHERITED_INDEX_ERROR, ), - pytest.param( + ( DATA_LIMITS, nwp.col("d").fill_null(strategy="backward", limit=3).over(order_by="idx"), {"d": [1, None, None, None, None, 2, 2, 2, 2, None]}, - marks=XFAIL_INHERITED_INDEX_ERROR, ), ], ) From bd2b75a8a7911464558ced6071d4bbc405f7d354 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 17:44:08 +0000 Subject: [PATCH 024/215] perf: Reuse `is_valid` mask --- narwhals/_plan/arrow/functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 9beea56864..a8124431fb 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -493,12 +493,8 @@ def fill_null_with_strategy( not_eq(almost_valid_index, lit(length)), almost_valid_index ) distance = sub(valid_index, indices) - # TODO @dangotbanned: Rewrite this to reuse the `is_valid` we have already as the predicate - return when_then( - and_(native.is_null(), lt_eq(distance, lit(limit))), - native.take(valid_index), - native, - ) + preserve = or_(valid_mask, gt(distance, lit(limit))) + return when_then(preserve, native, native.take(valid_index)) def is_between( From 3f7900bc77bad73f3fa2acfea90131aaed27e351 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 18:03:20 +0000 Subject: [PATCH 025/215] perf: Avoid 1/3 reverses for `fill_null("backward",limit=...)` Each of these are expensive + this version is simpler --- narwhals/_plan/arrow/functions.py | 48 ++++++++++++++----------------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index a8124431fb..80503cb559 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -459,44 +459,38 @@ def preserve_nulls( } -def fill_null_with_strategy( - native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None -) -> ChunkedArrayAny: - if limit is None: - return _FILL_NULL_STRATEGY[strategy](native) +def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArrayAny: # NOTE: Original impl comment by @IsaiasGutierrezCruz # > this algorithm first finds the indices of the valid values to fill all the null value positions # > then it calculates the distance of each new index and the original index # > if the distance is equal to or less than the limit and the original value is null, it is replaced - - # TODO @dangotbanned: Return early if we don't have nulls - # TODO @dangotbanned: Fastpaths for length 1 (and 0 if not covered by above) - # TODO @dangotbanned: Can we do this *without* generating a range? - # TODO @dangotbanned: Can we do this *without* using a cumulative function? valid_mask = native.is_valid() length = len(native) + # TODO @dangotbanned: Can we do this *without* generating a range? indices = int_range(length, chunked=False) - if strategy == "forward": - valid_index_or_sentinel = when_then(valid_mask, indices, -1) - almost_valid_index = cum_max(valid_index_or_sentinel) - # NOTE: The correction here is for nulls at either end of the array - # They should be preserved when the fill direction would need an extra element - valid_index = when_then(not_eq(almost_valid_index, lit(-1)), almost_valid_index) - distance = sub(indices, valid_index) - else: - # TODO @dangotbanned: Every reverse is a full-copy, try to avoid it - # - Does this really need 3x `reverse`? - # - Can we generate any of these in the desired to start with? - valid_index_or_sentinel = when_then(reverse(valid_mask), reverse(indices), length) # type: ignore[assignment] - almost_valid_index = reverse(cum_min(valid_index_or_sentinel)) - valid_index = when_then( - not_eq(almost_valid_index, lit(length)), almost_valid_index - ) - distance = sub(valid_index, indices) + valid_index_or_sentinel = when_then(valid_mask, indices, -1) + # TODO @dangotbanned: Can we do this *without* using a cumulative function? + almost_valid_index = cum_max(valid_index_or_sentinel) + # NOTE: The correction here is for nulls at either end of the array + # They should be preserved when the fill direction would need an extra element + valid_index = when_then(not_eq(almost_valid_index, lit(-1)), almost_valid_index) + distance = sub(indices, valid_index) preserve = or_(valid_mask, gt(distance, lit(limit))) return when_then(preserve, native, native.take(valid_index)) +def fill_null_with_strategy( + native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None +) -> ChunkedArrayAny: + # TODO @dangotbanned: Return early if we don't have nulls + # TODO @dangotbanned: Fastpaths for length 1 (and 0 if not covered by above) + if limit is None: + return _FILL_NULL_STRATEGY[strategy](native) + if strategy == "forward": + return _fill_null_forward_limit(native, limit) + return reverse(_fill_null_forward_limit(reverse(native), limit)) + + def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT], From 818471a8c9c38113f7ba910d1af6926e237b0300 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 18:26:31 +0000 Subject: [PATCH 026/215] perf: Noop when we have nothing to fill/fill with https://github.com/pola-rs/polars/blob/e1d6f294218a36497255e2d872c223e19a47e2ec/crates/polars-core/src/chunked_array/ops/fill_null.rs#L58-L73 --- narwhals/_plan/arrow/functions.py | 5 +++-- tests/plan/fill_null_test.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 80503cb559..65984c91ea 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -482,8 +482,9 @@ def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArra def fill_null_with_strategy( native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None ) -> ChunkedArrayAny: - # TODO @dangotbanned: Return early if we don't have nulls - # TODO @dangotbanned: Fastpaths for length 1 (and 0 if not covered by above) + null_count = native.null_count + if null_count == 0 or (null_count == len(native)): + return native if limit is None: return _FILL_NULL_STRATEGY[strategy](native) if strategy == "forward": diff --git a/tests/plan/fill_null_test.py b/tests/plan/fill_null_test.py index ca6d89d89f..1a752cc6d0 100644 --- a/tests/plan/fill_null_test.py +++ b/tests/plan/fill_null_test.py @@ -108,3 +108,20 @@ def test_fill_null(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: df = dataframe(data) assert_equal_data(df.select(exprs), expected) + + +@pytest.mark.parametrize( + "expr", + [ + (~ncs.last()).fill_null(strategy="forward"), + (~ncs.last()).fill_null(strategy="backward"), + (~ncs.last()).fill_null(strategy="forward", limit=100), + (~ncs.last()).fill_null(strategy="backward", limit=20), + ], +) +def test_fill_null_strategy_noop(expr: nwp.Expr) -> None: + data = {"a": [1, 2, 3], "b": [None, None, None], "i": [0, 1, 2]} + expected = {"a": [1, 2, 3], "b": [None, None, None]} + df = dataframe(data) + assert_equal_data(df.select(expr), expected) + assert_equal_data(df.select(expr.over(order_by=ncs.last())), expected) From 0e82e2090f6966bff6884ebd83c48355f548dc69 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 18:41:43 +0000 Subject: [PATCH 027/215] chore: renaming --- narwhals/_plan/arrow/functions.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 65984c91ea..4408d6ae7f 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -464,19 +464,20 @@ def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArra # > this algorithm first finds the indices of the valid values to fill all the null value positions # > then it calculates the distance of each new index and the original index # > if the distance is equal to or less than the limit and the original value is null, it is replaced - valid_mask = native.is_valid() - length = len(native) + SENTINEL = lit(-1) # noqa: N806 + is_not_null = native.is_valid() # TODO @dangotbanned: Can we do this *without* generating a range? - indices = int_range(length, chunked=False) - valid_index_or_sentinel = when_then(valid_mask, indices, -1) + index = int_range(len(native), chunked=False) # TODO @dangotbanned: Can we do this *without* using a cumulative function? - almost_valid_index = cum_max(valid_index_or_sentinel) + index_not_null_almost = cum_max(when_then(is_not_null, index, SENTINEL)) # NOTE: The correction here is for nulls at either end of the array # They should be preserved when the fill direction would need an extra element - valid_index = when_then(not_eq(almost_valid_index, lit(-1)), almost_valid_index) - distance = sub(indices, valid_index) - preserve = or_(valid_mask, gt(distance, lit(limit))) - return when_then(preserve, native, native.take(valid_index)) + index_not_null = when_then( + not_eq(index_not_null_almost, SENTINEL), index_not_null_almost + ) + distance = sub(index, index_not_null) + preserve = or_(is_not_null, gt(distance, lit(limit))) + return when_then(preserve, native, native.take(index_not_null)) def fill_null_with_strategy( From 50e95af5f9a7d3f661fdb2516c5c89a5e5e84482 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 21:44:33 +0000 Subject: [PATCH 028/215] refactor: Finalize `_fill_null_forward_limit` Managed to write it with one less `if_else`, but the readability suffered so this will do --- narwhals/_plan/arrow/functions.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 4408d6ae7f..480df9360c 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -460,24 +460,16 @@ def preserve_nulls( def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArrayAny: - # NOTE: Original impl comment by @IsaiasGutierrezCruz - # > this algorithm first finds the indices of the valid values to fill all the null value positions - # > then it calculates the distance of each new index and the original index - # > if the distance is equal to or less than the limit and the original value is null, it is replaced SENTINEL = lit(-1) # noqa: N806 is_not_null = native.is_valid() - # TODO @dangotbanned: Can we do this *without* generating a range? index = int_range(len(native), chunked=False) - # TODO @dangotbanned: Can we do this *without* using a cumulative function? - index_not_null_almost = cum_max(when_then(is_not_null, index, SENTINEL)) + index_not_null = cum_max(when_then(is_not_null, index, SENTINEL)) # NOTE: The correction here is for nulls at either end of the array - # They should be preserved when the fill direction would need an extra element - index_not_null = when_then( - not_eq(index_not_null_almost, SENTINEL), index_not_null_almost - ) - distance = sub(index, index_not_null) - preserve = or_(is_not_null, gt(distance, lit(limit))) - return when_then(preserve, native, native.take(index_not_null)) + # They should be preserved when the `strategy` would need an out-of-bounds index + not_oob = not_eq(index_not_null, SENTINEL) + index_not_null = when_then(not_oob, index_not_null) + beyond_limit = gt(sub(index, index_not_null), lit(limit)) + return when_then(or_(is_not_null, beyond_limit), native, native.take(index_not_null)) def fill_null_with_strategy( From 591b700760f1252d433f526c9bcd9438753dbf5e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 21:45:31 +0000 Subject: [PATCH 029/215] tidy --- narwhals/_plan/arrow/expr.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 9f2cd4a975..8accb891ba 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -552,6 +552,14 @@ def mode_any(self, node: FExpr[F.ModeAny], frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.input[0], frame, name).native return self._with_native(fn.mode_any(native), name) + def fill_null_with_strategy( + self, node: FExpr[F.FillNullWithStrategy], frame: Frame, name: str + ) -> Self: + native = self._dispatch_expr(node.input[0], frame, name).native + strategy, limit = node.function.strategy, node.function.limit + func = fn.fill_null_with_strategy + return self._with_native(func(native, strategy, limit), name) + cum_count = _cumulative cum_min = _cumulative cum_max = _cumulative @@ -571,14 +579,6 @@ def mode_any(self, node: FExpr[F.ModeAny], frame: Frame, name: str) -> Scalar: hist_bins = not_implemented() hist_bin_count = not_implemented() - def fill_null_with_strategy( - self, node: FExpr[F.FillNullWithStrategy], frame: Frame, name: str - ) -> Self: - native = self._dispatch_expr(node.input[0], frame, name).native - strategy, limit = node.function.strategy, node.function.limit - func = fn.fill_null_with_strategy - return self._with_native(func(native, strategy, limit), name) - # ewm_mean = not_implemented() # noqa: ERA001 From 103e4d4437d82c1968b300146a8d8f8348e38bd1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 22:37:53 +0000 Subject: [PATCH 030/215] start adding `replace_strict` --- narwhals/_plan/arrow/expr.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 8accb891ba..e8787251e8 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -240,7 +240,35 @@ def clip_upper( ) return self._with_native(result, name) - replace_strict = not_implemented() # type: ignore[misc] + # https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L772-L812 + # TODO @dangotbanned: Handle `Scalar` + # TODO @dangotbanned: Update `F.ReplaceStrict` to have `default` + handle it + def replace_strict( + self, node: FExpr[F.ReplaceStrict], frame: Frame, name: str + ) -> StoresNativeT_co: + compliant = node.input[0].dispatch(self, frame, name) + native = compliant.native + old, new = node.function.old, node.function.new + if isinstance(native, pa.Scalar): + msg = "TODO: `scalar.replace_strict`" + raise NotImplementedError(msg) + idxs = pc.index_in(native, pa.array(old)) + result = pa.array(new).take(idxs) + if dtype := node.function.return_dtype: + result = result.cast(narwhals_to_native_dtype(dtype, self.version)) + if result.null_count != native.null_count: + replace_failed = ( + native.filter(fn.and_(fn.is_not_null(native), result.is_null())) + .unique() + .to_pylist() + ) + msg = ( + "replace_strict did not replace all non-null values.\n\n" + "The following did not get replaced: " + f"{replace_failed}" + ) + raise ValueError(msg) + return self._with_native(fn.chunked_array(result), name) class ArrowExpr( # type: ignore[misc] From a19d38a5e6e4599a9a89be5ea7646292f396071a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 22 Nov 2025 22:56:01 +0000 Subject: [PATCH 031/215] update `Expr.replace_strict` signature, partial parsing - Maybe split this up into 2-3 versions? - polars supports `Expr` for all 3 parameters --- narwhals/_plan/expr.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 9c25041111..1da0d7cc08 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -24,7 +24,7 @@ SortOptions, rolling_options, ) -from narwhals._utils import Version +from narwhals._utils import Version, no_default from narwhals.exceptions import ComputeError if TYPE_CHECKING: @@ -42,6 +42,7 @@ from narwhals._plan.expressions.temporal import ExprDateTimeNamespace from narwhals._plan.meta import MetaNamespace from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf + from narwhals._typing import NoDefault from narwhals.typing import ( ClosedInterval, FillNullStrategy, @@ -347,16 +348,18 @@ def ewm_mean( ) return self._with_unary(F.EwmMean(options=options)) + # TODO @dangotbanned: Update to support `default` def replace_strict( self, old: Sequence[Any] | Mapping[Any, Any], - new: Sequence[Any] | None = None, + new: Sequence[Any] | NoDefault = no_default, *, + default: IntoExpr | NoDefault = no_default, return_dtype: IntoDType | None = None, ) -> Self: before: Seq[Any] after: Seq[Any] - if new is None: + if new is no_default: if not isinstance(old, Mapping): msg = "`new` argument is required if `old` argument is not a Mapping type" raise TypeError(msg) @@ -370,6 +373,14 @@ def replace_strict( after = tuple(new) if return_dtype is not None: return_dtype = common.into_dtype(return_dtype) + + if default is no_default: + ... + else: + default_ir = parse_into_expr_ir(default, str_as_lit=True) + msg = f"replace_strict(default={default_ir!r})" + raise NotImplementedError(msg) + function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) return self._with_unary(function) From ce57dcc30913fc9503c2adc7ea4f9a1dd6e7083f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 12:50:12 +0000 Subject: [PATCH 032/215] test: Port non-default tests Covering more than just a single (column) dtype here --- tests/plan/replace_strict_test.py | 96 +++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 tests/plan/replace_strict_test.py diff --git a/tests/plan/replace_strict_test.py b/tests/plan/replace_strict_test.py new file mode 100644 index 0000000000..fbf8e717c1 --- /dev/null +++ b/tests/plan/replace_strict_test.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from narwhals.exceptions import InvalidOperationError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping, Sequence + + from _pytest.mark import ParameterSet + from typing_extensions import TypeAlias + + from narwhals._typing import NoDefault + from narwhals.typing import IntoDType + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +Old: TypeAlias = "Sequence[Any] | Mapping[Any, Any]" +New: TypeAlias = "Sequence[Any] | NoDefault" + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "str": ["one", "two", "three", "four"], + "int": [1, 2, 3, 4], + "str-null": ["one", None, "three", "four"], + "int-null": [1, 2, None, 4], + "str-alt": ["beluga", "narwhal", "orca", "vaquita"], + } + + +def basic_cases( + column: str, + replacements: Mapping[Any, Any], + return_dtypes: Iterable[IntoDType | None], +) -> Iterator[ParameterSet]: + old, new = list(replacements), tuple(replacements.values()) + values = list(new) + base = nwp.col(column) + alt_name = f"{column}_seqs" + alt = nwp.col(column).alias(alt_name) + expected = {column: values, alt_name: values} + for dtype in return_dtypes: + exprs = ( + base.replace_strict(replacements, return_dtype=dtype), + alt.replace_strict(old, new, return_dtype=dtype), + ) + schema = {column: dtype, alt_name: dtype} if dtype is not None else None + yield pytest.param(exprs, expected, schema, id=f"{column}-{dtype}") + + +@pytest.mark.parametrize( + ("exprs", "expected", "schema"), + [ + *basic_cases( + "str", + {"one": 1, "two": 2, "three": 3, "four": 4}, + [nw.Int8, nw.Float32, None], + ), + *basic_cases( + "int", {1: "one", 2: "two", 3: "three", 4: "four"}, [nw.String(), None] + ), + ], +) +def test_replace_strict_expr_basic( + data: Data, + exprs: Iterable[nwp.Expr], + expected: Data, + schema: Mapping[str, IntoDType] | None, +) -> None: + result = dataframe(data).select(exprs) + assert_equal_data(result, expected) + if schema is not None: + assert result.collect_schema() == schema + + +@pytest.mark.parametrize( + "expr", + [ + nwp.col("int").replace_strict([1, 3], [3, 4]), + nwp.col("str-null").replace_strict({"one": "two", "four": "five"}), + ], +) +def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: + with pytest.raises( + (ValueError, InvalidOperationError), match=r"did not replace all non-null" + ): + dataframe(data).select(expr) From 2eab41a3ca0e14f80b8f127f01f180a86ffb37d4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 14:06:17 +0000 Subject: [PATCH 033/215] test: Port the default tests --- tests/plan/replace_strict_test.py | 73 +++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/plan/replace_strict_test.py b/tests/plan/replace_strict_test.py index fbf8e717c1..f42956352a 100644 --- a/tests/plan/replace_strict_test.py +++ b/tests/plan/replace_strict_test.py @@ -94,3 +94,76 @@ def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: (ValueError, InvalidOperationError), match=r"did not replace all non-null" ): dataframe(data).select(expr) + + +XFAIL_DEFAULT = pytest.mark.xfail( + reason="Not Implemented `replace_strict(default=...)` yet", raises=ValueError +) + + +# TODO @dangotbanned: Share more of the case generation logic from `basic_cases` +@pytest.mark.parametrize( + ("expr", "expected"), + [ + # test_replace_strict_expr_with_default + pytest.param( + nwp.col("int").replace_strict( + [1, 2], ["one", "two"], default=nwp.lit("other"), return_dtype=nw.String + ), + {"int": ["one", "two", "other", "other"]}, + marks=XFAIL_DEFAULT, + id="non-null-1", + ), + pytest.param( + nwp.col("int").replace_strict([1, 2], ["one", "two"], default="other"), + {"int": ["one", "two", "other", "other"]}, + marks=XFAIL_DEFAULT, + id="non-null-2", + ), + # test_replace_strict_with_default_and_nulls + pytest.param( + nwp.col("int-null").replace_strict( + [1, 2], [10, 20], default=99, return_dtype=nw.Int64 + ), + {"int-null": [10, 20, 99, 99]}, + marks=XFAIL_DEFAULT, + id="null-1", + ), + pytest.param( + nwp.col("int-null").replace_strict([1, 2], [10, 20], default=99), + {"int-null": [10, 20, 99, 99]}, + marks=XFAIL_DEFAULT, + id="null-2", + ), + # test_replace_strict_with_default_mapping + pytest.param( + nwp.col("int").replace_strict( + {1: "one", 2: "two", 3: None}, default="other", return_dtype=nw.String() + ), + {"int": ["one", "two", None, "other"]}, + marks=XFAIL_DEFAULT, + # shouldn't be an independent case, the mapping isn't the default + id="replace_strict_with_default_mapping", + ), + # test_replace_strict_with_expressified_default + pytest.param( + nwp.col("int").replace_strict( + {1: "one", 2: "two"}, default=nwp.col("str-alt"), return_dtype=nw.String + ), + {"int": ["one", "two", "orca", "vaquita"]}, + marks=XFAIL_DEFAULT, + id="column", + ), + # test_mapping_key_not_in_expr + pytest.param( + nwp.col("int").replace_strict( + {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"}, default="hundred" + ), + {"int": ["one", "two", "three", "four"]}, + id="mapping_key_not_in_expr", + ), + ], +) +def test_replace_strict_expr_default(data: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data).select(expr) + assert_equal_data(result, expected) From 7ec10f777f84512a7e2f3a1e26cc79733a526e4a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 14:29:08 +0000 Subject: [PATCH 034/215] feat(DRAFT): Support `replace_strict(default=...)` --- narwhals/_plan/arrow/expr.py | 20 ++++++++++++++++++-- narwhals/_plan/compliant/expr.py | 3 +++ narwhals/_plan/expr.py | 15 +++++++-------- narwhals/_plan/expressions/functions.py | 6 ++++++ tests/plan/replace_strict_test.py | 11 ----------- 5 files changed, 34 insertions(+), 21 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index e8787251e8..3f158873a6 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -240,9 +240,7 @@ def clip_upper( ) return self._with_native(result, name) - # https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L772-L812 # TODO @dangotbanned: Handle `Scalar` - # TODO @dangotbanned: Update `F.ReplaceStrict` to have `default` + handle it def replace_strict( self, node: FExpr[F.ReplaceStrict], frame: Frame, name: str ) -> StoresNativeT_co: @@ -270,6 +268,24 @@ def replace_strict( raise ValueError(msg) return self._with_native(fn.chunked_array(result), name) + # TODO @dangotbanned: Move everything to `functions`, share most things with `replace_strict` + def replace_strict_default( + self, node: FExpr[F.ReplaceStrictDefault], frame: Frame, name: str + ) -> StoresNativeT_co: + expr, default_ = node.function.unwrap_input(node) + native = expr.dispatch(self, frame, name).native + default = default_.dispatch(self, frame, name).native + old, new = node.function.old, node.function.new + if isinstance(native, pa.Scalar): + msg = "TODO: `scalar.replace_strict`" + raise NotImplementedError(msg) + idxs = pc.index_in(native, pa.array(old)) + result = pa.array(new).take(idxs) + if dtype := node.function.return_dtype: + result = result.cast(narwhals_to_native_dtype(dtype, self.version)) + result = fn.when_then(idxs.is_valid(), result, default) + return self._with_native(fn.chunked_array(result), name) + class ArrowExpr( # type: ignore[misc] _ArrowDispatch["ArrowExpr | ArrowScalar"], diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 908982b087..faa5892923 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -232,6 +232,9 @@ def mode_any( def replace_strict( self, node: FunctionExpr[F.ReplaceStrict], frame: FrameT_contra, name: str ) -> Self: ... + def replace_strict_default( + self, node: FunctionExpr[F.ReplaceStrictDefault], frame: FrameT_contra, name: str + ) -> Self: ... def round( self, node: FunctionExpr[F.Round], frame: FrameT_contra, name: str ) -> Self: ... diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 1da0d7cc08..9fd200c009 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -375,14 +375,13 @@ def replace_strict( return_dtype = common.into_dtype(return_dtype) if default is no_default: - ... - else: - default_ir = parse_into_expr_ir(default, str_as_lit=True) - msg = f"replace_strict(default={default_ir!r})" - raise NotImplementedError(msg) - - function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) - return self._with_unary(function) + function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) + return self._with_unary(function) + function = F.ReplaceStrictDefault( + old=before, new=after, return_dtype=return_dtype + ) + default_ir = parse_into_expr_ir(default, str_as_lit=True) + return self._from_ir(function.to_function_expr(self._ir, default_ir)) def gather_every(self, n: int, offset: int = 0) -> Self: return self._with_unary(F.GatherEvery(n=n, offset=offset)) diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index 8dafc7ff00..b148c8843f 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -175,6 +175,12 @@ class ReplaceStrict(Function, options=FunctionOptions.elementwise): return_dtype: DType | None +class ReplaceStrictDefault(ReplaceStrict): + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, default = node.input + return expr, default + + class GatherEvery(Function): __slots__ = ("n", "offset") n: int diff --git a/tests/plan/replace_strict_test.py b/tests/plan/replace_strict_test.py index f42956352a..0896d18129 100644 --- a/tests/plan/replace_strict_test.py +++ b/tests/plan/replace_strict_test.py @@ -96,11 +96,6 @@ def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: dataframe(data).select(expr) -XFAIL_DEFAULT = pytest.mark.xfail( - reason="Not Implemented `replace_strict(default=...)` yet", raises=ValueError -) - - # TODO @dangotbanned: Share more of the case generation logic from `basic_cases` @pytest.mark.parametrize( ("expr", "expected"), @@ -111,13 +106,11 @@ def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: [1, 2], ["one", "two"], default=nwp.lit("other"), return_dtype=nw.String ), {"int": ["one", "two", "other", "other"]}, - marks=XFAIL_DEFAULT, id="non-null-1", ), pytest.param( nwp.col("int").replace_strict([1, 2], ["one", "two"], default="other"), {"int": ["one", "two", "other", "other"]}, - marks=XFAIL_DEFAULT, id="non-null-2", ), # test_replace_strict_with_default_and_nulls @@ -126,13 +119,11 @@ def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: [1, 2], [10, 20], default=99, return_dtype=nw.Int64 ), {"int-null": [10, 20, 99, 99]}, - marks=XFAIL_DEFAULT, id="null-1", ), pytest.param( nwp.col("int-null").replace_strict([1, 2], [10, 20], default=99), {"int-null": [10, 20, 99, 99]}, - marks=XFAIL_DEFAULT, id="null-2", ), # test_replace_strict_with_default_mapping @@ -141,7 +132,6 @@ def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: {1: "one", 2: "two", 3: None}, default="other", return_dtype=nw.String() ), {"int": ["one", "two", None, "other"]}, - marks=XFAIL_DEFAULT, # shouldn't be an independent case, the mapping isn't the default id="replace_strict_with_default_mapping", ), @@ -151,7 +141,6 @@ def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: {1: "one", 2: "two"}, default=nwp.col("str-alt"), return_dtype=nw.String ), {"int": ["one", "two", "orca", "vaquita"]}, - marks=XFAIL_DEFAULT, id="column", ), # test_mapping_key_not_in_expr From 801f2837fff8eb4298dc9b35047b0e6258bec8c5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 16:00:23 +0000 Subject: [PATCH 035/215] split out ops logic --- narwhals/_plan/arrow/expr.py | 45 +++++---------------- narwhals/_plan/arrow/functions.py | 66 ++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 36 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 3f158873a6..47a05f7a8b 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -244,47 +244,22 @@ def clip_upper( def replace_strict( self, node: FExpr[F.ReplaceStrict], frame: Frame, name: str ) -> StoresNativeT_co: - compliant = node.input[0].dispatch(self, frame, name) - native = compliant.native - old, new = node.function.old, node.function.new - if isinstance(native, pa.Scalar): - msg = "TODO: `scalar.replace_strict`" - raise NotImplementedError(msg) - idxs = pc.index_in(native, pa.array(old)) - result = pa.array(new).take(idxs) - if dtype := node.function.return_dtype: - result = result.cast(narwhals_to_native_dtype(dtype, self.version)) - if result.null_count != native.null_count: - replace_failed = ( - native.filter(fn.and_(fn.is_not_null(native), result.is_null())) - .unique() - .to_pylist() - ) - msg = ( - "replace_strict did not replace all non-null values.\n\n" - "The following did not get replaced: " - f"{replace_failed}" - ) - raise ValueError(msg) - return self._with_native(fn.chunked_array(result), name) + func = node.function + native = node.input[0].dispatch(self, frame, name).native + dtype = fn.dtype_native(func.return_dtype, self.version) + result = fn.replace_strict(native, func.old, func.new, dtype) + return self._with_native(result, name) - # TODO @dangotbanned: Move everything to `functions`, share most things with `replace_strict` def replace_strict_default( self, node: FExpr[F.ReplaceStrictDefault], frame: Frame, name: str ) -> StoresNativeT_co: - expr, default_ = node.function.unwrap_input(node) + func = node.function + expr, default_ = func.unwrap_input(node) native = expr.dispatch(self, frame, name).native default = default_.dispatch(self, frame, name).native - old, new = node.function.old, node.function.new - if isinstance(native, pa.Scalar): - msg = "TODO: `scalar.replace_strict`" - raise NotImplementedError(msg) - idxs = pc.index_in(native, pa.array(old)) - result = pa.array(new).take(idxs) - if dtype := node.function.return_dtype: - result = result.cast(narwhals_to_native_dtype(dtype, self.version)) - result = fn.when_then(idxs.is_valid(), result, default) - return self._with_native(fn.chunked_array(result), name) + dtype = fn.dtype_native(func.return_dtype, self.version) + result = fn.replace_strict_default(native, func.old, func.new, default, dtype) + return self._with_native(result, name) class ArrowExpr( # type: ignore[misc] diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 480df9360c..10af78c0eb 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -14,12 +14,13 @@ cast_for_truediv, chunked_array as _chunked_array, floordiv_compat as floordiv, + narwhals_to_native_dtype as _dtype_native, ) from narwhals._plan import expressions as ir from narwhals._plan._guards import is_non_nested_literal from narwhals._plan.arrow import options as pa_options from narwhals._plan.expressions import functions as F, operators as ops -from narwhals._utils import Implementation +from narwhals._utils import Implementation, Version if TYPE_CHECKING: import datetime as dt @@ -73,6 +74,7 @@ ClosedInterval, FillNullStrategy, IntoArrowSchema, + IntoDType, NonNestedLiteral, PythonLiteral, ) @@ -169,6 +171,16 @@ def modulus(lhs: Any, rhs: Any) -> Any: } +@t.overload +def dtype_native(dtype: IntoDType, version: Version) -> pa.DataType: ... +@t.overload +def dtype_native(dtype: None, version: Version) -> None: ... +@t.overload +def dtype_native(dtype: IntoDType | None, version: Version) -> pa.DataType | None: ... +def dtype_native(dtype: IntoDType | None, version: Version) -> pa.DataType | None: + return dtype if dtype is None else _dtype_native(dtype, version) + + @t.overload def cast( native: Scalar[Any], target_type: DataTypeT, *, safe: bool | None = ... @@ -485,6 +497,58 @@ def fill_null_with_strategy( return reverse(_fill_null_forward_limit(reverse(native), limit)) +def _replace_strict( + native: ChunkedOrScalarAny, + old: Seq[Any], + new: Seq[Any], + dtype: pa.DataType | None = None, +) -> Incomplete: + if isinstance(native, pa.Scalar): + msg = "TODO: `scalar.replace_strict`" + raise NotImplementedError(msg) + idxs = pc.index_in(native, pa.array(old)) + result = pa.array(new).take(idxs) + if dtype: + result = result.cast(dtype) + return result, idxs + + +def replace_strict( + native: ChunkedOrScalarAny, + old: Seq[Any], + new: Seq[Any], + dtype: pa.DataType | None = None, +) -> ChunkedOrScalarAny: + result, _ = _replace_strict(native, old, new, dtype) + if isinstance(native, pa.Scalar): + msg = "TODO: `scalar.replace_strict`" + raise NotImplementedError(msg) + if result.null_count == native.null_count: + return chunked_array(result) + + replace_failed = ( + native.filter(and_(is_not_null(native), result.is_null())).unique().to_pylist() + ) + msg = ( + "replace_strict did not replace all non-null values.\n\n" + "The following did not get replaced: " + f"{replace_failed}" + ) + raise ValueError(msg) + + +def replace_strict_default( + native: ChunkedOrScalarAny, + old: Seq[Any], + new: Seq[Any], + default: Incomplete, + dtype: pa.DataType | None = None, +) -> ChunkedOrScalarAny: + result, idxs = _replace_strict(native, old, new, dtype) + result = when_then(idxs.is_valid(), result, default) + return chunked_array(result) + + def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT], From 050b60c4a02daeb78d65a81d4a06bbbd8cbd7650 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 16:01:58 +0000 Subject: [PATCH 036/215] add tests for scalar Need to specialize this because `null_count`, `index_in`, `filter` and `unique` won't work --- tests/plan/replace_strict_test.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/plan/replace_strict_test.py b/tests/plan/replace_strict_test.py index 0896d18129..eaa9608d6d 100644 --- a/tests/plan/replace_strict_test.py +++ b/tests/plan/replace_strict_test.py @@ -156,3 +156,23 @@ def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: def test_replace_strict_expr_default(data: Data, expr: nwp.Expr, expected: Data) -> None: result = dataframe(data).select(expr) assert_equal_data(result, expected) + + +@pytest.mark.xfail(reason="TODO: `scalar.replace_strict`", raises=NotImplementedError) +def test_replace_strict_scalar(data: Data) -> None: # pragma: no cover + df = dataframe(data) + expr = ( + nwp.col("str-null") + .first() + .replace_strict({"one": 1, "two": 2, "three": 3, "four": 4}) + ) + assert_equal_data(df.select(expr), {"str-null": [1]}) + + int_null = nwp.col("int-null") + repl_ints = {1: 10, 2: 20, 4: 40} + + expr = int_null.last().replace_strict(repl_ints, default=999) + assert_equal_data(df.select(expr), {"int-null": [40]}) + + expr = int_null.sort(nulls_last=True).last().replace_strict(repl_ints, default=999) + assert_equal_data(df.select(expr), {"int-null": [999]}) From c2faad07cd6d3b77afbbab5e33180709d2d7f126 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 17:47:09 +0000 Subject: [PATCH 037/215] feat: Add scalar paths for `replace_strict` --- narwhals/_plan/arrow/expr.py | 1 - narwhals/_plan/arrow/functions.py | 56 +++++++++++++------------------ tests/plan/replace_strict_test.py | 3 +- 3 files changed, 24 insertions(+), 36 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 47a05f7a8b..65cb55f605 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -240,7 +240,6 @@ def clip_upper( ) return self._with_native(result, name) - # TODO @dangotbanned: Handle `Scalar` def replace_strict( self, node: FExpr[F.ReplaceStrict], frame: Frame, name: str ) -> StoresNativeT_co: diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 10af78c0eb..aecec16cd6 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -497,20 +497,16 @@ def fill_null_with_strategy( return reverse(_fill_null_forward_limit(reverse(native), limit)) -def _replace_strict( - native: ChunkedOrScalarAny, - old: Seq[Any], - new: Seq[Any], - dtype: pa.DataType | None = None, -) -> Incomplete: - if isinstance(native, pa.Scalar): - msg = "TODO: `scalar.replace_strict`" - raise NotImplementedError(msg) - idxs = pc.index_in(native, pa.array(old)) - result = pa.array(new).take(idxs) - if dtype: - result = result.cast(dtype) - return result, idxs +def _ensure_all_replaced( + native: ChunkedOrScalarAny, unmatched: ArrowAny +) -> ValueError | None: + if not any_(unmatched).as_py(): + return None + msg = ( + "replace_strict did not replace all non-null values.\n\n" + f"The following did not get replaced: {chunked_array(native).filter(array(unmatched)).unique().to_pylist()}" + ) + return ValueError(msg) def replace_strict( @@ -519,34 +515,28 @@ def replace_strict( new: Seq[Any], dtype: pa.DataType | None = None, ) -> ChunkedOrScalarAny: - result, _ = _replace_strict(native, old, new, dtype) if isinstance(native, pa.Scalar): - msg = "TODO: `scalar.replace_strict`" - raise NotImplementedError(msg) - if result.null_count == native.null_count: - return chunked_array(result) - - replace_failed = ( - native.filter(and_(is_not_null(native), result.is_null())).unique().to_pylist() - ) - msg = ( - "replace_strict did not replace all non-null values.\n\n" - "The following did not get replaced: " - f"{replace_failed}" - ) - raise ValueError(msg) + idxs: ArrayAny = array(pc.index_in(native, pa.array(old))) + result: ChunkedOrScalarAny = pa.array(new).take(idxs)[0] + else: + idxs = pc.index_in(native, pa.array(old)) + result = chunked_array(pa.array(new).take(idxs)) + if err := _ensure_all_replaced(native, and_(is_not_null(native), is_null(idxs))): + raise err + return result.cast(dtype) if dtype else result def replace_strict_default( native: ChunkedOrScalarAny, old: Seq[Any], new: Seq[Any], - default: Incomplete, + default: ChunkedOrScalarAny, dtype: pa.DataType | None = None, ) -> ChunkedOrScalarAny: - result, idxs = _replace_strict(native, old, new, dtype) - result = when_then(idxs.is_valid(), result, default) - return chunked_array(result) + idxs = pc.index_in(native, pa.array(old)) + result = pa.array(new).take(array(idxs)) + result = when_then(is_null(idxs), default, result.cast(dtype) if dtype else result) + return chunked_array(result) if isinstance(native, pa.ChunkedArray) else result[0] def is_between( diff --git a/tests/plan/replace_strict_test.py b/tests/plan/replace_strict_test.py index eaa9608d6d..b460764b45 100644 --- a/tests/plan/replace_strict_test.py +++ b/tests/plan/replace_strict_test.py @@ -158,8 +158,7 @@ def test_replace_strict_expr_default(data: Data, expr: nwp.Expr, expected: Data) assert_equal_data(result, expected) -@pytest.mark.xfail(reason="TODO: `scalar.replace_strict`", raises=NotImplementedError) -def test_replace_strict_scalar(data: Data) -> None: # pragma: no cover +def test_replace_strict_scalar(data: Data) -> None: df = dataframe(data) expr = ( nwp.col("str-null") From cec550c26825c29fccae7e9df629a0fa537830cd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 19:44:38 +0000 Subject: [PATCH 038/215] finally start working on `rolling_expr` --- narwhals/_plan/arrow/expr.py | 51 ++++++++++++++++++++++++++++-- narwhals/_plan/arrow/functions.py | 9 +++++- narwhals/_plan/expressions/expr.py | 6 +++- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 65cb55f605..16ce9cdb10 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -16,6 +16,7 @@ from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar from narwhals._plan.compliant.typing import namespace +from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.boolean import ( IsDuplicated, IsFirstDistinct, @@ -38,7 +39,6 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction - from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -589,7 +589,54 @@ def fill_null_with_strategy( is_unique = _boolean_length_preserving # TODO @dangotbanned: Plan composing with `functions.cum_*` - rolling_expr = not_implemented() + # Waaaaaay more of this needs to be shared + # https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L930-L1034 + + def rolling_expr( # noqa: PLR0914 + self, node: ir.RollingExpr[F.RollingWindow], frame: Frame, name: str + ) -> Self: + function = node.function + if not isinstance(function, F.RollingSum): + msg = f"TODO: {node!r}" + raise NotImplementedError(msg) + roll_options = function.options + window_size = roll_options.window_size + compliant = self._dispatch_expr(node.input[0], frame, name) + native = compliant.native + + if roll_options.center: + offset_left = window_size // 2 + # subtract one if window_size is even + offset_right = offset_left - (window_size % 2 == 0) + chunks = native.chunks + arrays = ( + fn.nulls_like(offset_left, native), + *chunks, + fn.nulls_like(offset_right, native), + ) + native = fn.concat_vertical_chunked(arrays) + offset = offset_left + offset_right + else: + offset = 0 + # NOTE: Implementing `rolling_sum` first, then will see what the others look like + # this'll be easier to read as `ArrowSeries` methods + cum_sum = fn.fill_null_with_strategy(fn.cum_sum(native), "forward") + if window_size != 0: + rolling_sum = fn.sub(cum_sum, fn.fill_null(fn.shift(cum_sum, window_size), 0)) + else: + rolling_sum = cum_sum + + valid_count = fn.cum_count(native) + count_in_window = fn.sub( + valid_count, fn.fill_null(fn.shift(valid_count, window_size), 0) + ) + result_native = fn.when_then( + fn.gt_eq(count_in_window, fn.lit(roll_options.min_samples)), rolling_sum + ) + result = compliant._with_native(result_native) + if offset: + result = result.slice(offset) + return self.from_series(result) # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_compliant/series.py#L349-L415 # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L1060-L1076 diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index aecec16cd6..0de6e8e037 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -484,6 +484,13 @@ def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArra return when_then(or_(is_not_null, beyond_limit), native, native.take(index_not_null)) +def fill_null( + native: ChunkedOrArrayT, value: ScalarAny | NonNestedLiteral +) -> ChunkedOrArrayT: + fill_value: Incomplete = value + return pc.fill_null(native, fill_value) + + def fill_null_with_strategy( native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None ) -> ChunkedArrayAny: @@ -777,7 +784,7 @@ def chunked_array( def concat_vertical_chunked( - arrays: Iterable[ChunkedArrayAny], dtype: DataType | None = None, / + arrays: Iterable[ChunkedOrArrayAny], dtype: DataType | None = None, / ) -> ChunkedArrayAny: v_concat: Incomplete = pa.chunked_array return v_concat(arrays, dtype) # type: ignore[no-any-return] diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 62c211500e..eae6dfbbf5 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -273,7 +273,11 @@ def dispatch( return self.function.__expr_ir_dispatch__(self, ctx, frame, name) -class RollingExpr(FunctionExpr[RollingT_co]): ... +class RollingExpr(FunctionExpr[RollingT_co]): + def dispatch( + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.__expr_ir_dispatch__(self, ctx, frame, name) class AnonymousExpr( From 08477b588d66478ab0df9079122c3758386f9cbb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:44:18 +0000 Subject: [PATCH 039/215] feat: Add `rolling_{sum,mean}` + start of the tests --- narwhals/_plan/arrow/expr.py | 5 ++- tests/plan/rolling_expr_test.py | 65 +++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/plan/rolling_expr_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 16ce9cdb10..3bed797bd8 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -596,7 +596,7 @@ def rolling_expr( # noqa: PLR0914 self, node: ir.RollingExpr[F.RollingWindow], frame: Frame, name: str ) -> Self: function = node.function - if not isinstance(function, F.RollingSum): + if not isinstance(function, (F.RollingSum, F.RollingMean)): msg = f"TODO: {node!r}" raise NotImplementedError(msg) roll_options = function.options @@ -633,6 +633,9 @@ def rolling_expr( # noqa: PLR0914 result_native = fn.when_then( fn.gt_eq(count_in_window, fn.lit(roll_options.min_samples)), rolling_sum ) + # NOTE: `rolling_mean` just adds has this extra linevs `rolling_sum` + if isinstance(function, F.RollingMean): + result_native = fn.truediv(result_native, count_in_window) result = compliant._with_native(result_native) if offset: result = result.slice(offset) diff --git a/tests/plan/rolling_expr_test.py b/tests/plan/rolling_expr_test.py new file mode 100644 index 0000000000..fb9a44567c --- /dev/null +++ b/tests/plan/rolling_expr_test.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals.typing import NonNestedLiteral + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [None, 1, 2, None, 4, 6, 11]} + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "expected"), + [ + (3, None, False, [None, None, None, None, None, None, 21]), + (3, 1, False, [None, 1.0, 3.0, 3.0, 6.0, 10.0, 21.0]), + (2, 1, False, [None, 1.0, 3.0, 2.0, 4.0, 10.0, 17.0]), + (5, 1, True, [3.0, 3.0, 7.0, 13.0, 23.0, 21.0, 21.0]), + (4, 1, True, [1.0, 3.0, 3.0, 7.0, 12.0, 21.0, 21.0]), + ], +) +def test_rolling_sum_expr( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + expected: list[NonNestedLiteral], +) -> None: + expr = nwp.col("a").rolling_sum(window_size, min_samples=min_samples, center=center) + result = dataframe(data).select(expr) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "expected"), + [ + (3, None, False, [None, None, None, None, None, None, 7.0]), + (3, 1, False, [None, 1.0, 1.5, 1.5, 3.0, 5.0, 7.0]), + (2, 1, False, [None, 1.0, 1.5, 2.0, 4.0, 5.0, 8.5]), + (5, 1, True, [1.5, 1.5, 7 / 3, 3.25, 5.75, 7.0, 7.0]), + (4, 1, True, [1.0, 1.5, 1.5, 7 / 3, 4.0, 7.0, 7.0]), + ], +) +def test_rolling_mean_expr( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + expected: list[NonNestedLiteral], +) -> None: + expr = nwp.col("a").rolling_mean(window_size, min_samples=min_samples, center=center) + result = dataframe(data).select(expr) + assert_equal_data(result, {"a": expected}) From d98d5bf610a644322b58afdacc42aac84351fbcf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:57:18 +0000 Subject: [PATCH 040/215] test: Add `rolling_{sum,mean}.over(order_by=...)` --- tests/plan/rolling_expr_test.py | 69 +++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 3 deletions(-) diff --git a/tests/plan/rolling_expr_test.py b/tests/plan/rolling_expr_test.py index fb9a44567c..89f31e29db 100644 --- a/tests/plan/rolling_expr_test.py +++ b/tests/plan/rolling_expr_test.py @@ -16,7 +16,12 @@ @pytest.fixture(scope="module") def data() -> Data: - return {"a": [None, 1, 2, None, 4, 6, 11]} + return { + "a": [None, 1, 2, None, 4, 6, 11], + "b": [1, None, 2, None, 4, 6, 11], + "c": [1, None, 2, 3, 4, 5, 6], + "i": list(range(7)), + } @pytest.mark.parametrize( @@ -29,7 +34,7 @@ def data() -> Data: (4, 1, True, [1.0, 3.0, 3.0, 7.0, 12.0, 21.0, 21.0]), ], ) -def test_rolling_sum_expr( +def test_rolling_sum( data: Data, window_size: int, *, @@ -52,7 +57,7 @@ def test_rolling_sum_expr( (4, 1, True, [1.0, 1.5, 1.5, 7 / 3, 4.0, 7.0, 7.0]), ], ) -def test_rolling_mean_expr( +def test_rolling_mean( data: Data, window_size: int, *, @@ -63,3 +68,61 @@ def test_rolling_mean_expr( expr = nwp.col("a").rolling_mean(window_size, min_samples=min_samples, center=center) result = dataframe(data).select(expr) assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "expected"), + [ + (2, None, False, [None, None, 3, None, None, 10, 17]), + (2, 2, False, [None, None, 3, None, None, 10, 17]), + (3, 2, False, [None, None, 3, 3, 6, 10, 21]), + (3, 1, False, [1, None, 3, 3, 6, 10, 21]), + (3, 1, True, [3, 1, 3, 6, 10, 21, 17]), + (4, 1, True, [3, 1, 3, 7, 12, 21, 21]), + (5, 1, True, [3, 3, 7, 13, 23, 21, 21]), + ], +) +def test_rolling_sum_order_by( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + expected: list[NonNestedLiteral], +) -> None: + expr = ( + nwp.col("b") + .rolling_sum(window_size, min_samples=min_samples, center=center) + .over(order_by="c") + ) + result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") + assert_equal_data(result, {"b": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "expected"), + [ + (2, None, False, [None, None, 1.5, None, None, 5, 8.5]), + (2, 2, False, [None, None, 1.5, None, None, 5, 8.5]), + (3, 2, False, [None, None, 1.5, 1.5, 3, 5, 7]), + (3, 1, False, [1, None, 1.5, 1.5, 3, 5, 7]), + (3, 1, True, [1.5, 1, 1.5, 3, 5, 7, 8.5]), + (4, 1, True, [1.5, 1, 1.5, 2.3333333333333335, 4, 7, 7]), + (5, 1, True, [1.5, 1.5, 2.3333333333333335, 3.25, 5.75, 7.0, 7.0]), + ], +) +def test_rolling_mean_order_by( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + expected: list[NonNestedLiteral], +) -> None: + expr = ( + nwp.col("b") + .rolling_mean(window_size, min_samples=min_samples, center=center) + .over(order_by="c") + ) + result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") + assert_equal_data(result, {"b": expected}) From 69ec056c2c73bd1e98ef1ed88895a06af5b99d19 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:59:41 +0000 Subject: [PATCH 041/215] chore: Update cov/todos --- narwhals/_plan/expr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 9fd200c009..8a3d8b1be8 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -285,7 +285,7 @@ def cum_sum(self, *, reverse: bool = False) -> Self: def rolling_sum( self, window_size: int, *, min_samples: int | None = None, center: bool = False - ) -> Self: # pragma: no cover + ) -> Self: options = rolling_options(window_size, min_samples, center=center) return self._with_unary(F.RollingSum(options=options)) @@ -348,7 +348,6 @@ def ewm_mean( ) return self._with_unary(F.EwmMean(options=options)) - # TODO @dangotbanned: Update to support `default` def replace_strict( self, old: Sequence[Any] | Mapping[Any, Any], From db262b57cf4dfcbf73e74a46c44fcc17074795d3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 22:30:47 +0000 Subject: [PATCH 042/215] feat(DRAFT): Add `rolling_{var,std}` --- narwhals/_plan/arrow/expr.py | 45 ++++++++++++++++++++++--------- narwhals/_plan/arrow/functions.py | 4 +++ 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 3bed797bd8..633d83060e 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -592,13 +592,12 @@ def fill_null_with_strategy( # Waaaaaay more of this needs to be shared # https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L930-L1034 - def rolling_expr( # noqa: PLR0914 + # yes ruff, i know this is too complicated! + # but we need to start somewhere + def rolling_expr( # noqa: PLR0912, PLR0914 self, node: ir.RollingExpr[F.RollingWindow], frame: Frame, name: str ) -> Self: function = node.function - if not isinstance(function, (F.RollingSum, F.RollingMean)): - msg = f"TODO: {node!r}" - raise NotImplementedError(msg) roll_options = function.options window_size = roll_options.window_size compliant = self._dispatch_expr(node.input[0], frame, name) @@ -618,9 +617,8 @@ def rolling_expr( # noqa: PLR0914 offset = offset_left + offset_right else: offset = 0 - # NOTE: Implementing `rolling_sum` first, then will see what the others look like - # this'll be easier to read as `ArrowSeries` methods - cum_sum = fn.fill_null_with_strategy(fn.cum_sum(native), "forward") + # NOTE: this'll be easier to read as `ArrowSeries` methods + cum_sum = fn.fill_null_forward(fn.cum_sum(native)) if window_size != 0: rolling_sum = fn.sub(cum_sum, fn.fill_null(fn.shift(cum_sum, window_size), 0)) else: @@ -630,12 +628,33 @@ def rolling_expr( # noqa: PLR0914 count_in_window = fn.sub( valid_count, fn.fill_null(fn.shift(valid_count, window_size), 0) ) - result_native = fn.when_then( - fn.gt_eq(count_in_window, fn.lit(roll_options.min_samples)), rolling_sum - ) - # NOTE: `rolling_mean` just adds has this extra linevs `rolling_sum` - if isinstance(function, F.RollingMean): - result_native = fn.truediv(result_native, count_in_window) + predicate = fn.gt_eq(count_in_window, fn.lit(roll_options.min_samples)) + if isinstance(function, (F.RollingVar, F.RollingStd)): + if fn_params := roll_options.fn_params: + ddof = fn_params.ddof + else: + msg = f"Expected `ddof` for {function!r}" + raise TypeError(msg) + cum_sum_sq = fn.fill_null_forward(fn.cum_sum(fn.power(native, fn.lit(2)))) + if window_size != 0: + rolling_sum_sq = fn.sub( + cum_sum_sq, fn.fill_null(fn.shift(cum_sum_sq, window_size), 0) + ) + else: + rolling_sum_sq = cum_sum_sq + rolling_something = fn.sub( + rolling_sum_sq, + fn.truediv(fn.power(rolling_sum, fn.lit(2)), count_in_window), + ) + i_dunno_man = fn.when_then(predicate, rolling_something) + denom = fn.max_horizontal(fn.sub(count_in_window, fn.lit(ddof)), fn.lit(0)) + result_native = fn.truediv(i_dunno_man, denom) + if isinstance(function, (F.RollingStd)): + result_native = fn.power(result_native, fn.lit(0.5)) + else: + result_native = fn.when_then(predicate, rolling_sum) + if isinstance(function, F.RollingMean): + result_native = fn.truediv(result_native, count_in_window) result = compliant._with_native(result_native) if offset: result = result.slice(offset) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 0de6e8e037..d5b2cb1c92 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -491,6 +491,10 @@ def fill_null( return pc.fill_null(native, fill_value) +def fill_null_forward(native: ChunkedArrayAny) -> ChunkedArrayAny: + return fill_null_with_strategy(native, "forward") + + def fill_null_with_strategy( native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None ) -> ChunkedArrayAny: From f0d182dc177b06a46920accf0d7b18cd5e51732b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Nov 2025 23:00:03 +0000 Subject: [PATCH 043/215] perf: Don't create nulls and then replace --- narwhals/_plan/arrow/expr.py | 9 +++++---- narwhals/_plan/arrow/functions.py | 24 ++++++++++++++++++++---- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 633d83060e..32b36546a3 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -618,15 +618,15 @@ def rolling_expr( # noqa: PLR0912, PLR0914 else: offset = 0 # NOTE: this'll be easier to read as `ArrowSeries` methods - cum_sum = fn.fill_null_forward(fn.cum_sum(native)) + cum_sum = fn.fill_null_forward(fn.cum_sum(native)).fill_null(fn.lit(0)) if window_size != 0: - rolling_sum = fn.sub(cum_sum, fn.fill_null(fn.shift(cum_sum, window_size), 0)) + rolling_sum = fn.sub(cum_sum, fn.shift(cum_sum, window_size, fill_value=0)) else: rolling_sum = cum_sum valid_count = fn.cum_count(native) count_in_window = fn.sub( - valid_count, fn.fill_null(fn.shift(valid_count, window_size), 0) + valid_count, fn.shift(valid_count, window_size, fill_value=0) ) predicate = fn.gt_eq(count_in_window, fn.lit(roll_options.min_samples)) if isinstance(function, (F.RollingVar, F.RollingStd)): @@ -635,10 +635,11 @@ def rolling_expr( # noqa: PLR0912, PLR0914 else: msg = f"Expected `ddof` for {function!r}" raise TypeError(msg) + # NOTE: Once this has coverage, probably need to add the `fill_null(0)` for the ends cum_sum_sq = fn.fill_null_forward(fn.cum_sum(fn.power(native, fn.lit(2)))) if window_size != 0: rolling_sum_sq = fn.sub( - cum_sum_sq, fn.fill_null(fn.shift(cum_sum_sq, window_size), 0) + cum_sum_sq, fn.shift(cum_sum_sq, window_size, fill_value=0) ) else: rolling_sum_sq = cum_sum_sq diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d5b2cb1c92..b934e7f931 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -429,14 +429,18 @@ def diff(native: ChunkedOrArrayT) -> ChunkedOrArrayT: ) -def shift(native: ChunkedArrayAny, n: int) -> ChunkedArrayAny: +def shift( + native: ChunkedArrayAny, n: int, *, fill_value: NonNestedLiteral = None +) -> ChunkedArrayAny: if n == 0: return native arr = native if n > 0: - arrays = [nulls_like(n, arr), *arr.slice(length=arr.length() - n).chunks] + filled = repeat_like(fill_value, n, arr) + arrays = [filled, *arr.slice(length=arr.length() - n).chunks] else: - arrays = [*arr.slice(offset=-n).chunks, nulls_like(-n, arr)] + filled = repeat_like(fill_value, -n, arr) + arrays = [*arr.slice(offset=-n).chunks, filled] return pa.chunked_array(arrays) @@ -746,12 +750,24 @@ def date_range( return ca.cast(pa.date32()) +def repeat(value: ScalarAny | NonNestedLiteral, n: int) -> ArrayAny: + repeat_: Incomplete = pa.repeat + value = value if isinstance(value, pa.Scalar) else lit(value) + result: ArrayAny = repeat_(value, n) + return result + + +def repeat_like(value: NonNestedLiteral, n: int, native: ArrowAny) -> ArrayAny: + return repeat(lit(value, native.type), n) + + def nulls_like(n: int, native: ArrowAny) -> ArrayAny: """Create a strongly-typed Array instance with all elements null. Uses the type of `native`. """ - return pa.nulls(n, native.type) # type: ignore[no-any-return] + result: ArrayAny = pa.nulls(n, native.type) + return result def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: From f351758c849151368ab810318d39a7566562410a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 12:16:49 +0000 Subject: [PATCH 044/215] feat: Add `CompliantSeries` dunders Fully lost myself in finding the cause of a `rolling_{var,std}` bug Implementing these ops + the other methods will make it more visible --- narwhals/_plan/compliant/series.py | 52 +++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index f9c33523ff..6ea082f555 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -15,7 +15,16 @@ from narwhals._plan.series import Series from narwhals._typing import _EagerAllowedImpl from narwhals.dtypes import DType - from narwhals.typing import Into1DArray, IntoDType, SizedMultiIndexSelector, _1DArray + from narwhals.typing import ( + FillNullStrategy, + Into1DArray, + IntoDType, + NonNestedLiteral, + NumericLiteral, + SizedMultiIndexSelector, + TemporalLiteral, + _1DArray, + ) Incomplete: TypeAlias = Any @@ -28,9 +37,43 @@ class CompliantSeries(HasVersion, Protocol[NativeSeriesT]): def __len__(self) -> int: return len(self.native) + def __add__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __and__(self, other: bool | Self) -> Self: ... + def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... # type: ignore[override] + def __floordiv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __ge__(self, other: NonNestedLiteral | Self) -> Self: ... + def __gt__(self, other: NonNestedLiteral | Self) -> Self: ... + def __invert__(self) -> Self: ... + def __le__(self, other: NonNestedLiteral | Self) -> Self: ... + def __lt__(self, other: NonNestedLiteral | Self) -> Self: ... + def __mod__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __mul__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __ne__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... # type: ignore[override] + def __or__(self, other: bool | Self) -> Self: ... + def __pow__(self, other: float | Self) -> Self: ... + def __rfloordiv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __radd__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __rand__(self, other: bool | Self) -> Self: ... + def __rmod__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __rmul__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __ror__(self, other: bool | Self) -> Self: ... + def __rpow__(self, other: float | Self) -> Self: ... + def __rsub__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __rtruediv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __rxor__(self, other: bool | Self) -> Self: ... + def __sub__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __truediv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... + def __xor__(self, other: bool | Self) -> Self: ... + def len(self) -> int: return len(self.native) + def not_(self) -> Self: + return self.__invert__() + + def pow(self, exponent: float | Self) -> Self: + return self.__pow__(exponent) + def __narwhals_namespace__(self) -> Incomplete: ... def __narwhals_series__(self) -> Self: return self @@ -75,6 +118,13 @@ def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) def cast(self, dtype: IntoDType) -> Self: ... + def cum_sum(self, *, reverse: bool = False) -> Self: ... + def cum_count(self, *, reverse: bool = False) -> Self: ... + def fill_null(self, value: Self | NonNestedLiteral) -> Self: ... + def fill_null_with_strategy( + self, strategy: FillNullStrategy, limit: int | None = None + ) -> Self: ... + def shift(self, n: int, *, fill_value: NonNestedLiteral = None) -> Self: ... def gather( self, indices: SizedMultiIndexSelector[NativeSeriesT] | _StoresNative[NativeSeriesT], From 299d9d10bd72352abdb0a1cdff2df5101c888e48 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:01:42 +0000 Subject: [PATCH 045/215] start adding `ArrowSeries` dunders fixed the typing as well --- narwhals/_plan/arrow/functions.py | 23 +++++++ narwhals/_plan/arrow/series.py | 105 ++++++++++++++++++++++++++++- narwhals/_plan/compliant/series.py | 56 +++++++-------- 3 files changed, 156 insertions(+), 28 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index b934e7f931..a482eb7988 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -68,6 +68,7 @@ UnaryFunction, VectorFunction, ) + from narwhals._plan.compliant.typing import SeriesT from narwhals._plan.options import RankOptions, SortMultipleOptions, SortOptions from narwhals._plan.typing import OneOrSeq, Seq from narwhals.typing import ( @@ -163,6 +164,28 @@ def modulus(lhs: Any, rhs: Any) -> Any: ops.ExclusiveOr: xor, } + +def bin_op( + function: Callable[[Any, Any], Any], /, *, reflect: bool = False +) -> Callable[[SeriesT, Any], SeriesT]: + """Attach a binary operator to `ArrowSeries`.""" + + def f(self: SeriesT, other: SeriesT | Any, /) -> SeriesT: + right = other.native if isinstance(other, type(self)) else lit(other) + return self._with_native(function(self.native, right)) + + def f_reflect(self: SeriesT, other: SeriesT | Any, /) -> SeriesT: + if isinstance(other, type(self)): + name = other.name + right: ArrowAny = other.native + else: + name = "literal" + right = lit(other) + return self.from_native(function(right, self.native), name, version=self.version) + + return f_reflect if reflect else f + + _IS_BETWEEN: Mapping[ClosedInterval, tuple[BinaryComp, BinaryComp]] = { "left": (gt_eq, lt), "right": (gt, lt_eq), diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index ffd68b660a..8c547b59b8 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -21,7 +21,15 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame as DataFrame from narwhals._plan.arrow.typing import ChunkedArrayAny from narwhals.dtypes import DType - from narwhals.typing import Into1DArray, IntoDType, _1DArray + from narwhals.typing import ( + FillNullStrategy, + Into1DArray, + IntoDType, + NonNestedLiteral, + NumericLiteral, + TemporalLiteral, + _1DArray, + ) class ArrowSeries(FrameSeries["ChunkedArrayAny"], CompliantSeries["ChunkedArrayAny"]): @@ -94,3 +102,98 @@ def is_in(self, other: Self) -> Self: def has_nulls(self) -> bool: return bool(self.native.null_count) + + __add__ = fn.bin_op(fn.add) + __and__ = fn.bin_op(fn.and_) + + def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # type: ignore[override] + raise NotImplementedError + + def __floordiv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __ge__(self, other: NonNestedLiteral | Self) -> Self: + raise NotImplementedError + + def __gt__(self, other: NonNestedLiteral | Self) -> Self: + raise NotImplementedError + + def __invert__(self) -> Self: + raise NotImplementedError + + def __le__(self, other: NonNestedLiteral | Self) -> Self: + raise NotImplementedError + + def __lt__(self, other: NonNestedLiteral | Self) -> Self: + raise NotImplementedError + + def __mod__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __mul__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __ne__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # type: ignore[override] + raise NotImplementedError + + def __or__(self, other: bool | Self) -> Self: + raise NotImplementedError + + def __pow__(self, other: float | Self) -> Self: + raise NotImplementedError + + def __rfloordiv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __radd__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __rand__(self, other: bool | Self) -> Self: + raise NotImplementedError + + def __rmod__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __rmul__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __ror__(self, other: bool | Self) -> Self: + raise NotImplementedError + + def __rpow__(self, other: float | Self) -> Self: + raise NotImplementedError + + def __rsub__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __rtruediv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __rxor__(self, other: bool | Self) -> Self: + raise NotImplementedError + + def __sub__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __truediv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: + raise NotImplementedError + + def __xor__(self, other: bool | Self) -> Self: + raise NotImplementedError + + def cum_sum(self, *, reverse: bool = False) -> Self: + raise NotImplementedError + + def cum_count(self, *, reverse: bool = False) -> Self: + raise NotImplementedError + + def fill_null(self, value: NonNestedLiteral | Self) -> Self: + raise NotImplementedError + + def fill_null_with_strategy( + self, strategy: FillNullStrategy, limit: int | None = None + ) -> Self: + raise NotImplementedError + + def shift(self, n: int, *, fill_value: NonNestedLiteral = None) -> Self: + raise NotImplementedError diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 6ea082f555..3ae8b71fbb 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -37,33 +37,35 @@ class CompliantSeries(HasVersion, Protocol[NativeSeriesT]): def __len__(self) -> int: return len(self.native) - def __add__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __and__(self, other: bool | Self) -> Self: ... - def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... # type: ignore[override] - def __floordiv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __ge__(self, other: NonNestedLiteral | Self) -> Self: ... - def __gt__(self, other: NonNestedLiteral | Self) -> Self: ... + def __add__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __and__(self, other: bool | Self, /) -> Self: ... + def __eq__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... # type: ignore[override] + def __floordiv__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __ge__(self, other: NonNestedLiteral | Self, /) -> Self: ... + def __gt__(self, other: NonNestedLiteral | Self, /) -> Self: ... def __invert__(self) -> Self: ... - def __le__(self, other: NonNestedLiteral | Self) -> Self: ... - def __lt__(self, other: NonNestedLiteral | Self) -> Self: ... - def __mod__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __mul__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __ne__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... # type: ignore[override] - def __or__(self, other: bool | Self) -> Self: ... - def __pow__(self, other: float | Self) -> Self: ... - def __rfloordiv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __radd__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __rand__(self, other: bool | Self) -> Self: ... - def __rmod__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __rmul__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __ror__(self, other: bool | Self) -> Self: ... - def __rpow__(self, other: float | Self) -> Self: ... - def __rsub__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __rtruediv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __rxor__(self, other: bool | Self) -> Self: ... - def __sub__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __truediv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: ... - def __xor__(self, other: bool | Self) -> Self: ... + def __le__(self, other: NonNestedLiteral | Self, /) -> Self: ... + def __lt__(self, other: NonNestedLiteral | Self, /) -> Self: ... + def __mod__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __mul__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __ne__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... # type: ignore[override] + def __or__(self, other: bool | Self, /) -> Self: ... + def __pow__(self, other: float | Self, /) -> Self: ... + def __rfloordiv__( + self, other: NumericLiteral | TemporalLiteral | Self, / + ) -> Self: ... + def __radd__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __rand__(self, other: bool | Self, /) -> Self: ... + def __rmod__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __rmul__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __ror__(self, other: bool | Self, /) -> Self: ... + def __rpow__(self, other: float | Self, /) -> Self: ... + def __rsub__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __rtruediv__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __rxor__(self, other: bool | Self, /) -> Self: ... + def __sub__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __truediv__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __xor__(self, other: bool | Self, /) -> Self: ... def len(self) -> int: return len(self.native) @@ -120,7 +122,7 @@ def alias(self, name: str) -> Self: def cast(self, dtype: IntoDType) -> Self: ... def cum_sum(self, *, reverse: bool = False) -> Self: ... def cum_count(self, *, reverse: bool = False) -> Self: ... - def fill_null(self, value: Self | NonNestedLiteral) -> Self: ... + def fill_null(self, value: NonNestedLiteral | Self) -> Self: ... def fill_null_with_strategy( self, strategy: FillNullStrategy, limit: int | None = None ) -> Self: ... From afb195b4095c73a882529652458a8ed8bfa04f90 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:45:22 +0000 Subject: [PATCH 046/215] feat: Impl `ArrowSeries` dunders --- narwhals/_plan/arrow/functions.py | 3 +- narwhals/_plan/arrow/series.py | 100 ++++++++---------------------- 2 files changed, 27 insertions(+), 76 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index a482eb7988..8a4efa9804 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -13,7 +13,7 @@ from narwhals._arrow.utils import ( cast_for_truediv, chunked_array as _chunked_array, - floordiv_compat as floordiv, + floordiv_compat as _floordiv, narwhals_to_native_dtype as _dtype_native, ) from narwhals._plan import expressions as ir @@ -130,6 +130,7 @@ class MinMax(ir.AggExpr): sub = t.cast("BinaryNumericTemporal", pc.subtract) multiply = pc.multiply power = t.cast("BinaryFunction[pc.NumericScalar, pc.NumericScalar]", pc.power) +floordiv = _floordiv def truediv(lhs: Any, rhs: Any) -> Any: diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 8c547b59b8..862c447aea 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -26,8 +26,6 @@ Into1DArray, IntoDType, NonNestedLiteral, - NumericLiteral, - TemporalLiteral, _1DArray, ) @@ -105,81 +103,33 @@ def has_nulls(self) -> bool: __add__ = fn.bin_op(fn.add) __and__ = fn.bin_op(fn.and_) - - def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # type: ignore[override] - raise NotImplementedError - - def __floordiv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __ge__(self, other: NonNestedLiteral | Self) -> Self: - raise NotImplementedError - - def __gt__(self, other: NonNestedLiteral | Self) -> Self: - raise NotImplementedError + __eq__ = fn.bin_op(fn.eq) + __floordiv__ = fn.bin_op(fn.floordiv) + __ge__ = fn.bin_op(fn.gt_eq) + __gt__ = fn.bin_op(fn.gt) + __le__ = fn.bin_op(fn.lt_eq) + __lt__ = fn.bin_op(fn.lt) + __mod__ = fn.bin_op(fn.modulus) + __mul__ = fn.bin_op(fn.multiply) + __ne__ = fn.bin_op(fn.not_eq) + __or__ = fn.bin_op(fn.or_) + __pow__ = fn.bin_op(fn.power) + __rfloordiv__ = fn.bin_op(fn.floordiv, reflect=True) + __radd__ = fn.bin_op(fn.add, reflect=True) + __rand__ = fn.bin_op(fn.and_, reflect=True) + __rmod__ = fn.bin_op(fn.modulus, reflect=True) + __rmul__ = fn.bin_op(fn.multiply, reflect=True) + __ror__ = fn.bin_op(fn.or_, reflect=True) + __rpow__ = fn.bin_op(fn.power, reflect=True) + __rsub__ = fn.bin_op(fn.sub, reflect=True) + __rtruediv__ = fn.bin_op(fn.truediv, reflect=True) + __rxor__ = fn.bin_op(fn.xor, reflect=True) + __sub__ = fn.bin_op(fn.sub) + __truediv__ = fn.bin_op(fn.truediv) + __xor__ = fn.bin_op(fn.xor) def __invert__(self) -> Self: - raise NotImplementedError - - def __le__(self, other: NonNestedLiteral | Self) -> Self: - raise NotImplementedError - - def __lt__(self, other: NonNestedLiteral | Self) -> Self: - raise NotImplementedError - - def __mod__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __mul__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __ne__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # type: ignore[override] - raise NotImplementedError - - def __or__(self, other: bool | Self) -> Self: - raise NotImplementedError - - def __pow__(self, other: float | Self) -> Self: - raise NotImplementedError - - def __rfloordiv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __radd__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __rand__(self, other: bool | Self) -> Self: - raise NotImplementedError - - def __rmod__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __rmul__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __ror__(self, other: bool | Self) -> Self: - raise NotImplementedError - - def __rpow__(self, other: float | Self) -> Self: - raise NotImplementedError - - def __rsub__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __rtruediv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __rxor__(self, other: bool | Self) -> Self: - raise NotImplementedError - - def __sub__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __truediv__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: - raise NotImplementedError - - def __xor__(self, other: bool | Self) -> Self: - raise NotImplementedError + return self._with_native(pc.invert(self.native)) def cum_sum(self, *, reverse: bool = False) -> Self: raise NotImplementedError From a17c3122f8adeae8e5c9a468018fe7b6f1cd8ede Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 14:12:41 +0000 Subject: [PATCH 047/215] feat: Impl the other `ArrowSeries` methods --- narwhals/_plan/arrow/expr.py | 7 +------ narwhals/_plan/arrow/functions.py | 9 +++++++-- narwhals/_plan/arrow/series.py | 31 +++++++++++++++++++++++++----- narwhals/_plan/compliant/series.py | 5 ++++- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 32b36546a3..e70a4b2d76 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -543,12 +543,7 @@ def rank(self, node: FExpr[Rank], frame: Frame, name: str) -> Self: def _cumulative(self, node: FExpr[CumAgg], frame: Frame, name: str) -> Self: native = self._dispatch_expr(node.input[0], frame, name).native - func = fn.CUMULATIVE[type(node.function)] - if not node.function.reverse: - result = func(native) - else: - result = fn.reverse(func(fn.reverse(native))) - return self._with_native(result, name) + return self._with_native(fn.cumulative(native, node.function), name) def unique(self, node: FExpr[F.Unique], frame: Frame, name: str) -> Self: result = self._dispatch_expr(node.input[0], frame, name).native.unique() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 8a4efa9804..00aa21d893 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -435,7 +435,7 @@ def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: return cum_sum(is_not_null(native).cast(pa.uint32())) -CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { +_CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { F.CumSum: cum_sum, F.CumCount: cum_count, F.CumMin: cum_min, @@ -444,6 +444,11 @@ def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: } +def cumulative(native: ChunkedArrayAny, f: F.CumAgg, /) -> ChunkedArrayAny: + func = _CUMULATIVE[type(f)] + return func(native) if not f.reverse else reverse(func(reverse(native))) + + def diff(native: ChunkedOrArrayT) -> ChunkedOrArrayT: # pyarrow.lib.ArrowInvalid: Vector kernel cannot execute chunkwise and no chunked exec function was defined return ( @@ -513,7 +518,7 @@ def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArra def fill_null( - native: ChunkedOrArrayT, value: ScalarAny | NonNestedLiteral + native: ChunkedOrArrayT, value: ScalarAny | NonNestedLiteral | ChunkedOrArrayT ) -> ChunkedOrArrayT: fill_value: Incomplete = value return pc.fill_null(native, fill_value) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 862c447aea..ce4bbbc7af 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -9,6 +9,7 @@ from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.compliant.typing import namespace +from narwhals._plan.expressions import functions as F from narwhals._utils import Version, generate_repr from narwhals.dependencies import is_numpy_array_1d @@ -132,18 +133,38 @@ def __invert__(self) -> Self: return self._with_native(pc.invert(self.native)) def cum_sum(self, *, reverse: bool = False) -> Self: - raise NotImplementedError + if not reverse: + return self._with_native(fn.cum_sum(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumSum(reverse=reverse))) def cum_count(self, *, reverse: bool = False) -> Self: - raise NotImplementedError + if not reverse: + return self._with_native(fn.cum_count(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumCount(reverse=reverse))) + + def cum_max(self, *, reverse: bool = False) -> Self: + if not reverse: + return self._with_native(fn.cum_max(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumMax(reverse=reverse))) + + def cum_min(self, *, reverse: bool = False) -> Self: + if not reverse: + return self._with_native(fn.cum_min(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumMin(reverse=reverse))) + + def cum_prod(self, *, reverse: bool = False) -> Self: + if not reverse: + return self._with_native(fn.cum_prod(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumProd(reverse=reverse))) def fill_null(self, value: NonNestedLiteral | Self) -> Self: - raise NotImplementedError + fill_value = value.native if isinstance(value, ArrowSeries) else value + return self._with_native(fn.fill_null(self.native, fill_value)) def fill_null_with_strategy( self, strategy: FillNullStrategy, limit: int | None = None ) -> Self: - raise NotImplementedError + return self._with_native(fn.fill_null_with_strategy(self.native, strategy, limit)) def shift(self, n: int, *, fill_value: NonNestedLiteral = None) -> Self: - raise NotImplementedError + return self._with_native(fn.shift(self.native, n, fill_value=fill_value)) diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 3ae8b71fbb..ac61989bba 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -120,8 +120,11 @@ def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) def cast(self, dtype: IntoDType) -> Self: ... - def cum_sum(self, *, reverse: bool = False) -> Self: ... def cum_count(self, *, reverse: bool = False) -> Self: ... + def cum_max(self, *, reverse: bool = False) -> Self: ... + def cum_min(self, *, reverse: bool = False) -> Self: ... + def cum_prod(self, *, reverse: bool = False) -> Self: ... + def cum_sum(self, *, reverse: bool = False) -> Self: ... def fill_null(self, value: NonNestedLiteral | Self) -> Self: ... def fill_null_with_strategy( self, strategy: FillNullStrategy, limit: int | None = None From 8e5c156fca936634b3d29248bd5ce66a5aeba5c7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 15:10:50 +0000 Subject: [PATCH 048/215] "fix" `rolling_var` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `rolling_std` (may) have ended up being incorrect but `rolling_var` was just testing on the wrong data 😭 --- narwhals/_plan/arrow/expr.py | 57 ++++++++++++--------------------- narwhals/_plan/arrow/series.py | 19 +++++++++++ tests/plan/rolling_expr_test.py | 29 +++++++++++++++++ 3 files changed, 69 insertions(+), 36 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index e70a4b2d76..a459620cab 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -521,6 +521,7 @@ def _boolean_length_preserving( def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: if node.is_scalar: # NOTE: Just trying to avoid redoing the whole API for `Series` + # https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_compliant/expr.py#L738-L755 msg = "Only elementwise is currently supported" raise NotImplementedError(msg) series = self._dispatch_expr(node.input[0], frame, name) @@ -596,34 +597,21 @@ def rolling_expr( # noqa: PLR0912, PLR0914 roll_options = function.options window_size = roll_options.window_size compliant = self._dispatch_expr(node.input[0], frame, name) - native = compliant.native + # Read up on polars impl to get some names for what this is if roll_options.center: - offset_left = window_size // 2 - # subtract one if window_size is even - offset_right = offset_left - (window_size % 2 == 0) - chunks = native.chunks - arrays = ( - fn.nulls_like(offset_left, native), - *chunks, - fn.nulls_like(offset_right, native), - ) - native = fn.concat_vertical_chunked(arrays) - offset = offset_left + offset_right + compliant, offset = compliant._rolling_center(window_size) else: offset = 0 - # NOTE: this'll be easier to read as `ArrowSeries` methods - cum_sum = fn.fill_null_forward(fn.cum_sum(native)).fill_null(fn.lit(0)) + cum_sum = compliant.cum_sum().fill_null_with_strategy("forward") if window_size != 0: - rolling_sum = fn.sub(cum_sum, fn.shift(cum_sum, window_size, fill_value=0)) + rolling_sum = cum_sum - cum_sum.shift(window_size).fill_null(0) else: rolling_sum = cum_sum - valid_count = fn.cum_count(native) - count_in_window = fn.sub( - valid_count, fn.shift(valid_count, window_size, fill_value=0) - ) - predicate = fn.gt_eq(count_in_window, fn.lit(roll_options.min_samples)) + valid_count = compliant.cum_count() + count_in_window = valid_count - valid_count.shift(window_size).fill_null(0) + predicate = count_in_window >= roll_options.min_samples if isinstance(function, (F.RollingVar, F.RollingStd)): if fn_params := roll_options.fn_params: ddof = fn_params.ddof @@ -631,29 +619,26 @@ def rolling_expr( # noqa: PLR0912, PLR0914 msg = f"Expected `ddof` for {function!r}" raise TypeError(msg) # NOTE: Once this has coverage, probably need to add the `fill_null(0)` for the ends - cum_sum_sq = fn.fill_null_forward(fn.cum_sum(fn.power(native, fn.lit(2)))) + cum_sum_sq = compliant.pow(2).cum_sum().fill_null_with_strategy("forward") if window_size != 0: - rolling_sum_sq = fn.sub( - cum_sum_sq, fn.shift(cum_sum_sq, window_size, fill_value=0) - ) + rolling_sum_sq = cum_sum_sq - cum_sum_sq.shift(window_size).fill_null(0) else: rolling_sum_sq = cum_sum_sq - rolling_something = fn.sub( - rolling_sum_sq, - fn.truediv(fn.power(rolling_sum, fn.lit(2)), count_in_window), - ) - i_dunno_man = fn.when_then(predicate, rolling_something) - denom = fn.max_horizontal(fn.sub(count_in_window, fn.lit(ddof)), fn.lit(0)) - result_native = fn.truediv(i_dunno_man, denom) - if isinstance(function, (F.RollingStd)): - result_native = fn.power(result_native, fn.lit(0.5)) + + rolling_something = rolling_sum_sq - (rolling_sum**2 / count_in_window) + i_dunno_man = fn.when_then(predicate.native, rolling_something.native) + denominator = fn.max_horizontal((count_in_window - ddof).native, 0) + result = compliant._with_native(fn.truediv(i_dunno_man, denominator)) else: - result_native = fn.when_then(predicate, rolling_sum) + result = compliant._with_native( + fn.when_then(predicate.native, rolling_sum.native) + ) if isinstance(function, F.RollingMean): - result_native = fn.truediv(result_native, count_in_window) - result = compliant._with_native(result_native) + result = result / count_in_window if offset: result = result.slice(offset) + if isinstance(function, (F.RollingStd)): + result = result**0.5 return self.from_series(result) # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_compliant/series.py#L349-L415 diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index ce4bbbc7af..af17ddb1c5 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -168,3 +168,22 @@ def fill_null_with_strategy( def shift(self, n: int, *, fill_value: NonNestedLiteral = None) -> Self: return self._with_native(fn.shift(self.native, n, fill_value=fill_value)) + + def _rolling_center(self, window_size: int) -> tuple[Self, int]: + """Think this is similar to [`polars_core::chunked_array::ops::rolling_window::inner_mod::window_edges`]. + + On `main`, this is `narwhals._arrow.utils.pad_series`. + + [`polars_core::chunked_array::ops::rolling_window::inner_mod::window_edges`]: https://github.com/pola-rs/polars/blob/e1d6f294218a36497255e2d872c223e19a47e2ec/crates/polars-core/src/chunked_array/ops/rolling_window.rs#L64-L77 + """ + offset_left = window_size // 2 + # subtract one if window_size is even + offset_right = offset_left - (window_size % 2 == 0) + native = self.native + arrays = ( + fn.nulls_like(offset_left, native), + *native.chunks, + fn.nulls_like(offset_right, native), + ) + offset = offset_left + offset_right + return self._with_native(fn.concat_vertical_chunked(arrays)), offset diff --git a/tests/plan/rolling_expr_test.py b/tests/plan/rolling_expr_test.py index 89f31e29db..8a413313b4 100644 --- a/tests/plan/rolling_expr_test.py +++ b/tests/plan/rolling_expr_test.py @@ -20,10 +20,39 @@ def data() -> Data: "a": [None, 1, 2, None, 4, 6, 11], "b": [1, None, 2, None, 4, 6, 11], "c": [1, None, 2, 3, 4, 5, 6], + "d": [1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0], "i": list(range(7)), } +# TODO @dangotbanned: Just reuse `rolling_options` for the tests? +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "ddof", "expected"), + [ + (3, None, False, 1, [None, None, 1 / 3, 1, 4 / 3, 7 / 3, 3]), + (3, 1, False, 1, [None, 0.5, 1 / 3, 1.0, 4 / 3, 7 / 3, 3]), + (2, 1, False, 1, [None, 0.5, 0.5, 2.0, 2.0, 4.5, 4.5]), + (5, 1, True, 1, [1 / 3, 11 / 12, 4 / 5, 17 / 10, 2.0, 2.25, 3]), + (4, 1, True, 1, [0.5, 1 / 3, 11 / 12, 11 / 12, 2.25, 2.25, 3]), + (3, None, False, 2, [None, None, 2 / 3, 2.0, 8 / 3, 14 / 3, 6.0]), + ], +) +def test_rolling_var( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + ddof: int, + expected: list[NonNestedLiteral], +) -> None: + expr = nwp.col("d").rolling_var( + window_size, min_samples=min_samples, center=center, ddof=ddof + ) + result = dataframe(data).select(expr) + assert_equal_data(result, {"d": expected}) + + @pytest.mark.parametrize( ("window_size", "min_samples", "center", "expected"), [ From 9bd8bcdd1fbb66ad14e88459737924de87ed12a4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 15:45:58 +0000 Subject: [PATCH 049/215] get sugary, with an extended `zip_with` --- narwhals/_plan/arrow/expr.py | 14 +++++++------- narwhals/_plan/arrow/series.py | 5 +++++ narwhals/_plan/compliant/series.py | 1 + 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index a459620cab..950b3cd784 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -624,15 +624,15 @@ def rolling_expr( # noqa: PLR0912, PLR0914 rolling_sum_sq = cum_sum_sq - cum_sum_sq.shift(window_size).fill_null(0) else: rolling_sum_sq = cum_sum_sq - + # TODO @dangotbanned: Better name? rolling_something = rolling_sum_sq - (rolling_sum**2 / count_in_window) - i_dunno_man = fn.when_then(predicate.native, rolling_something.native) - denominator = fn.max_horizontal((count_in_window - ddof).native, 0) - result = compliant._with_native(fn.truediv(i_dunno_man, denominator)) - else: - result = compliant._with_native( - fn.when_then(predicate.native, rolling_sum.native) + # TODO @dangotbanned: Better name? + denominator = compliant._with_native( + fn.max_horizontal((count_in_window - ddof).native, 0) ) + result = rolling_something.zip_with(predicate, None) / denominator + else: + result = rolling_sum.zip_with(predicate, None) if isinstance(function, F.RollingMean): result = result / count_in_window if offset: diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index af17ddb1c5..f97f40a5b5 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -187,3 +187,8 @@ def _rolling_center(self, window_size: int) -> tuple[Self, int]: ) offset = offset_left + offset_right return self._with_native(fn.concat_vertical_chunked(arrays)), offset + + def zip_with(self, mask: Self, other: Self | None) -> Self: + predicate = mask.native.combine_chunks() + right = other.native if other is not None else other + return self._with_native(fn.when_then(predicate, self.native, right)) diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index ac61989bba..c82f456ebe 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -151,3 +151,4 @@ def to_narwhals(self) -> Series[NativeSeriesT]: def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... def to_polars(self) -> pl.Series: ... + def zip_with(self, mask: Self, other: Self) -> Self: ... From 988694d788d33f44b055a9163eccefee7a716203 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:03:07 +0000 Subject: [PATCH 050/215] add `test_rolling_std` --- tests/plan/rolling_expr_test.py | 39 ++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/tests/plan/rolling_expr_test.py b/tests/plan/rolling_expr_test.py index 8a413313b4..8a4fc38b60 100644 --- a/tests/plan/rolling_expr_test.py +++ b/tests/plan/rolling_expr_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math from typing import TYPE_CHECKING import pytest @@ -14,13 +15,17 @@ pytest.importorskip("pyarrow") +def sqrt_or_null(*values: float | None) -> list[float | None]: + return [el if el is None else math.sqrt(el) for el in values] + + @pytest.fixture(scope="module") def data() -> Data: return { "a": [None, 1, 2, None, 4, 6, 11], "b": [1, None, 2, None, 4, 6, 11], "c": [1, None, 2, 3, 4, 5, 6], - "d": [1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0], + "var_std": [1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0], "i": list(range(7)), } @@ -46,11 +51,39 @@ def test_rolling_var( ddof: int, expected: list[NonNestedLiteral], ) -> None: - expr = nwp.col("d").rolling_var( + expr = nwp.col("var_std").rolling_var( + window_size, min_samples=min_samples, center=center, ddof=ddof + ) + result = dataframe(data).select(expr) + assert_equal_data(result, {"var_std": expected}) + + +# TODO @dangotbanned: Just reuse `rolling_options` for the tests? +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "ddof", "expected"), + [ + (3, None, False, 1, sqrt_or_null(None, None, 1 / 3, 1, 4 / 3, 7 / 3, 3)), + (3, 1, False, 1, sqrt_or_null(None, 0.5, 1 / 3, 1.0, 4 / 3, 7 / 3, 3)), + (2, 1, False, 1, sqrt_or_null(None, 0.5, 0.5, 2.0, 2.0, 4.5, 4.5)), + (5, 1, True, 1, sqrt_or_null(1 / 3, 11 / 12, 4 / 5, 17 / 10, 2.0, 2.25, 3)), + (4, 1, True, 1, sqrt_or_null(0.5, 1 / 3, 11 / 12, 11 / 12, 2.25, 2.25, 3)), + (3, None, False, 2, sqrt_or_null(None, None, 2 / 3, 2.0, 8 / 3, 14 / 3, 6.0)), + ], +) +def test_rolling_std( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + ddof: int, + expected: list[NonNestedLiteral], +) -> None: + expr = nwp.col("var_std").rolling_std( window_size, min_samples=min_samples, center=center, ddof=ddof ) result = dataframe(data).select(expr) - assert_equal_data(result, {"d": expected}) + assert_equal_data(result, {"var_std": expected}) @pytest.mark.parametrize( From 7dbb8828d6be85cfe62ab0e594e20b68b0907555 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:18:20 +0000 Subject: [PATCH 051/215] Add `test_rolling_{var,std}_order_by` --- tests/plan/rolling_expr_test.py | 144 ++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/tests/plan/rolling_expr_test.py b/tests/plan/rolling_expr_test.py index 8a4fc38b60..f8e3ee08c8 100644 --- a/tests/plan/rolling_expr_test.py +++ b/tests/plan/rolling_expr_test.py @@ -188,3 +188,147 @@ def test_rolling_mean_order_by( ) result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") assert_equal_data(result, {"b": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "ddof", "expected"), + [ + (2, None, False, 0, [None, None, 0.25, None, None, 1, 6.25]), + (2, 2, False, 1, [None, None, 0.5, None, None, 2, 12.5]), + (3, 2, False, 1, [None, None, 0.5, 0.5, 2, 2, 13]), + (3, 1, False, 0, [0, None, 0.25, 0.25, 1, 1, 8.666666666666666]), + (3, 1, True, 1, [0.5, None, 0.5, 2, 2, 13, 12.5]), + (4, 1, True, 1, [0.5, None, 0.5, 2.333333333333333, 4, 13, 13]), + ( + 5, + 1, + True, + 0, + [ + 0.25, + 0.25, + 1.5555555555555554, + 3.6874999999999996, + 11.1875, + 8.666666666666666, + 8.666666666666666, + ], + ), + ], +) +def test_rolling_var_order_by( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + ddof: int, + expected: list[NonNestedLiteral], +) -> None: + expr = ( + nwp.col("b") + .rolling_var(window_size, min_samples=min_samples, center=center, ddof=ddof) + .over(order_by="c") + ) + result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") + assert_equal_data(result, {"b": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "ddof", "expected"), + [ + (2, None, False, 0, [None, None, 0.5, None, None, 1, 2.5]), + ( + 2, + 2, + False, + 1, + [ + None, + None, + 0.7071067811865476, + None, + None, + 1.4142135623730951, + 3.5355339059327378, + ], + ), + ( + 3, + 2, + False, + 1, + [ + None, + None, + 0.7071067811865476, + 0.7071067811865476, + 1.4142135623730951, + 1.4142135623730951, + 3.605551275463989, + ], + ), + (3, 1, False, 0, [0.0, None, 0.5, 0.5, 1.0, 1.0, 2.943920288775949]), + ( + 3, + 1, + True, + 1, + [ + 0.7071067811865476, + None, + 0.7071067811865476, + 1.4142135623730951, + 1.4142135623730951, + 3.605551275463989, + 3.5355339059327378, + ], + ), + ( + 4, + 1, + True, + 1, + [ + 0.7071067811865476, + None, + 0.7071067811865476, + 1.5275252316519465, + 2.0, + 3.605551275463989, + 3.605551275463989, + ], + ), + ( + 5, + 1, + True, + 0, + [ + 0.5, + 0.5, + 1.247219128924647, + 1.920286436967152, + 3.344772040064913, + 2.943920288775949, + 2.943920288775949, + ], + ), + ], +) +def test_rolling_std_order_by( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + ddof: int, + expected: list[NonNestedLiteral], +) -> None: + expr = ( + nwp.col("b") + .rolling_std(window_size, min_samples=min_samples, center=center, ddof=ddof) + .over(order_by="c") + ) + result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") + assert_equal_data(result, {"b": expected}) From fdd901c61e4a4cecc78f990290916afd28ea4fff Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:29:14 +0000 Subject: [PATCH 052/215] test: Limit values to `abs_tol=1e-6` They are not tested beyond that level of precision --- tests/plan/rolling_expr_test.py | 100 ++++---------------------------- 1 file changed, 11 insertions(+), 89 deletions(-) diff --git a/tests/plan/rolling_expr_test.py b/tests/plan/rolling_expr_test.py index f8e3ee08c8..3c07d0cd44 100644 --- a/tests/plan/rolling_expr_test.py +++ b/tests/plan/rolling_expr_test.py @@ -169,8 +169,8 @@ def test_rolling_sum_order_by( (3, 2, False, [None, None, 1.5, 1.5, 3, 5, 7]), (3, 1, False, [1, None, 1.5, 1.5, 3, 5, 7]), (3, 1, True, [1.5, 1, 1.5, 3, 5, 7, 8.5]), - (4, 1, True, [1.5, 1, 1.5, 2.3333333333333335, 4, 7, 7]), - (5, 1, True, [1.5, 1.5, 2.3333333333333335, 3.25, 5.75, 7.0, 7.0]), + (4, 1, True, [1.5, 1, 1.5, 2.333333, 4, 7, 7]), + (5, 1, True, [1.5, 1.5, 2.333333, 3.25, 5.75, 7.0, 7.0]), ], ) def test_rolling_mean_order_by( @@ -196,24 +196,10 @@ def test_rolling_mean_order_by( (2, None, False, 0, [None, None, 0.25, None, None, 1, 6.25]), (2, 2, False, 1, [None, None, 0.5, None, None, 2, 12.5]), (3, 2, False, 1, [None, None, 0.5, 0.5, 2, 2, 13]), - (3, 1, False, 0, [0, None, 0.25, 0.25, 1, 1, 8.666666666666666]), + (3, 1, False, 0, [0, None, 0.25, 0.25, 1, 1, 8.666666]), (3, 1, True, 1, [0.5, None, 0.5, 2, 2, 13, 12.5]), - (4, 1, True, 1, [0.5, None, 0.5, 2.333333333333333, 4, 13, 13]), - ( - 5, - 1, - True, - 0, - [ - 0.25, - 0.25, - 1.5555555555555554, - 3.6874999999999996, - 11.1875, - 8.666666666666666, - 8.666666666666666, - ], - ), + (4, 1, True, 1, [0.5, None, 0.5, 2.333333, 4, 13, 13]), + (5, 1, True, 0, [0.25, 0.25, 1.555555, 3.6875, 11.1875, 8.666666, 8.666666]), ], ) def test_rolling_var_order_by( @@ -238,82 +224,18 @@ def test_rolling_var_order_by( ("window_size", "min_samples", "center", "ddof", "expected"), [ (2, None, False, 0, [None, None, 0.5, None, None, 1, 2.5]), - ( - 2, - 2, - False, - 1, - [ - None, - None, - 0.7071067811865476, - None, - None, - 1.4142135623730951, - 3.5355339059327378, - ], - ), - ( - 3, - 2, - False, - 1, - [ - None, - None, - 0.7071067811865476, - 0.7071067811865476, - 1.4142135623730951, - 1.4142135623730951, - 3.605551275463989, - ], - ), - (3, 1, False, 0, [0.0, None, 0.5, 0.5, 1.0, 1.0, 2.943920288775949]), + (2, 2, False, 1, [None, None, 0.707107, None, None, 1.414214, 3.535534]), + (3, 2, False, 1, [None, None, 0.707107, 0.707107, 1.414214, 1.414214, 3.605551]), + (3, 1, False, 0, [0.0, None, 0.5, 0.5, 1.0, 1.0, 2.943920]), ( 3, 1, True, 1, - [ - 0.7071067811865476, - None, - 0.7071067811865476, - 1.4142135623730951, - 1.4142135623730951, - 3.605551275463989, - 3.5355339059327378, - ], - ), - ( - 4, - 1, - True, - 1, - [ - 0.7071067811865476, - None, - 0.7071067811865476, - 1.5275252316519465, - 2.0, - 3.605551275463989, - 3.605551275463989, - ], - ), - ( - 5, - 1, - True, - 0, - [ - 0.5, - 0.5, - 1.247219128924647, - 1.920286436967152, - 3.344772040064913, - 2.943920288775949, - 2.943920288775949, - ], + [0.707107, None, 0.707107, 1.414214, 1.414214, 3.605551, 3.535534], ), + (4, 1, True, 1, [0.707107, None, 0.707107, 1.527525, 2.0, 3.605551, 3.605551]), + (5, 1, True, 0, [0.5, 0.5, 1.247219, 1.920286, 3.344772, 2.943920, 2.943920]), ], ) def test_rolling_std_order_by( From 99bb1b4a9ec373a6487318bb77b0200598d96b4b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:36:22 +0000 Subject: [PATCH 053/215] pull out `ddof` handling --- narwhals/_plan/arrow/expr.py | 8 ++------ narwhals/_plan/options.py | 4 ++++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 950b3cd784..aa841b9c0f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -590,7 +590,7 @@ def fill_null_with_strategy( # yes ruff, i know this is too complicated! # but we need to start somewhere - def rolling_expr( # noqa: PLR0912, PLR0914 + def rolling_expr( # noqa: PLR0914 self, node: ir.RollingExpr[F.RollingWindow], frame: Frame, name: str ) -> Self: function = node.function @@ -613,11 +613,7 @@ def rolling_expr( # noqa: PLR0912, PLR0914 count_in_window = valid_count - valid_count.shift(window_size).fill_null(0) predicate = count_in_window >= roll_options.min_samples if isinstance(function, (F.RollingVar, F.RollingStd)): - if fn_params := roll_options.fn_params: - ddof = fn_params.ddof - else: - msg = f"Expected `ddof` for {function!r}" - raise TypeError(msg) + ddof = roll_options.ddof # NOTE: Once this has coverage, probably need to add the `fill_null(0)` for the ends cum_sum_sq = compliant.pow(2).cum_sum().fill_null_with_strategy("forward") if window_size != 0: diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 4a31389170..2d4f6d8c6e 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -246,6 +246,10 @@ class RollingOptionsFixedWindow(Immutable): center: bool fn_params: RollingVarParams | None + @property + def ddof(self) -> int: + return 1 if self.fn_params is None else self.fn_params.ddof + def rolling_options( window_size: int, min_samples: int | None, /, *, center: bool, ddof: int | None = None From 50d7f0deba1d2f5ced532cea5dca37d963d645e9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:46:09 +0000 Subject: [PATCH 054/215] add `shift(fill_value=0)` back --- narwhals/_plan/arrow/expr.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index aa841b9c0f..b4122e94b7 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -605,19 +605,20 @@ def rolling_expr( # noqa: PLR0914 offset = 0 cum_sum = compliant.cum_sum().fill_null_with_strategy("forward") if window_size != 0: - rolling_sum = cum_sum - cum_sum.shift(window_size).fill_null(0) + rolling_sum = cum_sum - cum_sum.shift(window_size, fill_value=0).fill_null(0) else: rolling_sum = cum_sum valid_count = compliant.cum_count() - count_in_window = valid_count - valid_count.shift(window_size).fill_null(0) + count_in_window = valid_count - valid_count.shift(window_size, fill_value=0) predicate = count_in_window >= roll_options.min_samples if isinstance(function, (F.RollingVar, F.RollingStd)): ddof = roll_options.ddof - # NOTE: Once this has coverage, probably need to add the `fill_null(0)` for the ends cum_sum_sq = compliant.pow(2).cum_sum().fill_null_with_strategy("forward") if window_size != 0: - rolling_sum_sq = cum_sum_sq - cum_sum_sq.shift(window_size).fill_null(0) + rolling_sum_sq = cum_sum_sq - cum_sum_sq.shift( + window_size, fill_value=0 + ).fill_null(0) else: rolling_sum_sq = cum_sum_sq # TODO @dangotbanned: Better name? From 2e745f20911ff98f083d24d6e5ae02cf470add93 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:56:46 +0000 Subject: [PATCH 055/215] remove unreachable `window_size == 0` branches --- narwhals/_plan/arrow/expr.py | 17 ++++++----------- narwhals/_plan/options.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index b4122e94b7..9efa08c1e9 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -603,24 +603,19 @@ def rolling_expr( # noqa: PLR0914 compliant, offset = compliant._rolling_center(window_size) else: offset = 0 - cum_sum = compliant.cum_sum().fill_null_with_strategy("forward") - if window_size != 0: - rolling_sum = cum_sum - cum_sum.shift(window_size, fill_value=0).fill_null(0) - else: - rolling_sum = cum_sum valid_count = compliant.cum_count() count_in_window = valid_count - valid_count.shift(window_size, fill_value=0) predicate = count_in_window >= roll_options.min_samples + cum_sum = compliant.cum_sum().fill_null_with_strategy("forward") + rolling_sum = cum_sum - cum_sum.shift(window_size, fill_value=0).fill_null(0) + if isinstance(function, (F.RollingVar, F.RollingStd)): ddof = roll_options.ddof cum_sum_sq = compliant.pow(2).cum_sum().fill_null_with_strategy("forward") - if window_size != 0: - rolling_sum_sq = cum_sum_sq - cum_sum_sq.shift( - window_size, fill_value=0 - ).fill_null(0) - else: - rolling_sum_sq = cum_sum_sq + rolling_sum_sq = cum_sum_sq - cum_sum_sq.shift( + window_size, fill_value=0 + ).fill_null(0) # TODO @dangotbanned: Better name? rolling_something = rolling_sum_sq - (rolling_sum**2 / count_in_window) # TODO @dangotbanned: Better name? diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 2d4f6d8c6e..a527b51071 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -254,9 +254,20 @@ def ddof(self) -> int: def rolling_options( window_size: int, min_samples: int | None, /, *, center: bool, ddof: int | None = None ) -> RollingOptionsFixedWindow: + if window_size < 1: + msg = "window_size must be greater or equal than 1" + raise ValueError(msg) + if min_samples is None: + min_samples = window_size + elif min_samples < 1: + msg = "min_samples must be greater or equal than 1" + raise ValueError(msg) + elif min_samples > window_size: + msg = "`min_samples` must be less or equal than `window_size`" + raise InvalidOperationError(msg) return RollingOptionsFixedWindow( window_size=window_size, - min_samples=window_size if min_samples is None else min_samples, + min_samples=min_samples, center=center, fn_params=ddof if ddof is None else RollingVarParams(ddof=ddof), ) From 1b6370fb8152dd6a893bc87013bf43e2ecf89e1b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 17:00:37 +0000 Subject: [PATCH 056/215] consistently use `ArrowSeries.pow` --- narwhals/_plan/arrow/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 9efa08c1e9..2c05de1fb8 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -617,7 +617,7 @@ def rolling_expr( # noqa: PLR0914 window_size, fill_value=0 ).fill_null(0) # TODO @dangotbanned: Better name? - rolling_something = rolling_sum_sq - (rolling_sum**2 / count_in_window) + rolling_something = rolling_sum_sq - (rolling_sum.pow(2) / count_in_window) # TODO @dangotbanned: Better name? denominator = compliant._with_native( fn.max_horizontal((count_in_window - ddof).native, 0) @@ -630,7 +630,7 @@ def rolling_expr( # noqa: PLR0914 if offset: result = result.slice(offset) if isinstance(function, (F.RollingStd)): - result = result**0.5 + result = result.pow(0.5) return self.from_series(result) # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_compliant/series.py#L349-L415 From 7a24c51971ffa790ae39cd03d21009a6454507db Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 17:26:00 +0000 Subject: [PATCH 057/215] Add `test_rolling_expr_invalid` --- narwhals/_plan/expr.py | 4 ++-- narwhals/_plan/options.py | 14 +++++++------ tests/plan/expr_parsing_test.py | 36 +++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8a3d8b1be8..8ecfe96912 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -302,7 +302,7 @@ def rolling_var( min_samples: int | None = None, center: bool = False, ddof: int = 1, - ) -> Self: # pragma: no cover + ) -> Self: options = rolling_options(window_size, min_samples, center=center, ddof=ddof) return self._with_unary(F.RollingVar(options=options)) @@ -313,7 +313,7 @@ def rolling_std( min_samples: int | None = None, center: bool = False, ddof: int = 1, - ) -> Self: # pragma: no cover + ) -> Self: options = rolling_options(window_size, min_samples, center=center, ddof=ddof) return self._with_unary(F.RollingStd(options=options)) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index a527b51071..e34cadd05d 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable -from narwhals._utils import Implementation +from narwhals._utils import Implementation, ensure_type from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: @@ -254,16 +254,18 @@ def ddof(self) -> int: def rolling_options( window_size: int, min_samples: int | None, /, *, center: bool, ddof: int | None = None ) -> RollingOptionsFixedWindow: + ensure_type(window_size, int, param_name="window_size") + ensure_type(min_samples, int, type(None), param_name="min_samples") if window_size < 1: - msg = "window_size must be greater or equal than 1" - raise ValueError(msg) + msg = "`window_size` must be >= 1" + raise InvalidOperationError(msg) if min_samples is None: min_samples = window_size elif min_samples < 1: - msg = "min_samples must be greater or equal than 1" - raise ValueError(msg) + msg = "`min_samples` must be >= 1" + raise InvalidOperationError(msg) elif min_samples > window_size: - msg = "`min_samples` must be less or equal than `window_size`" + msg = "`min_samples` must be <= `window_size`" raise InvalidOperationError(msg) return RollingOptionsFixedWindow( window_size=window_size, diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index e559e5165d..0daccd2da6 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -700,3 +700,39 @@ def test_mode_invalid() -> None: TypeError, match=r"keep.+must be one of.+all.+any.+but got 'first'" ): nwp.col("a").mode(keep="first") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "context"), + [ + (-1, None, pytest.raises(ValueError, match=r"window_size.+>= 1")), + (2, -1, pytest.raises(ValueError, match=r"min_samples.+>= 1")), + ( + 1, + 2, + pytest.raises(InvalidOperationError, match=r"min_samples.+<=.+window_size"), + ), + ( + 4.2, + None, + pytest.raises(TypeError, match=r"Expected.+int.+got.+float.+\s+window_size="), + ), + ( + 2, + 4.2, + pytest.raises(TypeError, match=r"Expected.+int.+got.+float.+\s+min_samples="), + ), + ], +) +def test_rolling_expr_invalid( + window_size: int, min_samples: int | None, context: pytest.RaisesExc[Any] +) -> None: + a = nwp.col("a") + with context: + a.rolling_sum(window_size, min_samples=min_samples) + with context: + a.rolling_mean(window_size, min_samples=min_samples) + with context: + a.rolling_var(window_size, min_samples=min_samples) + with context: + a.rolling_std(window_size, min_samples=min_samples) From 582da649e477dd43671a66142261e0540e199d0e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 21:29:58 +0000 Subject: [PATCH 058/215] Move rolling to `ArrowSeries`, tidy up --- narwhals/_plan/arrow/expr.py | 66 +++++++++--------------------- narwhals/_plan/arrow/series.py | 56 +++++++++++++++++++++++++ narwhals/_plan/compliant/series.py | 12 ++++++ 3 files changed, 88 insertions(+), 46 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 2c05de1fb8..2746901067 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol, overload +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -31,7 +31,7 @@ from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable, Mapping, Sequence from typing_extensions import Self, TypeAlias @@ -584,53 +584,27 @@ def fill_null_with_strategy( is_duplicated = _boolean_length_preserving is_unique = _boolean_length_preserving - # TODO @dangotbanned: Plan composing with `functions.cum_*` - # Waaaaaay more of this needs to be shared - # https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L930-L1034 + _ROLLING: ClassVar[Mapping[type[F.RollingWindow], Callable[..., Series]]] = { + F.RollingSum: Series.rolling_sum, + F.RollingMean: Series.rolling_mean, + F.RollingVar: Series.rolling_var, + F.RollingStd: Series.rolling_std, + } - # yes ruff, i know this is too complicated! - # but we need to start somewhere - def rolling_expr( # noqa: PLR0914 + def rolling_expr( self, node: ir.RollingExpr[F.RollingWindow], frame: Frame, name: str ) -> Self: - function = node.function - roll_options = function.options - window_size = roll_options.window_size - compliant = self._dispatch_expr(node.input[0], frame, name) - - # Read up on polars impl to get some names for what this is - if roll_options.center: - compliant, offset = compliant._rolling_center(window_size) - else: - offset = 0 - - valid_count = compliant.cum_count() - count_in_window = valid_count - valid_count.shift(window_size, fill_value=0) - predicate = count_in_window >= roll_options.min_samples - cum_sum = compliant.cum_sum().fill_null_with_strategy("forward") - rolling_sum = cum_sum - cum_sum.shift(window_size, fill_value=0).fill_null(0) - - if isinstance(function, (F.RollingVar, F.RollingStd)): - ddof = roll_options.ddof - cum_sum_sq = compliant.pow(2).cum_sum().fill_null_with_strategy("forward") - rolling_sum_sq = cum_sum_sq - cum_sum_sq.shift( - window_size, fill_value=0 - ).fill_null(0) - # TODO @dangotbanned: Better name? - rolling_something = rolling_sum_sq - (rolling_sum.pow(2) / count_in_window) - # TODO @dangotbanned: Better name? - denominator = compliant._with_native( - fn.max_horizontal((count_in_window - ddof).native, 0) - ) - result = rolling_something.zip_with(predicate, None) / denominator - else: - result = rolling_sum.zip_with(predicate, None) - if isinstance(function, F.RollingMean): - result = result / count_in_window - if offset: - result = result.slice(offset) - if isinstance(function, (F.RollingStd)): - result = result.pow(0.5) + s = self._dispatch_expr(node.input[0], frame, name) + roll_options = node.function.options + size = roll_options.window_size + samples = roll_options.min_samples + center = roll_options.center + op = type(node.function) + method = self._ROLLING[op] + if op in {F.RollingSum, F.RollingMean}: + return self.from_series(method(s, size, min_samples=samples, center=center)) + ddof = roll_options.ddof + result = method(s, size, min_samples=samples, center=center, ddof=ddof) return self.from_series(result) # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_compliant/series.py#L349-L415 diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index f97f40a5b5..f381cfae3b 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -188,6 +188,62 @@ def _rolling_center(self, window_size: int) -> tuple[Self, int]: offset = offset_left + offset_right return self._with_native(fn.concat_vertical_chunked(arrays)), offset + def _rolling_sum(self, window_size: int, /) -> Self: + cum_sum = self.cum_sum().fill_null_with_strategy("forward") + return cum_sum - cum_sum.shift(window_size, fill_value=0).fill_null(0) + + def _rolling_count(self, window_size: int, /) -> Self: + cum_count = self.cum_count() + return cum_count - cum_count.shift(window_size, fill_value=0) + + def rolling_sum( + self, window_size: int, *, min_samples: int, center: bool = False + ) -> Self: + s, offset = self, 0 + if center: + s, offset = self._rolling_center(window_size) + rolling_count = s._rolling_count(window_size) + keep = rolling_count >= min_samples + result = s._rolling_sum(window_size).zip_with(keep, None) + return result.slice(offset) if offset else result + + def rolling_mean( + self, window_size: int, *, min_samples: int, center: bool = False + ) -> Self: + s, offset = self, 0 + if center: + s, offset = self._rolling_center(window_size) + rolling_count = s._rolling_count(window_size) + keep = rolling_count >= min_samples + result = (s._rolling_sum(window_size).zip_with(keep, None)) / rolling_count + return result.slice(offset) if offset else result + + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 + ) -> Self: + s, offset = self, 0 + if center: + s, offset = self._rolling_center(window_size) + rolling_count = s._rolling_count(window_size) + keep = rolling_count >= min_samples + + # NOTE: Yes, these two are different + sq_rolling_sum = s.pow(2)._rolling_sum(window_size) + rolling_sum_sq = s._rolling_sum(window_size).pow(2) + + # NOTE: Please somebody rename these two to *something else*! + rolling_something = sq_rolling_sum - (rolling_sum_sq / rolling_count) + denominator = s._with_native(fn.max_horizontal((rolling_count - ddof).native, 0)) + result = rolling_something.zip_with(keep, None) / denominator + return result.slice(offset) if offset else result + + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 + ) -> Self: + return self.rolling_var( + window_size, min_samples=min_samples, center=center, ddof=ddof + ).pow(0.5) + def zip_with(self, mask: Self, other: Self | None) -> Self: predicate = mask.native.combine_chunks() right = other.native if other is not None else other diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index c82f456ebe..34f79516d5 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -139,6 +139,18 @@ def is_empty(self) -> bool: return len(self) == 0 def is_in(self, other: Self) -> Self: ... + def rolling_mean( + self, window_size: int, *, min_samples: int, center: bool = False + ) -> Self: ... + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 + ) -> Self: ... + def rolling_sum( + self, window_size: int, *, min_samples: int, center: bool = False + ) -> Self: ... + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 + ) -> Self: ... def scatter(self, indices: Self, values: Self) -> Self: ... def slice(self, offset: int, length: int | None = None) -> Self: ... def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: ... From ef2a4f51f303bd7e0f0dea6315be62e5f08b22b6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 21:32:09 +0000 Subject: [PATCH 059/215] to-done --- narwhals/_plan/arrow/expr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 2746901067..d1e147babe 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -515,7 +515,6 @@ def _boolean_length_preserving( # - [ ] `map_batches` # - [x] elementwise # - [ ] scalar - # - [ ] `rolling_expr` has 4 variants # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: From c519f706ecc1a09c03cdfa04fd8235a0391968b5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 24 Nov 2025 22:52:20 +0000 Subject: [PATCH 060/215] quick+dirty `map_batches(returns_scalar=True)` Need to follow-up on a question about `main` deviating from `polars` --- narwhals/_plan/_guards.py | 7 +++++- narwhals/_plan/arrow/expr.py | 40 +++++++++++++++++++++----------- narwhals/_plan/compliant/expr.py | 3 ++- tests/plan/compliant_test.py | 4 ---- 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 8231af57ac..30503c55c3 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -24,7 +24,7 @@ NativeSeriesT, Seq, ) - from narwhals.typing import NonNestedLiteral + from narwhals.typing import NonNestedLiteral, PythonLiteral T = TypeVar("T") @@ -38,6 +38,7 @@ bytes, Decimal, ) +_PYTHON_LITERAL_TPS = (*_NON_NESTED_LITERAL_TPS, list, tuple, type(None)) def _ir(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 @@ -68,6 +69,10 @@ def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) +def is_python_literal(obj: Any) -> TypeIs[PythonLiteral]: + return isinstance(obj, _PYTHON_LITERAL_TPS) + + def is_expr(obj: Any) -> TypeIs[Expr]: return isinstance(obj, _expr().Expr) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index d1e147babe..891a8b4e8c 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -7,7 +7,7 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan import expressions as ir -from narwhals._plan._guards import is_function_expr, is_seq_column +from narwhals._plan._guards import is_function_expr, is_python_literal, is_seq_column from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co @@ -511,21 +511,18 @@ def _boolean_length_preserving( final_result = mask(index.native, aggregated.get_column(idx_name).native) return self.from_series(index._with_native(final_result)) - # TODO @dangotbanned: top-level, complex-ish nodes - # - [ ] `map_batches` - # - [x] elementwise - # - [ ] scalar - - # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` - def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: - if node.is_scalar: - # NOTE: Just trying to avoid redoing the whole API for `Series` - # https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_compliant/expr.py#L738-L755 - msg = "Only elementwise is currently supported" - raise NotImplementedError(msg) + # NOTE: Can't implement in `EagerExpr` (like on `main`) + # The version here is missing `__narwhals_namespace__` + def map_batches( + self, node: ir.AnonymousExpr, frame: Frame, name: str + ) -> Self | Scalar: series = self._dispatch_expr(node.input[0], frame, name) udf = node.function.function result: Series | Into1DArray = udf(series) + if node.is_scalar: + return ArrowScalar.from_unknown( + result, name, dtype=node.function.return_dtype, version=self.version + ) if not isinstance(result, Series): result = Series.from_numpy(result, name, version=self.version) if dtype := node.function.return_dtype: @@ -664,6 +661,23 @@ def from_series(cls, series: Series) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) + @classmethod + def from_unknown( + cls, + value: Any, + name: str = "literal", + /, + *, + dtype: IntoDType | None = None, + version: Version = Version.MAIN, + ) -> Self: + if isinstance(value, pa.Scalar): + return cls.from_native(value, name, version) + if is_python_literal(value): + return cls.from_python(value, name, dtype=dtype, version=version) + native = fn.lit(value, fn.dtype_native(dtype, version)) + return cls.from_native(native, name, version) + def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series: msg = f"Expected unreachable, but hit at: {node!r}" raise InvalidOperationError(msg) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index faa5892923..bd1cdf4049 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -80,9 +80,10 @@ def is_nan( def is_null( self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str ) -> Self: ... + # NOTE: `Scalar` when using `returns_scalar=True` def map_batches( self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str - ) -> Self: ... + ) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 18eed7b49a..d917968ab1 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -427,10 +427,6 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: .map_batches(lambda s: (s.to_numpy().max()), returns_scalar=True), {"j": [15], "k": [42]}, id="map_batches-return_scalar", - marks=pytest.mark.xfail( - reason="not implemented `map_batches(returns_scalar=True)` for `pyarrow`", - raises=NotImplementedError, - ), ), pytest.param( [nwp.col("g").len(), nwp.col("m").last(), nwp.col("h").count()], From aa52033fcf30ae7e7aad692c8eedab3f9bdb4267 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:05:07 +0000 Subject: [PATCH 061/215] test: Move existing `map_batches` tests --- tests/plan/compliant_test.py | 40 --------------- tests/plan/map_batches_test.py | 91 ++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 40 deletions(-) create mode 100644 tests/plan/map_batches_test.py diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index d917968ab1..ba9c95b0f0 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -13,12 +13,10 @@ pytest.importorskip("numpy") import datetime as dt -import numpy as np import pyarrow as pa import narwhals as nw from narwhals import _plan as nwp -from narwhals._utils import Version from tests.plan.utils import assert_equal_data, dataframe, first, last, series if TYPE_CHECKING: @@ -390,44 +388,6 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: {"literal": ["a|b|c|d|20"]}, id="concat_str-all-lit", ), - pytest.param( - [ - nwp.col("a") - .alias("...") - .map_batches( - lambda s: s.from_iterable( - [*((len(s) - 1) * [type(s.dtype).__name__.lower()]), "last"], - version=Version.MAIN, - name="funky", - ), - is_elementwise=True, - ), - nwp.col("a"), - ], - {"funky": ["string", "string", "last"], "a": ["A", "B", "A"]}, - id="map_batches-series", - ), - pytest.param( - nwp.col("b") - .map_batches(lambda s: s.to_numpy() + 1, nw.Float64(), is_elementwise=True) - .sum(), - {"b": [9.0]}, - id="map_batches-numpy", - ), - pytest.param( - ncs.by_name("b", "c", "d") - .map_batches(lambda s: np.append(s.to_numpy(), [10, 2]), is_elementwise=True) - .sort(), - {"b": [1, 2, 2, 3, 10], "c": [2, 2, 4, 9, 10], "d": [2, 7, 8, 8, 10]}, - id="map_batches-selector", - ), - pytest.param( - nwp.col("j", "k") - .fill_null(15) - .map_batches(lambda s: (s.to_numpy().max()), returns_scalar=True), - {"j": [15], "k": [42]}, - id="map_batches-return_scalar", - ), pytest.param( [nwp.col("g").len(), nwp.col("m").last(), nwp.col("h").count()], {"g": [3], "m": [2], "h": [1]}, diff --git a/tests/plan/map_batches_test.py b/tests/plan/map_batches_test.py new file mode 100644 index 0000000000..7e2589c1a1 --- /dev/null +++ b/tests/plan/map_batches_test.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from narwhals._plan import selectors as ncs +from narwhals._utils import Version +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Sequence + + from tests.conftest import Data + +pytest.importorskip("pyarrow") +pytest.importorskip("numpy") +import numpy as np + + +@pytest.fixture +def data() -> Data: + return { + "a": ["A", "B", "A"], + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + "f": [True, False, None], + "g": [False, None, False], + "h": [None, None, True], + "i": [None, None, None], + "j": [12.1, None, 4.0], + "k": [42, 10, None], + "l": [4, 5, 6], + "m": [0, 1, 2], + "n": ["dogs", "cats", None], + "o": ["play", "swim", "walk"], + } + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + pytest.param( + [ + nwp.col("a") + .alias("...") + .map_batches( + lambda s: s.from_iterable( + [*((len(s) - 1) * [type(s.dtype).__name__.lower()]), "last"], + version=Version.MAIN, + name="funky", + ), + is_elementwise=True, + ), + nwp.col("a"), + ], + {"funky": ["string", "string", "last"], "a": ["A", "B", "A"]}, + id="series", + ), + pytest.param( + nwp.col("b") + .map_batches(lambda s: s.to_numpy() + 1, nw.Float64(), is_elementwise=True) + .sum(), + {"b": [9.0]}, + id="numpy", + ), + pytest.param( + ncs.by_name("b", "c", "d") + .map_batches(lambda s: np.append(s.to_numpy(), [10, 2]), is_elementwise=True) + .sort(), + {"b": [1, 2, 2, 3, 10], "c": [2, 2, 4, 9, 10], "d": [2, 7, 8, 8, 10]}, + id="selector", + ), + pytest.param( + nwp.col("j", "k") + .fill_null(15) + .map_batches(lambda s: (s.to_numpy().max()), returns_scalar=True), + {"j": [15], "k": [42]}, + id="returns_scalar", + ), + ], +) +def test_map_batches( + data: Data, expr: nwp.Expr | Sequence[nwp.Expr], expected: Data +) -> None: + result = dataframe(data).select(expr) + assert_equal_data(result, expected) From 100d82e7b6375a25304c192712947a24fb6767fc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 14:11:58 +0000 Subject: [PATCH 062/215] test: Clean up and doc `map_batches` --- tests/plan/map_batches_test.py | 94 ++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/tests/plan/map_batches_test.py b/tests/plan/map_batches_test.py index 7e2589c1a1..41e4ce2e41 100644 --- a/tests/plan/map_batches_test.py +++ b/tests/plan/map_batches_test.py @@ -7,12 +7,16 @@ import narwhals as nw import narwhals._plan as nwp from narwhals._plan import selectors as ncs -from narwhals._utils import Version from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: from collections.abc import Sequence + from narwhals._plan.compliant.typing import ( + SeriesAny as CompliantSeriesAny, + SeriesT as CompliantSeriesT, + ) + from narwhals.typing import _1DArray, _NumpyScalar from tests.conftest import Data pytest.importorskip("pyarrow") @@ -23,64 +27,86 @@ @pytest.fixture def data() -> Data: return { - "a": ["A", "B", "A"], - "b": [1, 2, 3], + "a": [1, 2, 3], + "b": [4, 5, 6], "c": [9, 2, 4], "d": [8, 7, 8], - "e": [None, 9, 7], - "f": [True, False, None], - "g": [False, None, False], - "h": [None, None, True], - "i": [None, None, None], + "e": ["A", "B", "A"], "j": [12.1, None, 4.0], "k": [42, 10, None], - "l": [4, 5, 6], - "m": [0, 1, 2], - "n": ["dogs", "cats", None], - "o": ["play", "swim", "walk"], + "z": [7.0, 8.0, 9.0], } +def elementwise_series(s: CompliantSeriesT, /) -> CompliantSeriesT: + dtype_name = type(s.dtype).__name__.lower() + repeat_name = (dtype_name,) * (len(s) - 1) + values = [*repeat_name, "last"] + return s.from_iterable(values, version=s.version, name="funky") + + +def elementwise_1d_array(s: CompliantSeriesAny, /) -> _1DArray: + return s.to_numpy() + 1 + + +def to_numpy(s: CompliantSeriesAny, /) -> _1DArray: + return s.to_numpy() + + +def groupwise_1d_array(s: CompliantSeriesAny, /) -> _1DArray: + return np.append(s.to_numpy(), [10, 2]) + + +def aggregation_np_scalar(s: CompliantSeriesAny, /) -> _NumpyScalar: + return s.to_numpy().max() + + @pytest.mark.parametrize( ("expr", "expected"), [ pytest.param( [ - nwp.col("a") + nwp.col("e") .alias("...") - .map_batches( - lambda s: s.from_iterable( - [*((len(s) - 1) * [type(s.dtype).__name__.lower()]), "last"], - version=Version.MAIN, - name="funky", - ), - is_elementwise=True, - ), - nwp.col("a"), + .map_batches(elementwise_series, is_elementwise=True), + nwp.col("e"), ], - {"funky": ["string", "string", "last"], "a": ["A", "B", "A"]}, - id="series", + {"funky": ["string", "string", "last"], "e": ["A", "B", "A"]}, + id="is_elementwise-series", ), pytest.param( - nwp.col("b") - .map_batches(lambda s: s.to_numpy() + 1, nw.Float64(), is_elementwise=True) + nwp.col("a", "b", "z").map_batches(to_numpy), + {"a": [1, 2, 3], "b": [4, 5, 6], "z": [7.0, 8.0, 9.0]}, + id="to-numpy", + ), + pytest.param( + nwp.col("a") + .map_batches(elementwise_1d_array, nw.Float64, is_elementwise=True) .sum(), - {"b": [9.0]}, - id="numpy", + {"a": [9.0]}, + id="is_elementwise-1d-array", + ), + pytest.param( + nwp.col("a").map_batches(elementwise_1d_array, nw.Float64).sum(), + {"a": [9.0]}, + id="unknown-1d-array", ), pytest.param( - ncs.by_name("b", "c", "d") - .map_batches(lambda s: np.append(s.to_numpy(), [10, 2]), is_elementwise=True) + ncs.by_index(0, 2, 3) + .map_batches(groupwise_1d_array, is_elementwise=True) .sort(), - {"b": [1, 2, 2, 3, 10], "c": [2, 2, 4, 9, 10], "d": [2, 7, 8, 8, 10]}, - id="selector", + {"a": [1, 2, 2, 3, 10], "c": [2, 2, 4, 9, 10], "d": [2, 7, 8, 8, 10]}, + # NOTE: Maybe this should be rejected because of the length change? + # It doesn't break broadcasting rules, but uses an optional argument incorrectly + # and we only know *after* execution + id="is_elementwise-1d-array-groupwise", ), pytest.param( nwp.col("j", "k") .fill_null(15) - .map_batches(lambda s: (s.to_numpy().max()), returns_scalar=True), + .map_batches(aggregation_np_scalar, returns_scalar=True), {"j": [15], "k": [42]}, - id="returns_scalar", + id="returns_scalar-np-scalar", ), ], ) From 3f3884f297238535ab8fa07d4a01df11bd74bcbb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 14:20:20 +0000 Subject: [PATCH 063/215] test: Add `test_map_batches_invalid` --- tests/plan/map_batches_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/plan/map_batches_test.py b/tests/plan/map_batches_test.py index 41e4ce2e41..41579c840b 100644 --- a/tests/plan/map_batches_test.py +++ b/tests/plan/map_batches_test.py @@ -115,3 +115,19 @@ def test_map_batches( ) -> None: result = dataframe(data).select(expr) assert_equal_data(result, expected) + + +@pytest.mark.xfail( + reason="TODO: Need to raise when `returns_scalar=False` does not return a Series" +) +def test_map_batches_invalid(data: Data) -> None: + df = dataframe(data) + expr = nwp.col("a", "b", "z").map_batches(aggregation_np_scalar) + + msg = ( + r"`map(?:_batches)?` with `returns_scalar=False` must return a Series; found " + "'numpy.int64'.\n\nIf `returns_scalar` is set to `True`, a returned value can be " + "a scalar value." + ) + with pytest.raises(TypeError, match=msg): + df.select(expr) From 67079192b11fb0380b11bab2a5eae9fe367066d8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 14:23:16 +0000 Subject: [PATCH 064/215] test(typing): happier mypy --- tests/plan/map_batches_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/plan/map_batches_test.py b/tests/plan/map_batches_test.py index 41579c840b..2076471244 100644 --- a/tests/plan/map_batches_test.py +++ b/tests/plan/map_batches_test.py @@ -54,11 +54,13 @@ def to_numpy(s: CompliantSeriesAny, /) -> _1DArray: def groupwise_1d_array(s: CompliantSeriesAny, /) -> _1DArray: - return np.append(s.to_numpy(), [10, 2]) + result: _1DArray = np.append(s.to_numpy(), [10, 2]) + return result def aggregation_np_scalar(s: CompliantSeriesAny, /) -> _NumpyScalar: - return s.to_numpy().max() + result: _NumpyScalar = s.to_numpy().max() + return result @pytest.mark.parametrize( From c87fb7477819434a59ffabd8374e08b9206567fe Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:37:55 +0000 Subject: [PATCH 065/215] fix: raise when `returns_scalar=False` does not return a `Series` --- narwhals/_plan/arrow/expr.py | 35 ++++++++++++++++++++------ narwhals/_plan/arrow/series.py | 2 +- tests/plan/map_batches_test.py | 46 ++++++++++++++++++++++------------ 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 891a8b4e8c..e49b2d8d8f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, ClassVar, Protocol, overload import pyarrow as pa # ignore-banned-import @@ -7,7 +8,12 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan import expressions as ir -from narwhals._plan._guards import is_function_expr, is_python_literal, is_seq_column +from narwhals._plan._guards import ( + is_function_expr, + is_iterable_reject, + is_python_literal, + is_seq_column, +) from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co @@ -27,7 +33,13 @@ IsUnique, ) from narwhals._plan.expressions.functions import NullCount -from narwhals._utils import Implementation, Version, _StoresNative, not_implemented +from narwhals._utils import ( + Implementation, + Version, + _StoresNative, + not_implemented, + qualified_type_name, +) from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: @@ -76,7 +88,7 @@ Shift, ) from narwhals._plan.typing import Seq - from narwhals.typing import Into1DArray, IntoDType, PythonLiteral + from narwhals.typing import IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" Scalar: TypeAlias = "ArrowScalar" @@ -518,13 +530,22 @@ def map_batches( ) -> Self | Scalar: series = self._dispatch_expr(node.input[0], frame, name) udf = node.function.function - result: Series | Into1DArray = udf(series) + udf_result: Series | Iterable[Any] | Any = udf(series) if node.is_scalar: return ArrowScalar.from_unknown( - result, name, dtype=node.function.return_dtype, version=self.version + udf_result, name, dtype=node.function.return_dtype, version=self.version + ) + if isinstance(udf_result, Series): + result = udf_result + elif isinstance(udf_result, Iterable) and not is_iterable_reject(udf_result): + result = Series.from_iterable(udf_result, name=name, version=self.version) + else: + msg = ( + "`map_batches` with `returns_scalar=False` must return a Series; " + f"found '{qualified_type_name(udf_result)}'.\n\nIf `returns_scalar` " + "is set to `True`, a returned value can be a scalar value." ) - if not isinstance(result, Series): - result = Series.from_numpy(result, name, version=self.version) + raise TypeError(msg) if dtype := node.function.return_dtype: result = result.cast(dtype) return self.from_series(result) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index f381cfae3b..023c95485a 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -79,7 +79,7 @@ def from_iterable( name: str = "", dtype: IntoDType | None = None, ) -> Self: - dtype_pa = narwhals_to_native_dtype(dtype, version) if dtype else None + dtype_pa = fn.dtype_native(dtype, version) return cls.from_native(fn.chunked_array([data], dtype_pa), name, version=version) def cast(self, dtype: IntoDType) -> Self: diff --git a/tests/plan/map_batches_test.py b/tests/plan/map_batches_test.py index 2076471244..bfdad27303 100644 --- a/tests/plan/map_batches_test.py +++ b/tests/plan/map_batches_test.py @@ -1,16 +1,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest import narwhals as nw import narwhals._plan as nwp from narwhals._plan import selectors as ncs -from tests.plan.utils import assert_equal_data, dataframe +from tests.plan.utils import assert_equal_data, dataframe, re_compile if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence + + import pyarrow as pa from narwhals._plan.compliant.typing import ( SeriesAny as CompliantSeriesAny, @@ -19,7 +21,7 @@ from narwhals.typing import _1DArray, _NumpyScalar from tests.conftest import Data -pytest.importorskip("pyarrow") + pytest.importorskip("numpy") import numpy as np @@ -63,6 +65,14 @@ def aggregation_np_scalar(s: CompliantSeriesAny, /) -> _NumpyScalar: return result +def aggregation_pa_scalar(s: CompliantSeriesAny) -> pa.Scalar[Any]: + pytest.importorskip("pyarrow") + import pyarrow as pa + + result: pa.Scalar[Any] = pa.array(s.to_list())[0] + return result + + @pytest.mark.parametrize( ("expr", "expected"), [ @@ -119,17 +129,21 @@ def test_map_batches( assert_equal_data(result, expected) -@pytest.mark.xfail( - reason="TODO: Need to raise when `returns_scalar=False` does not return a Series" +@pytest.mark.parametrize( + ("udf", "result_type_name"), + [ + (aggregation_np_scalar, "'numpy.int64'"), + (aggregation_pa_scalar, ".+pyarrow.+scalar.+"), + (len, "'int'"), + (str, "'str'"), + ], ) -def test_map_batches_invalid(data: Data) -> None: - df = dataframe(data) - expr = nwp.col("a", "b", "z").map_batches(aggregation_np_scalar) - - msg = ( - r"`map(?:_batches)?` with `returns_scalar=False` must return a Series; found " - "'numpy.int64'.\n\nIf `returns_scalar` is set to `True`, a returned value can be " - "a scalar value." +def test_map_batches_invalid( + data: Data, udf: Callable[[Any], Any], result_type_name: str +) -> None: + expr = nwp.col("a", "b", "z").map_batches(udf) + pattern = re_compile( + rf"map.+ with `returns_scalar=False` must return a Series.+{result_type_name}" ) - with pytest.raises(TypeError, match=msg): - df.select(expr) + with pytest.raises(TypeError, match=pattern): + dataframe(data).select(expr) From 24fb35035a1292e26650552ecc3b00f424f86dc1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 16:07:32 +0000 Subject: [PATCH 066/215] =?UTF-8?q?fix:=20Check=20the=20correct=20flags=20?= =?UTF-8?q?for=20mutual=20exclusivity=20=F0=9F=98=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- narwhals/_plan/options.py | 12 ++++++++---- tests/plan/expr_parsing_test.py | 8 ++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index e34cadd05d..4914b8b59b 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -76,6 +76,9 @@ def __str__(self) -> str: return name.replace("|", " | ") +_INVALID = FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING + + class FunctionOptions(Immutable): """https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs""" # noqa: D415 @@ -101,11 +104,12 @@ def is_input_wildcard_expansion(self) -> bool: return self.flags.is_input_wildcard_expansion() def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: - if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: - msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." # pragma: no cover - raise TypeError(msg) # pragma: no cover + new_flags = self.flags | flags + if _INVALID in new_flags: + msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." + raise TypeError(msg) obj = FunctionOptions.__new__(FunctionOptions) - object.__setattr__(obj, "flags", self.flags | flags) + object.__setattr__(obj, "flags", new_flags) return obj def with_elementwise(self) -> FunctionOptions: diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 0daccd2da6..aabf84798e 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -377,6 +377,14 @@ def test_binary_expr_shape_invalid() -> None: a.fill_null(1) // b.rolling_mean(5) +def test_map_batches_invalid() -> None: + with pytest.raises( + TypeError, + match=r"A function cannot both return a scalar and preserve length, they are mutually exclusive", + ): + nwp.col("a").map_batches(lambda x: x, is_elementwise=True, returns_scalar=True) + + @pytest.mark.parametrize("into_iter", [list, tuple, deque, iter, dict.fromkeys, set]) def test_is_in_seq(into_iter: IntoIterable) -> None: expected = 1, 2, 3 From 632932f3378a26dcb067c14551d4ce456c25325f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 17:37:01 +0000 Subject: [PATCH 067/215] test: Also check list scalar --- tests/plan/map_batches_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/plan/map_batches_test.py b/tests/plan/map_batches_test.py index bfdad27303..6c8f9776e8 100644 --- a/tests/plan/map_batches_test.py +++ b/tests/plan/map_batches_test.py @@ -120,6 +120,18 @@ def aggregation_pa_scalar(s: CompliantSeriesAny) -> pa.Scalar[Any]: {"j": [15], "k": [42]}, id="returns_scalar-np-scalar", ), + pytest.param( + [ + nwp.col("a").map_batches( + lambda _: [1, 2], + returns_scalar=True, + return_dtype=nw.List(nw.Int64()), + ), + nwp.col("b").last(), + ], + {"a": [[1, 2]], "b": [6]}, + id="returns_scalar-list", + ), ], ) def test_map_batches( From 951272c263f6a7cf986bb1e9ec831bbfebd30204 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 17:52:47 +0000 Subject: [PATCH 068/215] dont repeat yourself, dont repeat yourself, ... --- narwhals/_plan/arrow/expr.py | 9 +++------ narwhals/_plan/arrow/functions.py | 8 ++++++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index e49b2d8d8f..7c501c12da 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -47,7 +47,6 @@ from typing_extensions import Self, TypeAlias - from narwhals._arrow.typing import Incomplete from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction @@ -714,17 +713,15 @@ def to_series(self) -> Series: return self.broadcast(1) def to_python(self) -> PythonLiteral: - return self.native.as_py() # type: ignore[no-any-return] + result: PythonLiteral = self.native.as_py() + return result def broadcast(self, length: int) -> Series: scalar = self.native if length == 1: chunked = fn.chunked_array(scalar) else: - # NOTE: Same issue as `pa.scalar` overlapping overloads - # https://github.com/zen-xu/pyarrow-stubs/pull/209 - pa_repeat: Incomplete = pa.repeat - chunked = fn.chunked_array(pa_repeat(scalar, length)) + chunked = fn.chunked_array(fn.repeat_unchecked(scalar, length)) return Series.from_native(chunked, self.name, version=self.version) def count(self, node: Count, frame: Frame, name: str) -> Scalar: diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 00aa21d893..c5e36673b6 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -780,14 +780,18 @@ def date_range( def repeat(value: ScalarAny | NonNestedLiteral, n: int) -> ArrayAny: - repeat_: Incomplete = pa.repeat value = value if isinstance(value, pa.Scalar) else lit(value) + return repeat_unchecked(value, n) + + +def repeat_unchecked(value: ScalarAny, /, n: int) -> ArrayAny: + repeat_: Incomplete = pa.repeat result: ArrayAny = repeat_(value, n) return result def repeat_like(value: NonNestedLiteral, n: int, native: ArrowAny) -> ArrayAny: - return repeat(lit(value, native.type), n) + return repeat_unchecked(lit(value, native.type), n) def nulls_like(n: int, native: ArrowAny) -> ArrayAny: From 2685791981b6b4dfc8d273c2dfac02568905bd07 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 17:57:07 +0000 Subject: [PATCH 069/215] remove dead code --- narwhals/_plan/arrow/group_by.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index c9eb41fd8c..29d87f866e 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -17,7 +17,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Collection, Iterator, Mapping, Sequence + from collections.abc import Iterator, Mapping, Sequence from typing_extensions import Self, TypeAlias @@ -149,15 +149,6 @@ def group_by_error( return InvalidOperationError(msg) -def multiple_null_partitions_error(column_names: Collection[str]) -> NotImplementedError: - backend = Implementation.PYARROW - msg = ( - f"`over(*partition_by)` where multiple columns contain null values is not yet supported for {backend!r}\n" - f"Got: {list(column_names)!r}" - ) - return NotImplementedError(msg) - - class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): _df: Frame _keys: Seq[NamedIR] From 11ed36aefd1e52cf50e92274ee0f3b1bf7dc38cd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:06:13 +0000 Subject: [PATCH 070/215] chore(typing): Move `map_batches` to `EagerExpr` Relies on a `CompliantSeries` existing, which can't be expected in `CompliantExpr` --- narwhals/_plan/compliant/expr.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index bd1cdf4049..33a01bba39 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -16,7 +16,7 @@ from typing_extensions import Self, TypeAlias from narwhals._plan import expressions as ir - from narwhals._plan.compliant.scalar import CompliantScalar + from narwhals._plan.compliant.scalar import CompliantScalar, EagerScalar from narwhals._plan.expressions import ( BinaryExpr, FunctionExpr, @@ -80,10 +80,6 @@ def is_nan( def is_null( self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str ) -> Self: ... - # NOTE: `Scalar` when using `returns_scalar=True` - def map_batches( - self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str - ) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` @@ -188,7 +184,7 @@ def skew( self, node: FunctionExpr[F.Skew], frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - # mixed/todo + # TODO @dangotbanned: Reorder these def clip( self, node: FunctionExpr[F.Clip], frame: FrameT_contra, name: str ) -> Self: ... @@ -261,6 +257,10 @@ def is_in_series( frame: FrameT_contra, name: str, ) -> Self: ... + # NOTE: `Scalar` when using `returns_scalar=True` + def map_batches( + self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str + ) -> Self | EagerScalar[FrameT_contra, SeriesT]: ... def __bool__(self) -> Literal[True]: # NOTE: Avoids falling back to `__len__` when truth-testing on dispatch return True From 2f1ca620bc9bc5087a7c9c1ffd931e309b4fa65c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:56:35 +0000 Subject: [PATCH 071/215] feat: Add and use `ArrowSeries.diff(n=...)` for rolling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit > Computes the first order difference of an array, It internally calls the scalar function “subtract” to compute differences, so its behavior and supported types are the same as “subtract”. https://arrow.apache.org/docs/python/generated/pyarrow.compute.pairwise_diff.html#pyarrow.compute.pairwise_diff --- narwhals/_plan/arrow/functions.py | 6 +++--- narwhals/_plan/arrow/series.py | 6 +++++- narwhals/_plan/compliant/series.py | 1 + 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index c5e36673b6..e037c492fd 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -449,12 +449,12 @@ def cumulative(native: ChunkedArrayAny, f: F.CumAgg, /) -> ChunkedArrayAny: return func(native) if not f.reverse else reverse(func(reverse(native))) -def diff(native: ChunkedOrArrayT) -> ChunkedOrArrayT: +def diff(native: ChunkedOrArrayT, n: int = 1) -> ChunkedOrArrayT: # pyarrow.lib.ArrowInvalid: Vector kernel cannot execute chunkwise and no chunked exec function was defined return ( - pc.pairwise_diff(native) + pc.pairwise_diff(native, n) if isinstance(native, pa.Array) - else chunked_array(pc.pairwise_diff(native.combine_chunks())) + else chunked_array(pc.pairwise_diff(native.combine_chunks(), n)) ) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 023c95485a..e0bf71b0d0 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -166,6 +166,9 @@ def fill_null_with_strategy( ) -> Self: return self._with_native(fn.fill_null_with_strategy(self.native, strategy, limit)) + def diff(self, n: int = 1) -> Self: + return self._with_native(fn.diff(self.native, n)) + def shift(self, n: int, *, fill_value: NonNestedLiteral = None) -> Self: return self._with_native(fn.shift(self.native, n, fill_value=fill_value)) @@ -188,13 +191,14 @@ def _rolling_center(self, window_size: int) -> tuple[Self, int]: offset = offset_left + offset_right return self._with_native(fn.concat_vertical_chunked(arrays)), offset + # TODO @dangotbanned: Try rewriting with `diff(window_size)`? def _rolling_sum(self, window_size: int, /) -> Self: cum_sum = self.cum_sum().fill_null_with_strategy("forward") return cum_sum - cum_sum.shift(window_size, fill_value=0).fill_null(0) def _rolling_count(self, window_size: int, /) -> Self: cum_count = self.cum_count() - return cum_count - cum_count.shift(window_size, fill_value=0) + return cum_count.diff(window_size).fill_null(cum_count) def rolling_sum( self, window_size: int, *, min_samples: int, center: bool = False diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 34f79516d5..6d1a587974 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -125,6 +125,7 @@ def cum_max(self, *, reverse: bool = False) -> Self: ... def cum_min(self, *, reverse: bool = False) -> Self: ... def cum_prod(self, *, reverse: bool = False) -> Self: ... def cum_sum(self, *, reverse: bool = False) -> Self: ... + def diff(self, n: int = 1) -> Self: ... def fill_null(self, value: NonNestedLiteral | Self) -> Self: ... def fill_null_with_strategy( self, strategy: FillNullStrategy, limit: int | None = None From d586d55a083c8571de58faee17003657ffcdc3b8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 19:46:44 +0000 Subject: [PATCH 072/215] perf?: Use `diff(window_size)` in `_rolling_sum` too --- narwhals/_plan/arrow/series.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index e0bf71b0d0..7344c5bdd3 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -191,10 +191,9 @@ def _rolling_center(self, window_size: int) -> tuple[Self, int]: offset = offset_left + offset_right return self._with_native(fn.concat_vertical_chunked(arrays)), offset - # TODO @dangotbanned: Try rewriting with `diff(window_size)`? def _rolling_sum(self, window_size: int, /) -> Self: cum_sum = self.cum_sum().fill_null_with_strategy("forward") - return cum_sum - cum_sum.shift(window_size, fill_value=0).fill_null(0) + return cum_sum.diff(window_size).fill_null(cum_sum) def _rolling_count(self, window_size: int, /) -> Self: cum_count = self.cum_count() From 7169fef4755eb04015716cca36283c61cf5bd39d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 20:22:41 +0000 Subject: [PATCH 073/215] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/plan/mode_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plan/mode_test.py b/tests/plan/mode_test.py index e3cef89f2f..1a647d2b68 100644 --- a/tests/plan/mode_test.py +++ b/tests/plan/mode_test.py @@ -26,7 +26,7 @@ def data() -> Data: (nwp.col("b").filter(nwp.col("b") != 3).mode(), {"b": [1, 2, 4]}), (nwp.col("a").mode().sum(), {"a": [3]}), ], - ids=["single", "multiple-1", "multiple-2", "mutliple-agg"], + ids=["single", "multiple-1", "multiple-2", "multiple-agg"], ) def test_mode_expr_keep_all(data: Data, expr: nwp.Expr, expected: Data) -> None: result = dataframe(data).select(expr).sort(ncs.first()) From 6c836c5a5b1327b69ec32a5c620462fb3f133248 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 25 Nov 2025 22:45:54 +0000 Subject: [PATCH 074/215] feat(DRAFT): Add `struct.field` Forgot that I implemented an extended version already aha --- narwhals/_plan/arrow/expr.py | 37 ++++++++++++++++++++++++++++- narwhals/_plan/arrow/functions.py | 38 ++++++++++++++++++------------ narwhals/_plan/compliant/expr.py | 4 ++++ narwhals/_plan/compliant/struct.py | 15 ++++++++++++ 4 files changed, 78 insertions(+), 16 deletions(-) create mode 100644 narwhals/_plan/compliant/struct.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 7c501c12da..9d6ec92351 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, overload +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -21,6 +21,7 @@ from narwhals._plan.compliant.column import ExprDispatch from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar +from narwhals._plan.compliant.struct import ExprStructNamespace from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.boolean import ( @@ -86,6 +87,7 @@ Rank, Shift, ) + from narwhals._plan.expressions.struct import FieldByName from narwhals._plan.typing import Seq from narwhals.typing import IntoDType, PythonLiteral @@ -631,6 +633,10 @@ def rolling_expr( # ewm_mean = not_implemented() # noqa: ERA001 + @property + def struct(self) -> ArrowStructNamespace[Expr]: + return ArrowStructNamespace(self) + class ArrowScalar( _ArrowDispatch["ArrowScalar"], @@ -741,6 +747,10 @@ def drop_nulls( # type: ignore[override] chunked = fn.chunked_array([[]], previous.native.type) return ArrowExpr.from_native(chunked, name, version=self.version) + @property + def struct(self) -> ArrowStructNamespace[Scalar]: + return ArrowStructNamespace(self) + filter = not_implemented() over = not_implemented() over_ordered = not_implemented() @@ -754,3 +764,28 @@ def drop_nulls( # type: ignore[override] cum_min = not_implemented() cum_max = not_implemented() cum_prod = not_implemented() + + +ExprOrScalarT = TypeVar("ExprOrScalarT", ArrowExpr, ArrowScalar) + + +class ArrowStructNamespace(ExprStructNamespace["Frame", ExprOrScalarT]): + def __narwhals_namespace__(self) -> ArrowNamespace: + return namespace(self._compliant) + + @property + def version(self) -> Version: + return self._compliant.version + + def __init__(self, compliant: ExprOrScalarT, /) -> None: + self._compliant: ExprOrScalarT = compliant + + def with_native(self, native: Any, name: str, /) -> ExprOrScalarT: + return self._compliant.from_native(native, name, self.version) + + def field( + self, node: ir.FunctionExpr[FieldByName], frame: Frame, name: str + ) -> ExprOrScalarT: + native = node.input[0].dispatch(self._compliant, frame, name).native + field_name = node.function.name + return self.with_native(fn.struct_field(native, field_name), field_name) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index e037c492fd..7283902b12 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -70,7 +70,7 @@ ) from narwhals._plan.compliant.typing import SeriesT from narwhals._plan.options import RankOptions, SortMultipleOptions, SortOptions - from narwhals._plan.typing import OneOrSeq, Seq + from narwhals._plan.typing import Seq from narwhals.typing import ( ClosedInterval, FillNullStrategy, @@ -267,21 +267,29 @@ def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStr @t.overload def struct_field(native: ChunkedStruct, field: Field, /) -> ChunkedArrayAny: ... @t.overload -def struct_field( - native: ChunkedStruct, field: Field, *fields: Field -) -> Seq[ChunkedArrayAny]: ... -@t.overload def struct_field(native: StructArray, field: Field, /) -> ArrayAny: ... @t.overload -def struct_field(native: StructArray, field: Field, *fields: Field) -> Seq[ArrayAny]: ... -def struct_field( - native: ChunkedOrArrayAny, field: Field, *fields: Field -) -> OneOrSeq[ChunkedOrArrayAny]: - """Retrieve one or multiple `Struct` field(s) as `(Chunked)Array`(s).""" - func = t.cast("Callable[[Any,Any], ChunkedOrArrayAny]", pc.struct_field) - if not fields: - return func(native, field) - return tuple(func(native, name) for name in (field, *fields)) +def struct_field(native: pa.StructScalar, field: Field, /) -> ScalarAny: ... +@t.overload +def struct_field(native: SameArrowT, field: Field, /) -> SameArrowT: ... +def struct_field(native: ArrowAny, field: Field, /) -> ArrowAny: + """Retrieve one `Struct` field.""" + func = t.cast("Callable[[Any,Any], ArrowAny]", pc.struct_field) + return func(native, field) + + +@t.overload +def struct_fields(native: ChunkedStruct, *fields: Field) -> Seq[ChunkedArrayAny]: ... +@t.overload +def struct_fields(native: StructArray, *fields: Field) -> Seq[ArrayAny]: ... +@t.overload +def struct_fields(native: pa.StructScalar, *fields: Field) -> Seq[ScalarAny]: ... +@t.overload +def struct_fields(native: SameArrowT, *fields: Field) -> Seq[SameArrowT]: ... +def struct_fields(native: ArrowAny, *fields: Field) -> Seq[ArrowAny]: + """Retrieve multiple `Struct` fields.""" + func = t.cast("Callable[[Any,Any], ArrowAny]", pc.struct_field) + return tuple(func(native, name) for name in fields) @t.overload @@ -621,7 +629,7 @@ def ir_min_max(name: str, /) -> MinMax: def _boolean_is_unique( indices: ChunkedArrayAny, aggregated: ChunkedStruct, / ) -> ChunkedArrayAny: - min, max = struct_field(aggregated, "min", "max") + min, max = struct_fields(aggregated, "min", "max") return and_(is_in(indices, min), is_in(indices, max)) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 33a01bba39..933c0790c7 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -17,6 +17,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.compliant.scalar import CompliantScalar, EagerScalar + from narwhals._plan.compliant.struct import ExprStructNamespace from narwhals._plan.expressions import ( BinaryExpr, FunctionExpr, @@ -242,6 +243,9 @@ def unique( self, node: FunctionExpr[F.Unique], frame: FrameT_contra, name: str ) -> Self: ... + @property + def struct(self) -> ExprStructNamespace[FrameT_contra, Self]: ... + class EagerExpr( EagerBroadcast[SeriesT], diff --git a/narwhals/_plan/compliant/struct.py b/narwhals/_plan/compliant/struct.py new file mode 100644 index 0000000000..2afc58afb7 --- /dev/null +++ b/narwhals/_plan/compliant/struct.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +from narwhals._plan.compliant.typing import ExprT_co, FrameT_contra + +if TYPE_CHECKING: + from narwhals._plan.expressions import FunctionExpr as FExpr + from narwhals._plan.expressions.struct import FieldByName + + +class ExprStructNamespace(Protocol[FrameT_contra, ExprT_co]): + def field( + self, node: FExpr[FieldByName], frame: FrameT_contra, name: str + ) -> ExprT_co: ... From b17b343cc5aba070068ae9f158c110029e7e6b18 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 14:09:44 +0000 Subject: [PATCH 075/215] chore: Remove finished todo Opened #3327 --- tests/plan/fill_null_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/plan/fill_null_test.py b/tests/plan/fill_null_test.py index 1a752cc6d0..9015cc0656 100644 --- a/tests/plan/fill_null_test.py +++ b/tests/plan/fill_null_test.py @@ -31,10 +31,6 @@ } -# TODO @dangotbanned: Address index out-of-bounds error -# - [x] Fix this in the new version -# - [ ] Open an issue demonstrating the bug -# - Same problem impacts `main` for `fill_null(limit=...)` @pytest.mark.parametrize( ("data", "exprs", "expected"), [ From ce3ec2194e1f67e998c3ea1f989e624aadb05b55 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 14:14:32 +0000 Subject: [PATCH 076/215] test: Move comments to test ids --- tests/plan/fill_null_test.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/plan/fill_null_test.py b/tests/plan/fill_null_test.py index 9015cc0656..450eca14d5 100644 --- a/tests/plan/fill_null_test.py +++ b/tests/plan/fill_null_test.py @@ -34,30 +34,34 @@ @pytest.mark.parametrize( ("data", "exprs", "expected"), [ - ( # test_fill_null + pytest.param( DATA_1, nwp.all().fill_null(value=99), {"a": [0.0, 99, 2, 3, 4], "b": [1.0, 99, 99, 5, 3], "c": [5.0, 99, 3, 2, 1]}, + id="literal", ), - ( # test_fill_null_w_aggregate + pytest.param( {"a": [0.5, None, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", None, "yy"]}, [nwp.col("a").fill_null(nwp.col("a").mean()), nwp.col("b").fill_null("a")], {"a": [0.5, 2.5, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", "a", "yy"]}, + id="expr-aggregate", ), - ( # test_fill_null_series_expression + pytest.param( DATA_2, nwp.nth(0, 1).fill_null(nwp.col("c")), {"a": [0.0, 2, 2, 3, 4], "b": [1.0, 2, None, 5, 3]}, + id="expr-column", ), - ( # test_fill_null_strategies_with_limit_as_none (1) + pytest.param( DATA_LIMITS, ncs.by_index(0, 1).fill_null(strategy="forward").over(order_by="idx"), { "a": [1, 1, 1, 1, 5, 6, 6, 6, 6, 10], "b": ["a", "a", "a", "a", "b", "c", "c", "c", "c", "d"], }, + id="forward", ), - ( # test_fill_null_strategies_with_limit_as_none (2) + pytest.param( DATA_LIMITS, nwp.exclude("idx").fill_null(strategy="backward").over(order_by="idx"), { @@ -66,16 +70,18 @@ "c": [2.5, 2.5, 3.6, 3.6, 3.6, 3.6, 3.6, 2.2, 2.2, 3.0], "d": [1, 2, 2, 2, 2, 2, 2, 2, 2, None], }, + id="backward", ), - ( # test_fill_null_limits (1) + pytest.param( DATA_LIMITS, nwp.col("a", "b").fill_null(strategy="forward", limit=2).over(order_by="idx"), { "a": [1, 1, 1, None, 5, 6, 6, 6, None, 10], "b": ["a", "a", "a", None, "b", "c", "c", "c", None, "d"], }, + id="forward-limit", ), - ( # test_fill_null_limits (2) + pytest.param( DATA_LIMITS, [ nwp.col("a", "b") @@ -88,16 +94,19 @@ "b": ["a", None, "b", "b", "b", "c", None, "d", "d", "d"], "c": [2.5, 2.5, None, 3.6, 3.6, 3.6, 3.6, 2.2, 2.2, 3.0], }, + id="backward-limit", ), - ( + pytest.param( DATA_LIMITS, nwp.col("c").fill_null(strategy="forward", limit=3).over(order_by="idx"), {"c": [None, 2.5, 2.5, 2.5, 2.5, None, 3.6, 3.6, 2.2, 3.0]}, + id="forward-limit-nulls-first", ), - ( + pytest.param( DATA_LIMITS, nwp.col("d").fill_null(strategy="backward", limit=3).over(order_by="idx"), {"d": [1, None, None, None, None, 2, 2, 2, 2, None]}, + id="backward-limit-nulls-last", ), ], ) From cd04606fad0df82d5de000b0a062f9f5b488d6c8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:20:54 +0000 Subject: [PATCH 077/215] test: Tidy up `replace_strict` Much clearer what each case is testing, covers a slightly more, + generates nice test ids --- tests/plan/replace_strict_test.py | 145 ++++++++++++++---------------- 1 file changed, 66 insertions(+), 79 deletions(-) diff --git a/tests/plan/replace_strict_test.py b/tests/plan/replace_strict_test.py index b460764b45..dc79ecb87f 100644 --- a/tests/plan/replace_strict_test.py +++ b/tests/plan/replace_strict_test.py @@ -1,22 +1,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from itertools import chain +from typing import TYPE_CHECKING, Any, Literal import pytest import narwhals as nw import narwhals._plan as nwp +from narwhals._utils import no_default from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence from _pytest.mark import ParameterSet from typing_extensions import TypeAlias + from narwhals._plan.typing import IntoExpr from narwhals._typing import NoDefault - from narwhals.typing import IntoDType + from narwhals.typing import IntoDType, NonNestedLiteral from tests.conftest import Data pytest.importorskip("pyarrow") @@ -37,40 +40,86 @@ def data() -> Data: } -def basic_cases( - column: str, +def cases( + column: Literal["str", "int", "str-null", "int-null", "str-alt"], replacements: Mapping[Any, Any], return_dtypes: Iterable[IntoDType | None], + *, + default: IntoExpr | NoDefault = no_default, + expected: list[NonNestedLiteral] | None = None, + marks: pytest.MarkDecorator | Collection[pytest.MarkDecorator | pytest.Mark] = (), ) -> Iterator[ParameterSet]: old, new = list(replacements), tuple(replacements.values()) - values = list(new) base = nwp.col(column) alt_name = f"{column}_seqs" alt = nwp.col(column).alias(alt_name) - expected = {column: values, alt_name: values} + if expected: + expected_m = {column: expected, alt_name: expected} + else: + expected_m = {column: list(new), alt_name: list(new)} + if default is no_default: + suffix = "" + else: + tp = type(default._ir) if isinstance(default, nwp.Expr) else type(default) + suffix = f"-default-{tp.__name__}" + for dtype in return_dtypes: exprs = ( - base.replace_strict(replacements, return_dtype=dtype), - alt.replace_strict(old, new, return_dtype=dtype), + base.replace_strict(replacements, default=default, return_dtype=dtype), + alt.replace_strict(old, new, default=default, return_dtype=dtype), ) - schema = {column: dtype, alt_name: dtype} if dtype is not None else None - yield pytest.param(exprs, expected, schema, id=f"{column}-{dtype}") + schema = {column: dtype, alt_name: dtype} if dtype else None + id = f"{column}-{dtype}{suffix}" + yield pytest.param(exprs, expected_m, schema, id=id, marks=marks) @pytest.mark.parametrize( ("exprs", "expected", "schema"), - [ - *basic_cases( + chain( + cases( "str", {"one": 1, "two": 2, "three": 3, "four": 4}, [nw.Int8, nw.Float32, None], ), - *basic_cases( - "int", {1: "one", 2: "two", 3: "three", 4: "four"}, [nw.String(), None] + cases("int", {1: "one", 2: "two", 3: "three", 4: "four"}, [nw.String(), None]), + cases( + "int", + {1: "one", 2: "two"}, + [nw.String, None], + default=nwp.lit("other"), + expected=["one", "two", "other", "other"], ), - ], + cases( + "int-null", + {1: 10, 2: 20}, + [nw.Int64, None], + default=99, + expected=[10, 20, 99, 99], + ), + cases( + "int", + {1: "one", 2: "two", 3: None}, + [nw.String, None], + default="other", + expected=["one", "two", None, "other"], + ), + cases( + "int", + {1: "one", 2: "two"}, + [nw.String, None], + default=nwp.col("str-alt"), + expected=["one", "two", "orca", "vaquita"], + ), + cases( + "int", + {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"}, + [None], + default="hundred", + expected=["one", "two", "three", "four"], + ), + ), ) -def test_replace_strict_expr_basic( +def test_replace_strict_expr( data: Data, exprs: Iterable[nwp.Expr], expected: Data, @@ -96,68 +145,6 @@ def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: dataframe(data).select(expr) -# TODO @dangotbanned: Share more of the case generation logic from `basic_cases` -@pytest.mark.parametrize( - ("expr", "expected"), - [ - # test_replace_strict_expr_with_default - pytest.param( - nwp.col("int").replace_strict( - [1, 2], ["one", "two"], default=nwp.lit("other"), return_dtype=nw.String - ), - {"int": ["one", "two", "other", "other"]}, - id="non-null-1", - ), - pytest.param( - nwp.col("int").replace_strict([1, 2], ["one", "two"], default="other"), - {"int": ["one", "two", "other", "other"]}, - id="non-null-2", - ), - # test_replace_strict_with_default_and_nulls - pytest.param( - nwp.col("int-null").replace_strict( - [1, 2], [10, 20], default=99, return_dtype=nw.Int64 - ), - {"int-null": [10, 20, 99, 99]}, - id="null-1", - ), - pytest.param( - nwp.col("int-null").replace_strict([1, 2], [10, 20], default=99), - {"int-null": [10, 20, 99, 99]}, - id="null-2", - ), - # test_replace_strict_with_default_mapping - pytest.param( - nwp.col("int").replace_strict( - {1: "one", 2: "two", 3: None}, default="other", return_dtype=nw.String() - ), - {"int": ["one", "two", None, "other"]}, - # shouldn't be an independent case, the mapping isn't the default - id="replace_strict_with_default_mapping", - ), - # test_replace_strict_with_expressified_default - pytest.param( - nwp.col("int").replace_strict( - {1: "one", 2: "two"}, default=nwp.col("str-alt"), return_dtype=nw.String - ), - {"int": ["one", "two", "orca", "vaquita"]}, - id="column", - ), - # test_mapping_key_not_in_expr - pytest.param( - nwp.col("int").replace_strict( - {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"}, default="hundred" - ), - {"int": ["one", "two", "three", "four"]}, - id="mapping_key_not_in_expr", - ), - ], -) -def test_replace_strict_expr_default(data: Data, expr: nwp.Expr, expected: Data) -> None: - result = dataframe(data).select(expr) - assert_equal_data(result, expected) - - def test_replace_strict_scalar(data: Data) -> None: df = dataframe(data) expr = ( From e71a3bc87020a31c3bdf2282a47355ff269c9904 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:26:42 +0000 Subject: [PATCH 078/215] chore: remove unused --- tests/plan/replace_strict_test.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/plan/replace_strict_test.py b/tests/plan/replace_strict_test.py index dc79ecb87f..f4195dbd69 100644 --- a/tests/plan/replace_strict_test.py +++ b/tests/plan/replace_strict_test.py @@ -12,10 +12,9 @@ from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: - from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence + from collections.abc import Collection, Iterable, Iterator, Mapping from _pytest.mark import ParameterSet - from typing_extensions import TypeAlias from narwhals._plan.typing import IntoExpr from narwhals._typing import NoDefault @@ -25,10 +24,6 @@ pytest.importorskip("pyarrow") -Old: TypeAlias = "Sequence[Any] | Mapping[Any, Any]" -New: TypeAlias = "Sequence[Any] | NoDefault" - - @pytest.fixture(scope="module") def data() -> Data: return { From 0c2d4256f2f82777f93c7fe53f9cd85eaae02ba9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:45:36 +0000 Subject: [PATCH 079/215] test: Reveal some `struct.field` bugs --- tests/plan/struct_field_test.py | 43 +++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/plan/struct_field_test.py diff --git a/tests/plan/struct_field_test.py b/tests/plan/struct_field_test.py new file mode 100644 index 0000000000..3659836ca6 --- /dev/null +++ b/tests/plan/struct_field_test.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from narwhals.exceptions import DuplicateError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Iterable + + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + pytest.param( + [nwp.col("user").struct.field("id"), nwp.col("user").struct.field("name")], + {"id": ["0", "1"], "name": ["john", "jane"]}, + marks=pytest.mark.xfail( + raises=DuplicateError, + reason="TODO: Handle `FieldByName` correctly during `Expr` expansion", + ), + ), + pytest.param( + nwp.col("user").struct.field("id").name.keep(), + {"user": ["0", "1"]}, + marks=pytest.mark.xfail( + raises=NotImplementedError, + reason="BUG: Attempting to call `ArrowExpr.field` instead of `ArrowExpr.struct.field`", + ), + ), + ], +) +def test_struct_field(exprs: nwp.Expr | Iterable[nwp.Expr], expected: Data) -> None: + data = {"user": [{"id": "0", "name": "john"}, {"id": "1", "name": "jane"}]} + result = dataframe(data).select(exprs) + assert_equal_data(result, expected) # pragma: no cover From bcf197b75672461500cfd439588f8fc808de032e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:58:17 +0000 Subject: [PATCH 080/215] fix: Ensure generated method names preserve accessors Using `repr` was making the tests *look* like they worked, but the reality was hairier --- narwhals/_plan/_dispatch.py | 21 +++++++++++++++++---- narwhals/_plan/_function.py | 4 ++-- narwhals/_plan/expressions/categorical.py | 4 ++-- narwhals/_plan/expressions/lists.py | 4 ++-- narwhals/_plan/expressions/strings.py | 14 +++++++------- narwhals/_plan/expressions/temporal.py | 2 +- tests/plan/dispatch_test.py | 12 ++++++++++++ tests/plan/struct_field_test.py | 9 +++++++-- 8 files changed, 50 insertions(+), 20 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 7676de82bb..98eb51bec9 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -182,7 +182,20 @@ def _method_name(tp: type[ExprIRT | FunctionT]) -> str: def get_dispatch_name(expr: ExprIR | type[Function], /) -> str: - """Return the synthesized method name for `expr`.""" - return ( - repr(expr.function) if is_function_expr(expr) else expr.__expr_ir_dispatch__.name - ) + """Return the synthesized method name for `expr`. + + Note: + Refers to the `Compliant*` method name, which may be *either* more general + or more specialized than what the user called. + """ + dispatch: Dispatcher[Any] + if is_function_expr(expr): + from narwhals._plan import expressions as ir + + if isinstance(expr, (ir.RollingExpr, ir.AnonymousExpr)): + dispatch = expr.__expr_ir_dispatch__ + else: + dispatch = expr.function.__expr_ir_dispatch__ + else: + dispatch = expr.__expr_ir_dispatch__ + return dispatch.name diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index 8b71a4dd8d..8a8caafda2 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -52,8 +52,8 @@ def __init_subclass__( **kwds: Any, ) -> None: super().__init_subclass__(*args, **kwds) - if accessor: - config = replace(config or FEOptions.default(), accessor_name=accessor) + if accessor_name := accessor or cls.__expr_ir_config__.accessor_name: + config = replace(config or FEOptions.default(), accessor_name=accessor_name) if options: cls._function_options = staticmethod(options) if config: diff --git a/narwhals/_plan/expressions/categorical.py b/narwhals/_plan/expressions/categorical.py index 5bb7157f5d..7c59fd4443 100644 --- a/narwhals/_plan/expressions/categorical.py +++ b/narwhals/_plan/expressions/categorical.py @@ -20,7 +20,7 @@ class IRCatNamespace(IRNamespace): class ExprCatNamespace(ExprNamespace[IRCatNamespace]): @property def _ir_namespace(self) -> type[IRCatNamespace]: - return IRCatNamespace # pragma: no cover + return IRCatNamespace def get_categories(self) -> Expr: - return self._with_unary(self._ir.get_categories()) # pragma: no cover + return self._with_unary(self._ir.get_categories()) diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index b14090b985..604e054a5e 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -21,7 +21,7 @@ class IRListNamespace(IRNamespace): class ExprListNamespace(ExprNamespace[IRListNamespace]): @property def _ir_namespace(self) -> type[IRListNamespace]: - return IRListNamespace # pragma: no cover + return IRListNamespace def len(self) -> Expr: - return self._with_unary(self._ir.len()) # pragma: no cover + return self._with_unary(self._ir.len()) diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 5478a7154c..5d08412fdb 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -100,13 +100,13 @@ def strip_chars( def contains(self, pattern: str, *, literal: bool = False) -> Contains: return Contains(pattern=pattern, literal=literal) - def slice(self, offset: int, length: int | None = None) -> Slice: # pragma: no cover + def slice(self, offset: int, length: int | None = None) -> Slice: return Slice(offset=offset, length=length) - def head(self, n: int = 5) -> Slice: # pragma: no cover + def head(self, n: int = 5) -> Slice: return self.slice(0, n) - def tail(self, n: int = 5) -> Slice: # pragma: no cover + def tail(self, n: int = 5) -> Slice: return self.slice(-n) def to_datetime(self, format: str | None = None) -> ToDatetime: # pragma: no cover @@ -134,7 +134,7 @@ def replace_all( def strip_chars(self, characters: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.strip_chars(characters)) - def starts_with(self, prefix: str) -> Expr: # pragma: no cover + def starts_with(self, prefix: str) -> Expr: return self._with_unary(self._ir.starts_with(prefix=prefix)) def ends_with(self, suffix: str) -> Expr: # pragma: no cover @@ -143,13 +143,13 @@ def ends_with(self, suffix: str) -> Expr: # pragma: no cover def contains(self, pattern: str, *, literal: bool = False) -> Expr: return self._with_unary(self._ir.contains(pattern, literal=literal)) - def slice(self, offset: int, length: int | None = None) -> Expr: # pragma: no cover + def slice(self, offset: int, length: int | None = None) -> Expr: return self._with_unary(self._ir.slice(offset, length)) - def head(self, n: int = 5) -> Expr: # pragma: no cover + def head(self, n: int = 5) -> Expr: return self._with_unary(self._ir.head(n)) - def tail(self, n: int = 5) -> Expr: # pragma: no cover + def tail(self, n: int = 5) -> Expr: return self._with_unary(self._ir.tail(n)) def split(self, by: str) -> Expr: # pragma: no cover diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index 35a622ebd2..2ba22aa925 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -169,7 +169,7 @@ def total_nanoseconds(self) -> Expr: # pragma: no cover def to_string(self, format: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_string(format=format)) - def replace_time_zone(self, time_zone: str | None) -> Expr: # pragma: no cover + def replace_time_zone(self, time_zone: str | None) -> Expr: return self._with_unary(self._ir.replace_time_zone(time_zone=time_zone)) def convert_time_zone(self, time_zone: str) -> Expr: # pragma: no cover diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py index 8d8f8d994e..ddfac3051e 100644 --- a/tests/plan/dispatch_test.py +++ b/tests/plan/dispatch_test.py @@ -78,11 +78,23 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: (nwp.int_range(10), "int_range"), (nwp.col("a") + nwp.col("b") + 10, "binary_expr"), (nwp.when(nwp.col("c")).then(5).when(nwp.col("d")).then(20), "ternary_expr"), + (nwp.col("a").rolling_sum(2), "rolling_expr"), + (nwp.col("a").cum_sum(), "cum_sum"), + (nwp.col("a").cat.get_categories(), "cat.get_categories"), + (nwp.col("a").dt.timestamp(), "dt.timestamp"), + (nwp.col("a").dt.replace_time_zone(None), "dt.replace_time_zone"), + (nwp.col("a").list.len(), "list.len"), (nwp.col("a").cast(nw.String).str.starts_with("something"), ("str.starts_with")), + (nwp.col("a").str.slice(1), ("str.slice")), + (nwp.col("a").str.head(), ("str.slice")), + (nwp.col("a").str.tail(), ("str.slice")), + (nwp.col("a").struct.field("b"), "struct.field"), (nwp.mean("a"), "mean"), (nwp.nth(1).first(), "first"), (nwp.col("a").sum(), "sum"), + (~nwp.col("a"), "not_"), (nwp.col("a").drop_nulls().arg_min(), "arg_min"), + (nwp.col("a").map_batches(lambda x: x), "map_batches"), pytest.param(nwp.col("a").alias("b"), "Alias", id="no_dispatch-Alias"), pytest.param(ncs.string(), "RootSelector", id="no_dispatch-RootSelector"), ], diff --git a/tests/plan/struct_field_test.py b/tests/plan/struct_field_test.py index 3659836ca6..b432532eaa 100644 --- a/tests/plan/struct_field_test.py +++ b/tests/plan/struct_field_test.py @@ -19,6 +19,9 @@ @pytest.mark.parametrize( ("exprs", "expected"), [ + pytest.param( + nwp.col("user").struct.field("id"), {"id": ["0", "1"]}, id="field-single" + ), pytest.param( [nwp.col("user").struct.field("id"), nwp.col("user").struct.field("name")], {"id": ["0", "1"], "name": ["john", "jane"]}, @@ -26,14 +29,16 @@ raises=DuplicateError, reason="TODO: Handle `FieldByName` correctly during `Expr` expansion", ), + id="multiple-fields-same-root", ), pytest.param( nwp.col("user").struct.field("id").name.keep(), {"user": ["0", "1"]}, marks=pytest.mark.xfail( - raises=NotImplementedError, - reason="BUG: Attempting to call `ArrowExpr.field` instead of `ArrowExpr.struct.field`", + raises=KeyError, + reason="TODO: Handle `FieldByName` correctly during `Expr` expansion", ), + id="field-single-keep-root", ), ], ) From 6188c963842b5b8fe40b54a75bcbe36693d1182a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 18:10:17 +0000 Subject: [PATCH 081/215] fix: Handle `struct.field(...)` in `ExprIR` expansion Solved *some* of my woes - but it looks like there may be another issue? --- narwhals/_plan/expressions/__init__.py | 2 ++ narwhals/_plan/expressions/expr.py | 16 ++++++++++++++++ narwhals/_plan/expressions/struct.py | 19 ++++++++++++++++++- narwhals/_plan/meta.py | 3 +++ narwhals/_plan/typing.py | 4 ++++ tests/plan/meta_test.py | 5 ----- tests/plan/selectors_test.py | 13 ++++++++++++- tests/plan/struct_field_test.py | 9 ++------- 8 files changed, 57 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 0d97c20288..3b9f0364ab 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -38,6 +38,7 @@ RootSelector, Sort, SortBy, + StructExpr, TernaryExpr, WindowExpr, col, @@ -71,6 +72,7 @@ "SelectorIR", "Sort", "SortBy", + "StructExpr", "TernaryExpr", "WindowExpr", "aggregation", diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index eae6dfbbf5..26b475b51d 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -27,6 +27,7 @@ SelectorOperatorT, SelectorT, Seq, + StructT_co, ) from narwhals.exceptions import InvalidOperationError @@ -60,6 +61,7 @@ "SelectorIR", "Sort", "SortBy", + "StructExpr", "TernaryExpr", "WindowExpr", "col", @@ -320,6 +322,20 @@ def __repr__(self) -> str: return f"{self.function!r}({list(self.input)!r})" +class StructExpr(FunctionExpr[StructT_co]): + """E.g. `col("a").struct.field(...)`. + + Requires special handling during expression expansion. + """ + + def needs_expansion(self) -> bool: + return self.function.needs_expansion or super().needs_expansion() + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield self + yield from super().iter_output_name() # pragma: no cover + + class Filter(ExprIR, child=("expr", "by")): __slots__ = ("expr", "by") # noqa: RUF023 expr: ExprIR diff --git a/narwhals/_plan/expressions/struct.py b/narwhals/_plan/expressions/struct.py index e3625adb8a..6350f0668e 100644 --- a/narwhals/_plan/expressions/struct.py +++ b/narwhals/_plan/expressions/struct.py @@ -7,10 +7,23 @@ from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan._expr_ir import ExprIR from narwhals._plan.expr import Expr + from narwhals._plan.expressions.expr import StructExpr + + +class StructFunction(Function, accessor="struct"): + def to_function_expr(self, *inputs: ExprIR) -> StructExpr[Self]: + from narwhals._plan.expressions.expr import StructExpr + return StructExpr(input=inputs, function=self, options=self.function_options) -class StructFunction(Function, accessor="struct"): ... + @property + def needs_expansion(self) -> bool: + msg = f"{type(self).__name__}.needs_expansion" + raise NotImplementedError(msg) class FieldByName( @@ -22,6 +35,10 @@ class FieldByName( def __repr__(self) -> str: return f"{super().__repr__()}({self.name!r})" + @property + def needs_expansion(self) -> bool: + return True + class IRStructNamespace(IRNamespace): field: ClassVar = FieldByName diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 41487b503b..5938795a13 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -14,6 +14,7 @@ from narwhals._plan.expressions import selectors as cs from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._plan.expressions.namespace import IRNamespace +from narwhals._plan.expressions.struct import FieldByName from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.utils import Version @@ -113,6 +114,8 @@ def _expr_output_name(expr: ir.ExprIR, /) -> str | ComputeError: isinstance(e.selector, cs.ByName) and len(e.selector.names) == 1 ): return e.selector.names[0] + if isinstance(e, ir.StructExpr) and isinstance(e.function, FieldByName): + return e.function.name msg = ( f"unable to find root column name for expr '{expr!r}' when calling 'output_name'" ) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index c0973d0487..4fdf78dba2 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -21,6 +21,7 @@ from narwhals._plan.expressions.functions import RollingWindow from narwhals._plan.expressions.namespace import IRNamespace from narwhals._plan.expressions.ranges import RangeFunction + from narwhals._plan.expressions.struct import StructFunction from narwhals._plan.selectors import Selector from narwhals._plan.series import Series from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -61,6 +62,9 @@ RangeT_co = TypeVar( "RangeT_co", bound="RangeFunction", default="RangeFunction", covariant=True ) +StructT_co = TypeVar( + "StructT_co", bound="StructFunction", default="StructFunction", covariant=True +) LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index f7738dbf5e..36e1e12160 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -325,11 +325,6 @@ def test_literal_output_name() -> None: assert e.meta.output_name() == "" -# NOTE: Very low-priority -@pytest.mark.xfail( - reason="TODO: `Expr.struct.field` influences `meta.output_name`.", - raises=AssertionError, -) def test_struct_field_output_name_24003() -> None: assert nwp.col("ball").struct.field("radius").meta.output_name() == "radius" diff --git a/tests/plan/selectors_test.py b/tests/plan/selectors_test.py index 56ba3b1251..a9565521f8 100644 --- a/tests/plan/selectors_test.py +++ b/tests/plan/selectors_test.py @@ -18,7 +18,7 @@ from narwhals._plan import Selector, selectors as ncs from narwhals._plan._guards import is_expr, is_selector from narwhals._utils import zip_strict -from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError +from narwhals.exceptions import ColumnNotFoundError, DuplicateError, InvalidOperationError from tests.plan.utils import ( Frame, assert_expr_ir_equal, @@ -749,3 +749,14 @@ def test_when_then_keep_map_13858() -> None: ) df.assert_selects(aliased, "b_other") df.assert_selects(when_keep_chain, "b_other") + + +def test_keep_name_struct_field_23669() -> None: + df = Frame.from_mapping( + {"foo": nw.Struct({"x": nw.Int64}), "bar": nw.Struct({"x": nw.Int64})} + ) + + with pytest.raises(DuplicateError): + df.project(nwp.all().struct.field("x")) + + df.assert_selects(nwp.all().struct.field("x").name.keep(), "foo", "bar") diff --git a/tests/plan/struct_field_test.py b/tests/plan/struct_field_test.py index b432532eaa..60ee52c81a 100644 --- a/tests/plan/struct_field_test.py +++ b/tests/plan/struct_field_test.py @@ -5,7 +5,6 @@ import pytest import narwhals._plan as nwp -from narwhals.exceptions import DuplicateError from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: @@ -25,10 +24,6 @@ pytest.param( [nwp.col("user").struct.field("id"), nwp.col("user").struct.field("name")], {"id": ["0", "1"], "name": ["john", "jane"]}, - marks=pytest.mark.xfail( - raises=DuplicateError, - reason="TODO: Handle `FieldByName` correctly during `Expr` expansion", - ), id="multiple-fields-same-root", ), pytest.param( @@ -36,7 +31,7 @@ {"user": ["0", "1"]}, marks=pytest.mark.xfail( raises=KeyError, - reason="TODO: Handle `FieldByName` correctly during `Expr` expansion", + reason="TODO: Handle `FieldByName` correctly in `ArrowExpr`?", ), id="field-single-keep-root", ), @@ -45,4 +40,4 @@ def test_struct_field(exprs: nwp.Expr | Iterable[nwp.Expr], expected: Data) -> None: data = {"user": [{"id": "0", "name": "john"}, {"id": "1", "name": "jane"}]} result = dataframe(data).select(exprs) - assert_equal_data(result, expected) # pragma: no cover + assert_equal_data(result, expected) From 72f17e2422a7f7187f60a59b4e5a7819405743a9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 18:17:51 +0000 Subject: [PATCH 082/215] fix: Use the right names for select/alias At least the expansion worked! --- narwhals/_plan/arrow/expr.py | 3 +-- tests/plan/struct_field_test.py | 4 ---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 9d6ec92351..a8bd2adde6 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -787,5 +787,4 @@ def field( self, node: ir.FunctionExpr[FieldByName], frame: Frame, name: str ) -> ExprOrScalarT: native = node.input[0].dispatch(self._compliant, frame, name).native - field_name = node.function.name - return self.with_native(fn.struct_field(native, field_name), field_name) + return self.with_native(fn.struct_field(native, node.function.name), name) diff --git a/tests/plan/struct_field_test.py b/tests/plan/struct_field_test.py index 60ee52c81a..4c7b3b1d90 100644 --- a/tests/plan/struct_field_test.py +++ b/tests/plan/struct_field_test.py @@ -29,10 +29,6 @@ pytest.param( nwp.col("user").struct.field("id").name.keep(), {"user": ["0", "1"]}, - marks=pytest.mark.xfail( - raises=KeyError, - reason="TODO: Handle `FieldByName` correctly in `ArrowExpr`?", - ), id="field-single-keep-root", ), ], From 2f48f10e1db6e7512f183bc079ff6252ce143153 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 20:11:29 +0000 Subject: [PATCH 083/215] test: `{kurtosis,skew}.over(*partition_by)` Was already working, but always good to be sure --- tests/plan/over_test.py | 80 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/tests/plan/over_test.py b/tests/plan/over_test.py index ccea0eec77..42f61b5467 100644 --- a/tests/plan/over_test.py +++ b/tests/plan/over_test.py @@ -5,18 +5,20 @@ import pytest +from tests.utils import PYARROW_VERSION + pytest.importorskip("pyarrow") import narwhals as nw import narwhals._plan as nwp from narwhals._plan import selectors as ncs -from narwhals._utils import zip_strict +from narwhals._utils import Implementation, zip_strict from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_equal_data, dataframe, re_compile if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence + from collections.abc import Callable, Iterable, Mapping, Sequence from _pytest.mark import ParameterSet from typing_extensions import TypeAlias @@ -274,6 +276,80 @@ def test_null_count_over() -> None: assert_equal_data(result, expected) +@pytest.fixture(scope="module") +def data_kurtosis_skew() -> Data: + return { + "p1": ["a", "a", "b", "a", "b", "b"], + "p2": ["d", "e", "e", "e", "d", "d"], + "v1": [0.2, 5.0, 1.0, 0.7, 0.5, 1.0], + "v2": [-1.0, 0.8, 0.6, 0.0, 1.1, 19.0], + "v3": [None, 1.2, 2.1, 0.4, 5.0, 3.2], + } + + +EXPECTED_SKEW = { + "v1_p1": [0.678654, 0.678654, -0.707107, 0.678654, -0.707107, -0.707107], + "v2_p1": [-0.135062, -0.135062, 0.705297, -0.135062, 0.705297, 0.705297], + "v3_p1": [-4.33681e-16, -4.33681e-16, 0.285361, -4.33681e-16, 0.285361, 0.285361], + "v1_p1_p2": [float("nan"), -2.68106e-16, float("nan"), -2.68106e-16, 0.0, 0.0], + "v2_p1_p2": [float("nan"), 0.0, float("nan"), 0.0, -2.37866e-16, -2.37866e-16], + "v3_p1_p2": [ + None, + -4.33681e-16, + float("nan"), + -4.33681e-16, + 1.44679e-15, + 1.44679e-15, + ], +} +EXPECTED_KURTOSIS = { + "v1_p1": [-1.5, -1.5, -1.5, -1.5, -1.5, -1.5], + "v2_p1": [-1.5, -1.5, -1.5, -1.5, -1.5, -1.5], + "v3_p1": [-2.0, -2.0, -1.5, -2.0, -1.5, -1.5], + "v1_p1_p2": [float("nan"), -2.0, float("nan"), -2.0, -2.0, -2.0], + "v2_p1_p2": [float("nan"), -2.0, float("nan"), -2.0, -2.0, -2.0], + "v3_p1_p2": [None, -2.0, float("nan"), -2.0, -2.0, -2.0], +} +string = ncs.string() +not_string = ~string + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + ( + [ + not_string.skew().over("p1").name.suffix("_p1"), + not_string.skew().over(string).name.suffix("_p1_p2"), + ], + EXPECTED_SKEW, + ), + ( + [ + not_string.kurtosis().over("p1").name.suffix("_p1"), + not_string.kurtosis().over(string).name.suffix("_p1_p2"), + ], + EXPECTED_KURTOSIS, + ), + ], +) +def test_kurtosis_over_skew( + data_kurtosis_skew: Data, + request: pytest.FixtureRequest, + exprs: Iterable[nwp.Expr], + expected: Data, +) -> None: + df = dataframe(data_kurtosis_skew) + request.applymarker( + pytest.mark.xfail( + (df.implementation is Implementation.PYARROW and PYARROW_VERSION < (20,)), + reason="too old for `pyarrow.compute.{kurtosis,skew}`", + ) + ) + result = df.select(exprs) + assert_equal_data(result, expected) + + @pytest.fixture def data_groups() -> Data: return { From 84c01acf70f0952deface7b4b495a68894e47fbf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 22:30:24 +0000 Subject: [PATCH 084/215] feat: Add `cat.get_categories` - Expecting a failure for `<15.0`, just not sure what it'll be - `unify_dictionaries` may also be new? --- narwhals/_plan/arrow/expr.py | 42 +++++++++++++++---- narwhals/_plan/arrow/functions.py | 9 ++++ .../compliant/{struct.py => accessors.py} | 7 ++++ narwhals/_plan/compliant/expr.py | 6 ++- tests/plan/cat_get_categories_test.py | 22 ++++++++++ 5 files changed, 76 insertions(+), 10 deletions(-) rename narwhals/_plan/compliant/{struct.py => accessors.py} (65%) create mode 100644 tests/plan/cat_get_categories_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index a8bd2adde6..65372bafce 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -18,10 +18,10 @@ from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co from narwhals._plan.common import temp +from narwhals._plan.compliant.accessors import ExprCatNamespace, ExprStructNamespace from narwhals._plan.compliant.column import ExprDispatch from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar -from narwhals._plan.compliant.struct import ExprStructNamespace from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.boolean import ( @@ -76,6 +76,7 @@ IsNull, Not, ) + from narwhals._plan.expressions.categorical import GetCategories from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr as FExpr from narwhals._plan.expressions.functions import ( Abs, @@ -632,6 +633,9 @@ def rolling_expr( hist_bin_count = not_implemented() # ewm_mean = not_implemented() # noqa: ERA001 + @property + def cat(self) -> ArrowCatNamespace[Expr]: + return ArrowCatNamespace(self) @property def struct(self) -> ArrowStructNamespace[Expr]: @@ -747,6 +751,10 @@ def drop_nulls( # type: ignore[override] chunked = fn.chunked_array([[]], previous.native.type) return ArrowExpr.from_native(chunked, name, version=self.version) + @property + def cat(self) -> ArrowCatNamespace[Scalar]: + return ArrowCatNamespace(self) + @property def struct(self) -> ArrowStructNamespace[Scalar]: return ArrowStructNamespace(self) @@ -769,20 +777,36 @@ def struct(self) -> ArrowStructNamespace[Scalar]: ExprOrScalarT = TypeVar("ExprOrScalarT", ArrowExpr, ArrowScalar) -class ArrowStructNamespace(ExprStructNamespace["Frame", ExprOrScalarT]): +class ArrowAccessor(Generic[ExprOrScalarT]): + def __init__(self, compliant: ExprOrScalarT, /) -> None: + self._compliant: ExprOrScalarT = compliant + + @property + def compliant(self) -> ExprOrScalarT: + return self._compliant + def __narwhals_namespace__(self) -> ArrowNamespace: - return namespace(self._compliant) + return namespace(self.compliant) @property def version(self) -> Version: - return self._compliant.version - - def __init__(self, compliant: ExprOrScalarT, /) -> None: - self._compliant: ExprOrScalarT = compliant + return self.compliant.version def with_native(self, native: Any, name: str, /) -> ExprOrScalarT: - return self._compliant.from_native(native, name, self.version) + return self.compliant.from_native(native, name, self.version) + +class ArrowCatNamespace(ExprCatNamespace["Frame", "Expr"], ArrowAccessor[ExprOrScalarT]): + def get_categories( + self, node: ir.FunctionExpr[GetCategories], frame: Frame, name: str + ) -> Expr: + native = node.input[0].dispatch(self._compliant, frame, name).native + return ArrowExpr.from_native(fn.get_categories(native), name, self.version) + + +class ArrowStructNamespace( + ExprStructNamespace["Frame", ExprOrScalarT], ArrowAccessor[ExprOrScalarT] +): def field( self, node: ir.FunctionExpr[FieldByName], frame: Frame, name: str ) -> ExprOrScalarT: diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7283902b12..2895ed6d7e 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -292,6 +292,15 @@ def struct_fields(native: ArrowAny, *fields: Field) -> Seq[ArrowAny]: return tuple(func(native, name) for name in fields) +def get_categories(native: ArrowAny) -> ChunkedArrayAny: + da: Incomplete + if isinstance(native, pa.ChunkedArray): + da = native.unify_dictionaries().chunk(0) + else: + da = native + return chunked_array(da.dictionary) + + @t.overload def when_then( predicate: Predicate, then: SameArrowT, otherwise: SameArrowT diff --git a/narwhals/_plan/compliant/struct.py b/narwhals/_plan/compliant/accessors.py similarity index 65% rename from narwhals/_plan/compliant/struct.py rename to narwhals/_plan/compliant/accessors.py index 2afc58afb7..6df20d495c 100644 --- a/narwhals/_plan/compliant/struct.py +++ b/narwhals/_plan/compliant/accessors.py @@ -6,9 +6,16 @@ if TYPE_CHECKING: from narwhals._plan.expressions import FunctionExpr as FExpr + from narwhals._plan.expressions.categorical import GetCategories from narwhals._plan.expressions.struct import FieldByName +class ExprCatNamespace(Protocol[FrameT_contra, ExprT_co]): + def get_categories( + self, node: FExpr[GetCategories], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + + class ExprStructNamespace(Protocol[FrameT_contra, ExprT_co]): def field( self, node: FExpr[FieldByName], frame: FrameT_contra, name: str diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 933c0790c7..3d6e1d4ec1 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -16,8 +16,8 @@ from typing_extensions import Self, TypeAlias from narwhals._plan import expressions as ir + from narwhals._plan.compliant.accessors import ExprCatNamespace, ExprStructNamespace from narwhals._plan.compliant.scalar import CompliantScalar, EagerScalar - from narwhals._plan.compliant.struct import ExprStructNamespace from narwhals._plan.expressions import ( BinaryExpr, FunctionExpr, @@ -243,6 +243,10 @@ def unique( self, node: FunctionExpr[F.Unique], frame: FrameT_contra, name: str ) -> Self: ... + @property + def cat( + self, + ) -> ExprCatNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... @property def struct(self) -> ExprStructNamespace[FrameT_contra, Self]: ... diff --git a/tests/plan/cat_get_categories_test.py b/tests/plan/cat_get_categories_test.py new file mode 100644 index 0000000000..31ae416626 --- /dev/null +++ b/tests/plan/cat_get_categories_test.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import pytest + +import narwhals as nw +import narwhals._plan as nwp # noqa: F401 +import narwhals._plan.selectors as ncs +from tests.plan.utils import assert_equal_data, dataframe + +pytest.importorskip("pyarrow") + + +@pytest.mark.parametrize( + ("values", "expected"), + [(["one", "two", "two"], ["one", "two"]), (["A", "B", None, "D"], ["A", "B", "D"])], + ids=["full", "nulls"], +) +def test_get_categories(values: list[str], expected: list[str]) -> None: + data = {"a": values} + df = dataframe(data).select(ncs.first().cast(nw.Categorical)) + result = df.select(ncs.first().cat.get_categories()) + assert_equal_data(result, {"a": expected}) From 895cbe2c6beae344ff29d3a41c9c0d381c0b7746 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 22:36:24 +0000 Subject: [PATCH 085/215] test: xfail `pyarrow<15.0` --- tests/plan/cat_get_categories_test.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/plan/cat_get_categories_test.py b/tests/plan/cat_get_categories_test.py index 31ae416626..6d8fd8ae9b 100644 --- a/tests/plan/cat_get_categories_test.py +++ b/tests/plan/cat_get_categories_test.py @@ -5,7 +5,9 @@ import narwhals as nw import narwhals._plan as nwp # noqa: F401 import narwhals._plan.selectors as ncs +from narwhals._utils import Implementation from tests.plan.utils import assert_equal_data, dataframe +from tests.utils import PYARROW_VERSION pytest.importorskip("pyarrow") @@ -15,8 +17,18 @@ [(["one", "two", "two"], ["one", "two"]), (["A", "B", None, "D"], ["A", "B", "D"])], ids=["full", "nulls"], ) -def test_get_categories(values: list[str], expected: list[str]) -> None: +def test_get_categories( + values: list[str], expected: list[str], request: pytest.FixtureRequest +) -> None: data = {"a": values} + df = dataframe(data) + request.applymarker( + pytest.mark.xfail( + (df.implementation is Implementation.PYARROW and PYARROW_VERSION < (15,)), + reason="Unsupported cast from string to dictionary using function cast_dictionary", + ) + ) + df = dataframe(data) df = dataframe(data).select(ncs.first().cast(nw.Categorical)) result = df.select(ncs.first().cat.get_categories()) assert_equal_data(result, {"a": expected}) From 95f2d1c9e205e6b02fc2754dacb390f41faf288a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 26 Nov 2025 22:37:41 +0000 Subject: [PATCH 086/215] tidy --- tests/plan/cat_get_categories_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/plan/cat_get_categories_test.py b/tests/plan/cat_get_categories_test.py index 6d8fd8ae9b..fa13f3c40b 100644 --- a/tests/plan/cat_get_categories_test.py +++ b/tests/plan/cat_get_categories_test.py @@ -28,7 +28,5 @@ def test_get_categories( reason="Unsupported cast from string to dictionary using function cast_dictionary", ) ) - df = dataframe(data) - df = dataframe(data).select(ncs.first().cast(nw.Categorical)) - result = df.select(ncs.first().cat.get_categories()) + result = df.select(ncs.first().cast(nw.Categorical).cat.get_categories()) assert_equal_data(result, {"a": expected}) From 7674b604ecadd54962acc87bdfdc13219dabadca Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Nov 2025 12:14:03 +0000 Subject: [PATCH 087/215] feat: Add `list.len` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Needed to rethink the typing a bit - Accessor methods can return `Expr` or `Scalar` and may not depend on the input shape - Also, probably found a bug in `polars` - Nice to see things held together over here 😄 --- narwhals/_plan/arrow/expr.py | 42 +++++++++++++------- narwhals/_plan/arrow/functions.py | 17 +++++++++ narwhals/_plan/arrow/typing.py | 3 ++ narwhals/_plan/compliant/accessors.py | 8 +++- narwhals/_plan/compliant/expr.py | 14 ++++++- tests/plan/list_len_test.py | 55 +++++++++++++++++++++++++++ 6 files changed, 123 insertions(+), 16 deletions(-) create mode 100644 tests/plan/list_len_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 65372bafce..7ee0abdda3 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -18,7 +18,11 @@ from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co from narwhals._plan.common import temp -from narwhals._plan.compliant.accessors import ExprCatNamespace, ExprStructNamespace +from narwhals._plan.compliant.accessors import ( + ExprCatNamespace, + ExprListNamespace, + ExprStructNamespace, +) from narwhals._plan.compliant.column import ExprDispatch from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar @@ -51,6 +55,7 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction + from narwhals._plan.expressions import BinaryExpr, FunctionExpr as FExpr, lists from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -77,7 +82,6 @@ Not, ) from narwhals._plan.expressions.categorical import GetCategories - from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr as FExpr from narwhals._plan.expressions.functions import ( Abs, CumAgg, @@ -637,6 +641,10 @@ def rolling_expr( def cat(self) -> ArrowCatNamespace[Expr]: return ArrowCatNamespace(self) + @property + def list(self) -> ArrowListNamespace[Expr]: + return ArrowListNamespace(self) + @property def struct(self) -> ArrowStructNamespace[Expr]: return ArrowStructNamespace(self) @@ -755,6 +763,10 @@ def drop_nulls( # type: ignore[override] def cat(self) -> ArrowCatNamespace[Scalar]: return ArrowCatNamespace(self) + @property + def list(self) -> ArrowListNamespace[Scalar]: + return ArrowListNamespace(self) + @property def struct(self) -> ArrowStructNamespace[Scalar]: return ArrowStructNamespace(self) @@ -792,23 +804,27 @@ def __narwhals_namespace__(self) -> ArrowNamespace: def version(self) -> Version: return self.compliant.version - def with_native(self, native: Any, name: str, /) -> ExprOrScalarT: - return self.compliant.from_native(native, name, self.version) + def with_native(self, native: ChunkedOrScalarAny, name: str, /) -> Expr | Scalar: + return self.compliant._with_native(native, name) class ArrowCatNamespace(ExprCatNamespace["Frame", "Expr"], ArrowAccessor[ExprOrScalarT]): - def get_categories( - self, node: ir.FunctionExpr[GetCategories], frame: Frame, name: str - ) -> Expr: - native = node.input[0].dispatch(self._compliant, frame, name).native + def get_categories(self, node: FExpr[GetCategories], frame: Frame, name: str) -> Expr: + native = node.input[0].dispatch(self.compliant, frame, name).native return ArrowExpr.from_native(fn.get_categories(native), name, self.version) +class ArrowListNamespace( + ExprListNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] +): + def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + return self.with_native(fn.list_len(native), name) + + class ArrowStructNamespace( - ExprStructNamespace["Frame", ExprOrScalarT], ArrowAccessor[ExprOrScalarT] + ExprStructNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): - def field( - self, node: ir.FunctionExpr[FieldByName], frame: Frame, name: str - ) -> ExprOrScalarT: - native = node.input[0].dispatch(self._compliant, frame, name).native + def field(self, node: FExpr[FieldByName], frame: Frame, name: str) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native return self.with_native(fn.struct_field(native, node.function.name), name) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 2895ed6d7e..d9240d836a 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -43,6 +43,7 @@ BooleanLengthPreserving, ChunkedArray, ChunkedArrayAny, + ChunkedList, ChunkedOrArray, ChunkedOrArrayAny, ChunkedOrArrayT, @@ -56,6 +57,8 @@ IntegerScalar, IntegerType, LargeStringType, + ListArray, + ListScalar, NativeScalar, Predicate, SameArrowT, @@ -301,6 +304,20 @@ def get_categories(native: ArrowAny) -> ChunkedArrayAny: return chunked_array(da.dictionary) +@t.overload +def list_len(native: ChunkedList) -> ChunkedArray[pa.UInt32Scalar]: ... +@t.overload +def list_len(native: ListArray) -> pa.UInt32Array: ... +@t.overload +def list_len(native: ListScalar) -> pa.UInt32Scalar: ... +@t.overload +def list_len(native: SameArrowT) -> SameArrowT: ... +def list_len(native: ArrowAny) -> ArrowAny: + length: Incomplete = pc.list_value_length + result: ArrowAny = length(native).cast(pa.uint32()) + return result + + @t.overload def when_then( predicate: Predicate, then: SameArrowT, otherwise: SameArrowT diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 77c6eae71c..ce55cf4b95 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -32,6 +32,7 @@ IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" IntegerScalar: TypeAlias = "Scalar[IntegerType]" DateScalar: TypeAlias = "Scalar[Date32Type]" + ListScalar: TypeAlias = "Scalar[pa.ListType[Any]]" class NativeArrowSeries(NativeSeries, Protocol): @property @@ -180,6 +181,8 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ChunkedStruct: TypeAlias = "ChunkedArray[pa.StructScalar]" StructArray: TypeAlias = "pa.StructArray | Array[pa.StructScalar]" +ChunkedList: TypeAlias = "ChunkedArray[ListScalar]" +ListArray: TypeAlias = "Array[ListScalar]" Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 6df20d495c..d46eb46b42 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -5,7 +5,7 @@ from narwhals._plan.compliant.typing import ExprT_co, FrameT_contra if TYPE_CHECKING: - from narwhals._plan.expressions import FunctionExpr as FExpr + from narwhals._plan.expressions import FunctionExpr as FExpr, lists from narwhals._plan.expressions.categorical import GetCategories from narwhals._plan.expressions.struct import FieldByName @@ -16,6 +16,12 @@ def get_categories( ) -> ExprT_co: ... +class ExprListNamespace(Protocol[FrameT_contra, ExprT_co]): + def len( + self, node: FExpr[lists.Len], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + + class ExprStructNamespace(Protocol[FrameT_contra, ExprT_co]): def field( self, node: FExpr[FieldByName], frame: FrameT_contra, name: str diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 3d6e1d4ec1..bf66e5d94f 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -16,7 +16,11 @@ from typing_extensions import Self, TypeAlias from narwhals._plan import expressions as ir - from narwhals._plan.compliant.accessors import ExprCatNamespace, ExprStructNamespace + from narwhals._plan.compliant.accessors import ( + ExprCatNamespace, + ExprListNamespace, + ExprStructNamespace, + ) from narwhals._plan.compliant.scalar import CompliantScalar, EagerScalar from narwhals._plan.expressions import ( BinaryExpr, @@ -248,7 +252,13 @@ def cat( self, ) -> ExprCatNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... @property - def struct(self) -> ExprStructNamespace[FrameT_contra, Self]: ... + def list( + self, + ) -> ExprListNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... + @property + def struct( + self, + ) -> ExprStructNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... class EagerExpr( diff --git a/tests/plan/list_len_test.py b/tests/plan/list_len_test.py new file mode 100644 index 0000000000..f23c8c61be --- /dev/null +++ b/tests/plan/list_len_test.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [[1, 2], [3, 4, None], None, [], [None]], "i": [4, 3, 2, 1, 0]} + + +a = nwp.nth(0) + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + (a.list.len(), {"a": [2, 3, None, 0, 1]}), + ( + [a.first().list.len().alias("first"), a.last().list.len().alias("last")], + {"first": [2], "last": [1]}, + ), + ( # NOTE: `polars` produces nulls following the `over(order_by=...)` + # That's either a bug, or something that won't be ported to `narwhals` + [ + a.first().over(order_by="i").list.len().alias("first_order_i"), + a.last().over(order_by="i").list.len().alias("last_order_i"), + ], + {"first_order_i": [1], "last_order_i": [2]}, + ), + ( + # NOTE: This does work already in `polars` + [ + a.sort_by("i").first().list.len().alias("sort_by_i_first"), + a.sort_by("i").last().list.len().alias("sort_by_i_last"), + ], + {"sort_by_i_first": [1], "sort_by_i_last": [2]}, + ), + ], +) +def test_list_len(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: + df = dataframe(data).with_columns(a.cast(nw.List(nw.Int32()))) + result = df.select(exprs) + assert_equal_data(result, expected) From e6e2045dc7424f66a282d2cf6c6b94687fca41ea Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Nov 2025 12:19:38 +0000 Subject: [PATCH 088/215] cov --- narwhals/_plan/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8ecfe96912..8f2e752d9b 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -575,7 +575,7 @@ def name(self) -> ExprNameNamespace: return ExprNameNamespace(_expr=self) @property - def cat(self) -> ExprCatNamespace: # pragma: no cover + def cat(self) -> ExprCatNamespace: from narwhals._plan.expressions.categorical import ExprCatNamespace return ExprCatNamespace(_expr=self) @@ -593,7 +593,7 @@ def dt(self) -> ExprDateTimeNamespace: return ExprDateTimeNamespace(_expr=self) @property - def list(self) -> ExprListNamespace: # pragma: no cover + def list(self) -> ExprListNamespace: from narwhals._plan.expressions.lists import ExprListNamespace return ExprListNamespace(_expr=self) From d45f45cd6659cbdba82129228a3ed2cd0f6f375b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:51:22 +0000 Subject: [PATCH 089/215] feat: Add `{DataFrame,Series}.gather_every` --- narwhals/_plan/arrow/common.py | 3 ++ narwhals/_plan/arrow/expr.py | 7 ++- narwhals/_plan/compliant/dataframe.py | 1 + narwhals/_plan/compliant/series.py | 1 + narwhals/_plan/dataframe.py | 3 ++ narwhals/_plan/series.py | 3 ++ tests/plan/gather_test.py | 72 +++++++++++++++++++++++++++ 7 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 tests/plan/gather_test.py diff --git a/narwhals/_plan/arrow/common.py b/narwhals/_plan/arrow/common.py index fdbe173f2c..217f0a193d 100644 --- a/narwhals/_plan/arrow/common.py +++ b/narwhals/_plan/arrow/common.py @@ -57,5 +57,8 @@ def gather(self, indices: Indices | _StoresNative[ChunkedArrayAny]) -> Self: ca = self._gather(indices.native if is_series(indices) else indices) return self._with_native(ca) + def gather_every(self, n: int, offset: int = 0) -> Self: + return self._with_native(self.native[offset::n]) + def slice(self, offset: int, length: int | None = None) -> Self: return self._with_native(self.native.slice(offset=offset, length=length)) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 7ee0abdda3..46d5659ab4 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -573,11 +573,10 @@ def unique(self, node: FExpr[F.Unique], frame: Frame, name: str) -> Self: result = self._dispatch_expr(node.input[0], frame, name).native.unique() return self._with_native(result, name) - # TODO @dangotbanned: Only implement in `ArrowSeries` and reuse def gather_every(self, node: FExpr[F.GatherEvery], frame: Frame, name: str) -> Self: - native = self._dispatch_expr(node.input[0], frame, name).native - result = native[node.function.offset :: node.function.n] - return self._with_native(result, name) + series = self._dispatch_expr(node.input[0], frame, name) + n, offset = node.function.n, node.function.offset + return self.from_series(series.gather_every(n=n, offset=offset)) def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: return self._vector_function(fn.drop_nulls)(node, frame, name) diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index 5e64e105b5..f3be22775a 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -105,6 +105,7 @@ def native(self) -> NativeDataFrameT: def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... + def gather_every(self, n: int, offset: int = 0) -> Self: ... def get_column(self, name: str) -> SeriesT: ... def group_by_agg( self, by: OneOrIterable[IntoExpr], aggs: OneOrIterable[IntoExpr], / diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 6d1a587974..aed58350f2 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -135,6 +135,7 @@ def gather( self, indices: SizedMultiIndexSelector[NativeSeriesT] | _StoresNative[NativeSeriesT], ) -> Self: ... + def gather_every(self, n: int, offset: int = 0) -> Self: ... def has_nulls(self) -> bool: ... def is_empty(self) -> bool: return len(self) == 0 diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 59a22c5438..ccf78d3978 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -217,6 +217,9 @@ def to_series(self, index: int = 0) -> Series[NativeSeriesT]: # pragma: no cove def to_polars(self) -> pl.DataFrame: return self._compliant.to_polars() + def gather_every(self, n: int, offset: int = 0) -> Self: + return self._with_compliant(self._compliant.gather_every(n, offset)) + def get_column(self, name: str) -> Series[NativeSeriesT]: # pragma: no cover return self._series(self._compliant.get_column(name)) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 0220253087..9b12165525 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -122,6 +122,9 @@ def gather(self, indices: SizedMultiIndexSelector[Self]) -> Self: # pragma: no rows = indices._compliant if isinstance(indices, Series) else indices return type(self)(self._compliant.gather(rows)) + def gather_every(self, n: int, offset: int = 0) -> Self: + return type(self)(self._compliant.gather_every(n, offset)) + def has_nulls(self) -> bool: # pragma: no cover return self._compliant.has_nulls() diff --git a/tests/plan/gather_test.py b/tests/plan/gather_test.py new file mode 100644 index 0000000000..f2e3453509 --- /dev/null +++ b/tests/plan/gather_test.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from narwhals.exceptions import ShapeError +from tests.plan.utils import assert_equal_data, assert_equal_series, dataframe, series + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "idx": list(range(10)), + "name": ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"], + } + + +@pytest.mark.parametrize("n", [1, 2, 3]) +@pytest.mark.parametrize("offset", [0, 1, 2, 3]) +@pytest.mark.parametrize("column", ["idx", "name"]) +def test_gather_every_series(data: Data, n: int, offset: int, column: str) -> None: + ser = series(data[column]).alias(column) + result = ser.gather_every(n, offset) + expected = data[column][offset::n] + assert_equal_series(result, expected, column) + + +@pytest.mark.parametrize("n", [1, 2, 3]) +@pytest.mark.parametrize("offset", [0, 1, 2, 3]) +def test_gather_every_dataframe(data: Data, n: int, offset: int) -> None: + result = dataframe(data).gather_every(n, offset) + indices = slice(offset, None, n) + expected = {"idx": data["idx"][indices], "name": data["name"][indices]} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize("n", [1, 2, 3]) +@pytest.mark.parametrize("offset", [0, 1, 2, 3]) +def test_gather_every_expr(data: Data, n: int, offset: int) -> None: + df = dataframe(data) + indices = slice(offset, None, n) + v_idx, v_name = data["idx"][indices], data["name"][indices] + e_idx, e_name = nwp.col("idx"), nwp.col("name") + gather = partial(nwp.Expr.gather_every, n=n, offset=offset) + + result = df.select(gather(nwp.col("idx", "name"))) + expected = {"idx": v_idx, "name": v_name} + assert_equal_data(result, expected) + expected = {"name": v_name} + assert_equal_data(df.select(gather(e_name)), expected) + expected = {"name": v_name, "idx": v_idx} + assert_equal_data(df.select(gather(nwp.nth(1, 0))), expected) + expected = {"idx": v_idx, "name": v_name} + assert_equal_data(df.select(gather(e_idx), gather(ncs.last())), expected) + + if n == 1 and offset == 0: + result = df.select(gather(e_name), e_idx) + expected = {"name": data["name"], "idx": data["idx"]} + assert_equal_data(result, expected) + else: + with pytest.raises(ShapeError): + df.select(gather(e_name), e_idx) + result = df.select(gather(e_name), e_idx.first()) + expected = {"name": v_name, "idx": [0] * len(result)} + assert_equal_data(result, expected) From f493547d9a4c52c77c1c7bf07079b53aff9869f6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Nov 2025 16:49:52 +0000 Subject: [PATCH 090/215] refactor: Clear out the cobwebs Mostly typing, but generally things that made sense in an earlier version --- narwhals/_plan/arrow/options.py | 19 +++++++------------ narwhals/_plan/compliant/typing.py | 7 ------- narwhals/_plan/expressions/boolean.py | 4 ---- narwhals/_plan/typing.py | 3 +-- 4 files changed, 8 insertions(+), 25 deletions(-) diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 83e73eff87..254f6ca6da 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -109,21 +109,16 @@ def _sort_keys_every( return tuple((key, order) for key in by) -def _sort_keys( - by: tuple[str, ...], *, descending: bool | Sequence[bool] -) -> Seq[tuple[str, Order]]: - if not isinstance(descending, bool) and len(descending) == 1: - descending = descending[0] - if isinstance(descending, bool): - return _sort_keys_every(by, descending=descending) - it = zip_strict(by, descending) - return tuple(_sort_key(key, descending=desc) for (key, desc) in it) - - def sort( *by: str, descending: bool | Sequence[bool] = False, nulls_last: bool = False ) -> pc.SortOptions: - keys = _sort_keys(by, descending=descending) + if not isinstance(descending, bool) and len(descending) == 1: + descending = descending[0] + if isinstance(descending, bool): + keys = _sort_keys_every(by, descending=descending) + else: + it = zip_strict(by, descending) + keys = tuple(_sort_key(key, descending=desc) for (key, desc) in it) return pc.SortOptions(sort_keys=keys, null_placement=NULL_PLACEMENT[nulls_last]) diff --git a/narwhals/_plan/compliant/typing.py b/narwhals/_plan/compliant/typing.py index 91ad9320ed..ed9af83a50 100644 --- a/narwhals/_plan/compliant/typing.py +++ b/narwhals/_plan/compliant/typing.py @@ -20,17 +20,11 @@ from narwhals._plan.compliant.series import CompliantSeries from narwhals._utils import Version -T = TypeVar("T") R_co = TypeVar("R_co", covariant=True) LengthT = TypeVar("LengthT") -NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) - ConcatT1 = TypeVar("ConcatT1") ConcatT2 = TypeVar("ConcatT2", default=ConcatT1) - -ColumnT = TypeVar("ColumnT") ColumnT_co = TypeVar("ColumnT_co", covariant=True) - ResolverT_co = TypeVar("ResolverT_co", bound="GroupByResolver", covariant=True) ExprAny: TypeAlias = "CompliantExpr[Any, Any]" @@ -48,7 +42,6 @@ LazyScalarAny: TypeAlias = "LazyScalar[Any, Any, Any]" ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) -ScalarT = TypeVar("ScalarT", bound=ScalarAny) ScalarT_co = TypeVar("ScalarT_co", bound=ScalarAny, covariant=True) SeriesT = TypeVar("SeriesT", bound=SeriesAny) SeriesT_co = TypeVar("SeriesT_co", bound=SeriesAny, covariant=True) diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index b6bc93dd88..4aa956559e 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -8,7 +8,6 @@ from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._plan.typing import NativeSeriesT -from narwhals._typing_compat import TypeVar if TYPE_CHECKING: from typing_extensions import Self @@ -19,9 +18,6 @@ from narwhals._plan.typing import Seq from narwhals.typing import ClosedInterval -OtherT = TypeVar("OtherT") -ExprT = TypeVar("ExprT", bound="ExprIR", default="ExprIR") - # fmt: off class BooleanFunction(Function, options=FunctionOptions.elementwise): ... diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 4fdf78dba2..f8d212432a 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -96,7 +96,6 @@ NativeSeriesT_co = TypeVar( "NativeSeriesT_co", bound="NativeSeries", covariant=True, default="NativeSeries" ) -NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame", default="NativeFrame") NativeFrameT_co = TypeVar( "NativeFrameT_co", bound="NativeFrame", covariant=True, default="NativeFrame" ) @@ -127,7 +126,7 @@ IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" ColumnNameOrSelector: TypeAlias = "str | Selector" -OneOrIterable: TypeAlias = "T | t.Iterable[T]" +OneOrIterable: TypeAlias = "T | Iterable[T]" OneOrSeq: TypeAlias = t.Union[T, Seq[T]] DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") Order: TypeAlias = t.Literal["ascending", "descending"] From 94b676c291e507938cd946ba30ef35b4c5b889ed Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Nov 2025 18:08:24 +0000 Subject: [PATCH 091/215] remove unused keywords from `kurtosis` --- narwhals/_plan/expr.py | 4 ++-- narwhals/_plan/expressions/functions.py | 7 +------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8f2e752d9b..65a85a12bd 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -217,8 +217,8 @@ def exp(self) -> Self: def sqrt(self) -> Self: return self._with_unary(F.Sqrt()) - def kurtosis(self, *, fisher: bool = True, bias: bool = True) -> Self: - return self._with_unary(F.Kurtosis(fisher=fisher, bias=bias)) + def kurtosis(self) -> Self: + return self._with_unary(F.Kurtosis()) def null_count(self) -> Self: return self._with_unary(F.NullCount()) diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index b148c8843f..ea232d4724 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -47,6 +47,7 @@ class Sqrt(Function, options=FunctionOptions.elementwise): ... class DropNulls(Function, options=FunctionOptions.row_separable): ... class ModeAll(Function): ... class ModeAny(Function, options=FunctionOptions.aggregation): ... +class Kurtosis(Function, options=FunctionOptions.aggregation): ... class Skew(Function, options=FunctionOptions.aggregation): ... class Clip(Function, options=FunctionOptions.elementwise): def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, ExprIR]: @@ -128,12 +129,6 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: return base, exponent -class Kurtosis(Function, options=FunctionOptions.aggregation): - __slots__ = ("bias", "fisher") - fisher: bool - bias: bool - - class FillNull(Function, options=FunctionOptions.elementwise): """N-ary (expr, value).""" From f3aacf0ffadaac15303ac2b37ceb2da9f8d25fa9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Nov 2025 22:32:36 +0000 Subject: [PATCH 092/215] feat: More `nan`s and `null`s - `fill_nan`, `null_count` are on `main` - the other two are 2/3 of #3028 --- narwhals/_plan/arrow/expr.py | 21 +++++++- narwhals/_plan/arrow/functions.py | 24 +++++++-- narwhals/_plan/arrow/series.py | 7 +++ narwhals/_plan/arrow/typing.py | 1 + narwhals/_plan/compliant/expr.py | 11 ++++ narwhals/_plan/compliant/series.py | 2 + narwhals/_plan/expr.py | 15 ++++++ narwhals/_plan/expressions/__init__.py | 2 + narwhals/_plan/expressions/boolean.py | 2 + narwhals/_plan/expressions/expr.py | 5 ++ narwhals/_plan/expressions/functions.py | 8 +++ narwhals/_plan/series.py | 7 +++ narwhals/_plan/when_then.py | 9 +--- tests/plan/compliant_test.py | 2 +- tests/plan/fill_nan_test.py | 67 +++++++++++++++++++++++++ 15 files changed, 171 insertions(+), 12 deletions(-) create mode 100644 tests/plan/fill_nan_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 46d5659ab4..bc6e0ff5d0 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -35,6 +35,8 @@ IsInSeq, IsInSeries, IsLastDistinct, + IsNotNan, + IsNotNull, IsUnique, ) from narwhals._plan.expressions.functions import NullCount @@ -86,6 +88,7 @@ Abs, CumAgg, Diff, + FillNan, FillNull, NullCount, Pow, @@ -131,6 +134,12 @@ def fill_null( value_ = value.dispatch(self, frame, "value").native return self._with_native(pc.fill_null(native, value_), name) + def fill_nan(self, node: FExpr[FillNan], frame: Frame, name: str) -> StoresNativeT_co: + expr, value = node.function.unwrap_input(node) + native = expr.dispatch(self, frame, name).native + value_ = value.dispatch(self, frame, "value").native + return self._with_native(fn.fill_nan(native, value_), name) + def is_between( self, node: FExpr[IsBetween], frame: Frame, name: str ) -> StoresNativeT_co: @@ -154,7 +163,7 @@ def abs(self, node: FExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.abs)(node, frame, name) def not_(self, node: FExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(pc.invert)(node, frame, name) + return self._unary_function(fn.not_)(node, frame, name) def all(self, node: FExpr[All], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.all_)(node, frame, name) @@ -198,6 +207,16 @@ def is_nan(self, node: FExpr[IsNan], frame: Frame, name: str) -> StoresNativeT_c def is_null(self, node: FExpr[IsNull], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.is_null)(node, frame, name) + def is_not_nan( + self, node: FExpr[IsNotNan], frame: Frame, name: str + ) -> StoresNativeT_co: + return self._unary_function(fn.is_not_nan)(node, frame, name) + + def is_not_null( + self, node: FExpr[IsNotNull], frame: Frame, name: str + ) -> StoresNativeT_co: + return self._unary_function(fn.is_not_null)(node, frame, name) + def binary_expr(self, node: BinaryExpr, frame: Frame, name: str) -> StoresNativeT_co: lhs, rhs = ( node.left.dispatch(self, frame, name), diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d9240d836a..3c9be52abf 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -33,6 +33,7 @@ from narwhals._plan.arrow.typing import ( Array, ArrayAny, + Arrow, ArrowAny, ArrowT, BinaryComp, @@ -49,6 +50,7 @@ ChunkedOrArrayT, ChunkedOrScalar, ChunkedOrScalarAny, + ChunkedOrScalarT, ChunkedStruct, DataType, DataTypeRemap, @@ -114,8 +116,14 @@ class MinMax(ir.AggExpr): is_null = pc.is_null is_not_null = t.cast("UnaryFunction[ScalarAny,pa.BooleanScalar]", pc.is_valid) -is_nan = pc.is_nan -is_finite = pc.is_finite +is_nan = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_nan) +is_finite = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_finite) +not_ = t.cast("UnaryFunction[pa.BooleanScalar ,pa.BooleanScalar]", pc.invert) + + +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: + return not_(is_nan(native)) + and_ = t.cast("BinaryLogical", pc.and_kleene) or_ = t.cast("BinaryLogical", pc.or_kleene) @@ -558,6 +566,16 @@ def fill_null( return pc.fill_null(native, fill_value) +@t.overload +def fill_nan( + native: ChunkedOrScalarT, value: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarT: ... +@t.overload +def fill_nan(native: SameArrowT, value: NonNestedLiteral | ArrowAny) -> SameArrowT: ... +def fill_nan(native: ArrowAny, value: NonNestedLiteral | ArrowAny) -> Incomplete: + return when_then(is_not_nan(native), native, value) + + def fill_null_forward(native: ChunkedArrayAny) -> ChunkedArrayAny: return fill_null_with_strategy(native, "forward") @@ -662,7 +680,7 @@ def _boolean_is_unique( def _boolean_is_duplicated( indices: ChunkedArrayAny, aggregated: ChunkedStruct, / ) -> ChunkedArrayAny: - return pc.invert(_boolean_is_unique(indices, aggregated)) + return not_(_boolean_is_unique(indices, aggregated)) BOOLEAN_LENGTH_PRESERVING: Mapping[ diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 7344c5bdd3..f54b5eec72 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -102,6 +102,9 @@ def is_in(self, other: Self) -> Self: def has_nulls(self) -> bool: return bool(self.native.null_count) + def null_count(self) -> int: + return self.native.null_count + __add__ = fn.bin_op(fn.add) __and__ = fn.bin_op(fn.and_) __eq__ = fn.bin_op(fn.eq) @@ -157,6 +160,10 @@ def cum_prod(self, *, reverse: bool = False) -> Self: return self._with_native(fn.cum_prod(self.native)) return self._with_native(fn.cumulative(self.native, F.CumProd(reverse=reverse))) + def fill_nan(self, value: float | Self | None) -> Self: + fill_value = value.native if isinstance(value, ArrowSeries) else value + return self._with_native(fn.fill_nan(self.native, fill_value)) + def fill_null(self, value: NonNestedLiteral | Self) -> Self: fill_value = value.native if isinstance(value, ArrowSeries) else value return self._with_native(fn.fill_null(self.native, fill_value)) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index ce55cf4b95..af67f7f1e2 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -177,6 +177,7 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ChunkedOrScalarAny: TypeAlias = "ChunkedOrScalar[ScalarAny]" ChunkedOrArrayAny: TypeAlias = "ChunkedOrArray[ScalarAny]" ChunkedOrArrayT = TypeVar("ChunkedOrArrayT", ChunkedArrayAny, ArrayAny) +ChunkedOrScalarT = TypeVar("ChunkedOrScalarT", ChunkedArrayAny, ScalarAny) Indices: TypeAlias = "_SizedMultiIndexSelector[ChunkedOrArray[pc.IntegerScalar]]" ChunkedStruct: TypeAlias = "ChunkedArray[pa.StructScalar]" diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index bf66e5d94f..2e57521eb2 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -35,6 +35,8 @@ IsFirstDistinct, IsLastDistinct, IsNan, + IsNotNan, + IsNotNull, IsNull, Not, ) @@ -67,6 +69,9 @@ def ewm_mean( def fill_null( self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str ) -> Self: ... + def fill_nan( + self, node: FunctionExpr[F.FillNan], frame: FrameT_contra, name: str + ) -> Self: ... def is_between( self, node: FunctionExpr[IsBetween], frame: FrameT_contra, name: str ) -> Self: ... @@ -85,6 +90,12 @@ def is_nan( def is_null( self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str ) -> Self: ... + def is_not_nan( + self, node: FunctionExpr[IsNotNan], frame: FrameT_contra, name: str + ) -> Self: ... + def is_not_null( + self, node: FunctionExpr[IsNotNull], frame: FrameT_contra, name: str + ) -> Self: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index aed58350f2..a9b3ee3893 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -126,6 +126,7 @@ def cum_min(self, *, reverse: bool = False) -> Self: ... def cum_prod(self, *, reverse: bool = False) -> Self: ... def cum_sum(self, *, reverse: bool = False) -> Self: ... def diff(self, n: int = 1) -> Self: ... + def fill_nan(self, value: float | Self | None) -> Self: ... def fill_null(self, value: NonNestedLiteral | Self) -> Self: ... def fill_null_with_strategy( self, strategy: FillNullStrategy, limit: int | None = None @@ -137,6 +138,7 @@ def gather( ) -> Self: ... def gather_every(self, n: int, offset: int = 0) -> Self: ... def has_nulls(self) -> bool: ... + def null_count(self) -> int: ... def is_empty(self) -> bool: return len(self) == 0 diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 65a85a12bd..af21766576 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -223,6 +223,15 @@ def kurtosis(self) -> Self: def null_count(self) -> Self: return self._with_unary(F.NullCount()) + def fill_nan(self, value: float | Self | None) -> Self: + fill_value = parse_into_expr_ir(value, str_as_lit=True) + root = self._ir + if any(e.meta.has_multiple_outputs() for e in (root, fill_value)): + return self._from_ir(F.FillNan().to_function_expr(root, fill_value)) + # https://github.com/pola-rs/polars/blob/e1d6f294218a36497255e2d872c223e19a47e2ec/crates/polars-plan/src/dsl/mod.rs#L894-L902 + predicate = self.is_not_nan() | self.is_null() + return self._from_ir(ir.ternary_expr(predicate._ir, root, fill_value)) + def fill_null( self, value: IntoExpr = None, @@ -422,6 +431,12 @@ def is_nan(self) -> Self: def is_null(self) -> Self: return self._with_unary(ir.boolean.IsNull()) + def is_not_nan(self) -> Self: + return self._with_unary(ir.boolean.IsNotNan()) + + def is_not_null(self) -> Self: + return self._with_unary(ir.boolean.IsNotNull()) + def is_first_distinct(self) -> Self: return self._with_unary(ir.boolean.IsFirstDistinct()) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 3b9f0364ab..56e4f5d4ce 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -42,6 +42,7 @@ TernaryExpr, WindowExpr, col, + ternary_expr, ) from narwhals._plan.expressions.name import KeepName, RenameAlias from narwhals._plan.expressions.window import over, over_ordered @@ -91,4 +92,5 @@ "strings", "struct", "temporal", + "ternary_expr", ] diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index 4aa956559e..97d33fa31e 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -31,6 +31,8 @@ class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving class IsLastDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... class IsNan(BooleanFunction): ... class IsNull(BooleanFunction): ... +class IsNotNan(BooleanFunction): ... +class IsNotNull(BooleanFunction): ... class IsUnique(BooleanFunction, options=FunctionOptions.length_preserving): ... class Not(BooleanFunction, config=FEOptions.renamed("not_")): ... # fmt: on diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 26b475b51d..7f173d2b1b 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -65,6 +65,7 @@ "TernaryExpr", "WindowExpr", "col", + "ternary_expr", ] @@ -544,3 +545,7 @@ def matches(self, dtype: IntoDType) -> bool: def to_dtype_selector(self) -> Self: return replace(self, selector=self.selector.to_dtype_selector()) + + +def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: + return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index ea232d4724..f29ef5f554 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -137,6 +137,14 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: return expr, value +class FillNan(Function, options=FunctionOptions.elementwise): + """N-ary (expr, value).""" + + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, value = node.input + return expr, value + + class FillNullWithStrategy(Function): __slots__ = ("limit", "strategy") strategy: FillNullStrategy diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 9b12165525..53b9806052 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -176,6 +176,13 @@ def scatter( def is_in(self, other: Iterable[Any]) -> Self: return type(self)(self._compliant.is_in(self._parse_into_compliant(other))) + def null_count(self) -> int: + return self._compliant.null_count() + + def fill_nan(self, value: float | Self | None) -> Self: + other = self._unwrap_compliant(value) if is_series(value) else value + return type(self)(self._compliant.fill_nan(other)) + class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index ce51e19087..63347ee429 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -8,10 +8,11 @@ parse_predicates_constraints_into_expr_ir, ) from narwhals._plan.expr import Expr +from narwhals._plan.expressions import ternary_expr from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: - from narwhals._plan.expressions import ExprIR, TernaryExpr + from narwhals._plan.expressions import ExprIR from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq @@ -114,9 +115,3 @@ def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] def __eq__(self, other: IntoExpr) -> Expr: # type: ignore[override] return Expr.__eq__(self, other) - - -def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: - from narwhals._plan.expressions.expr import TernaryExpr - - return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index ba9c95b0f0..d4e55cae93 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -203,7 +203,7 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: ), (nwp.col("e", "d").is_null().any(), {"e": [True], "d": [False]}), ( - [(~nwp.col("e", "d").is_null()).all(), "b"], + [(nwp.col("e", "d").is_not_null()).all(), "b"], {"e": [False, False, False], "d": [True, True, True], "b": [1, 2, 3]}, ), pytest.param( diff --git a/tests/plan/fill_nan_test.py b/tests/plan/fill_nan_test.py new file mode 100644 index 0000000000..a0c57fc90a --- /dev/null +++ b/tests/plan/fill_nan_test.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from tests.plan.utils import assert_equal_data, assert_equal_series, dataframe, series + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"int": [-1, 1, None]} + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + ( + [nwp.col("no_nan").fill_nan(None), nwp.col("float_nan").fill_nan(None)], + [None, 1.0, None], + ), + ( + [nwp.col("no_nan").fill_nan(3.0), nwp.col("float_nan").fill_nan(3.0)], + [3.0, 1.0, None], + ), + (nwp.all().fill_nan(None), [None, 1.0, None]), + (nwp.all().fill_nan(3.0), [3.0, 1.0, None]), + ( + ncs.numeric().as_expr().fill_nan(nwp.lit(series([55.5, -100, -200]))), + [55.5, 1.0, None], + ), + ( + [ + nwp.col("no_nan"), + nwp.col("float_nan").fill_nan(nwp.col("no_nan").max() * 6), + ], + [6.0, 1.0, None], + ), + ], +) +def test_fill_nan( + data: Data, exprs: OneOrIterable[nwp.Expr], expected: list[Any] +) -> None: + base = nwp.col("int") + df = dataframe(data).select( + base.cast(nw.Float64).alias("no_nan"), (base**0.5).alias("float_nan") + ) + result = df.select(exprs) + assert_equal_data(result, {"no_nan": [-1.0, 1.0, None], "float_nan": expected}) + assert result.get_column("float_nan").null_count() == expected.count(None) + + +def test_fill_nan_series(data: Data) -> None: + ser = dataframe(data).select(float_nan=nwp.col("int") ** 0.5).get_column("float_nan") + result = ser.fill_nan(999) + assert_equal_series(result, [999.0, 1.0, None], "float_nan") + result = ser.fill_nan(series([1.23, None, None])) + assert_equal_series(result, [1.23, 1.0, None], "float_nan") From 05e5c7aa5c44e2a4ba79d1b6e4423867194c8223 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Nov 2025 22:44:56 +0000 Subject: [PATCH 093/215] Mark more features as not implemented a mix of things: - added after the dsl was written - or deprecated before - or not support in `pyarrow` - or I missed it oops --- narwhals/_plan/expr.py | 9 ++++++++- narwhals/_plan/expressions/functions.py | 2 ++ narwhals/_plan/expressions/lists.py | 5 +++++ narwhals/_plan/expressions/strings.py | 7 +++++++ narwhals/_plan/expressions/temporal.py | 3 +++ narwhals/_plan/functions.py | 12 ++++++++++++ 6 files changed, 37 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index af21766576..3e56f85639 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -24,7 +24,7 @@ SortOptions, rolling_options, ) -from narwhals._utils import Version, no_default +from narwhals._utils import Version, no_default, not_implemented from narwhals.exceptions import ComputeError if TYPE_CHECKING: @@ -619,6 +619,13 @@ def str(self) -> ExprStringNamespace: return ExprStringNamespace(_expr=self) + is_close = not_implemented() + sample = not_implemented() + head = not_implemented() + tail = not_implemented() + ceil = not_implemented() + floor = not_implemented() + class ExprV1(Expr): _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index f29ef5f554..379cb7b2f9 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -44,6 +44,8 @@ class Abs(Function, options=FunctionOptions.elementwise): ... class NullCount(Function, options=FunctionOptions.aggregation): ... class Exp(Function, options=FunctionOptions.elementwise): ... class Sqrt(Function, options=FunctionOptions.elementwise): ... +class Ceil(Function, options=FunctionOptions.elementwise): ... +class Floor(Function, options=FunctionOptions.elementwise): ... class DropNulls(Function, options=FunctionOptions.row_separable): ... class ModeAll(Function): ... class ModeAny(Function, options=FunctionOptions.aggregation): ... diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 604e054a5e..8c4306378d 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -5,6 +5,7 @@ from narwhals._plan._function import Function from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions +from narwhals._utils import not_implemented if TYPE_CHECKING: from narwhals._plan.expr import Expr @@ -25,3 +26,7 @@ def _ir_namespace(self) -> type[IRListNamespace]: def len(self) -> Expr: return self._with_unary(self._ir.len()) + + get = not_implemented() + contains = not_implemented() + unique = not_implemented() diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 5d08412fdb..ebbaeb0177 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -5,6 +5,7 @@ from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions +from narwhals._utils import not_implemented if TYPE_CHECKING: from narwhals._plan.expr import Expr @@ -121,11 +122,13 @@ def _ir_namespace(self) -> type[IRStringNamespace]: def len_chars(self) -> Expr: return self._with_unary(self._ir.len_chars()) + # TODO @dangotbanned: Support `value: IntoExpr` def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 ) -> Expr: # pragma: no cover return self._with_unary(self._ir.replace(pattern, value, literal=literal, n=n)) + # TODO @dangotbanned: Support `value: IntoExpr` def replace_all( self, pattern: str, value: str, *, literal: bool = False ) -> Expr: # pragma: no cover @@ -163,3 +166,7 @@ def to_lowercase(self) -> Expr: # pragma: no cover def to_uppercase(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_uppercase()) + + to_date = not_implemented() + to_titlecase = not_implemented() + zfill = not_implemented() diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index 2ba22aa925..6e703899ba 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -6,6 +6,7 @@ from narwhals._plan._function import Function from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions +from narwhals._utils import not_implemented if TYPE_CHECKING: from typing_extensions import TypeAlias, TypeIs @@ -180,3 +181,5 @@ def timestamp(self, time_unit: TimeUnit = "us") -> Expr: def truncate(self, every: str) -> Expr: return self._with_unary(self._ir.truncate(every)) + + offset_by = not_implemented() diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 6feb2f1700..fcec956ff1 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -53,6 +53,16 @@ ] +def format() -> Expr: + msg = "nwp.format" + raise NotImplementedError(msg) + + +def coalesce() -> Expr: + msg = "nwp.coalesce" + raise NotImplementedError(msg) + + def col(*names: str | t.Iterable[str]) -> Expr: flat = tuple(flatten(names)) return ( @@ -113,11 +123,13 @@ def sum(*columns: str) -> Expr: return col(columns).sum() +# TODO @dangotbanned: Support `ignore_nulls=...` def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = _parse.parse_into_seq_of_expr_ir(*exprs) return ir.boolean.AllHorizontal().to_function_expr(*it).to_narwhals() +# TODO @dangotbanned: Support `ignore_nulls=...` def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = _parse.parse_into_seq_of_expr_ir(*exprs) return ir.boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() From 33d151162525cb9386f4ebbf872b2cee12a4b0a4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:39:48 +0000 Subject: [PATCH 094/215] feat: Add `Expr.{ceil,floor}` --- narwhals/_plan/arrow/expr.py | 6 ++++++ narwhals/_plan/compliant/expr.py | 7 ++++++- narwhals/_plan/expr.py | 8 ++++++-- tests/plan/ceil_floor_test.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 tests/plan/ceil_floor_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index bc6e0ff5d0..c65f056dc1 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -248,6 +248,12 @@ def round(self, node: FExpr[F.Round], frame: Frame, name: str) -> StoresNativeT_ native = node.input[0].dispatch(self, frame, name).native return self._with_native(fn.round(native, node.function.decimals), name) + def ceil(self, node: FExpr[F.Ceil], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(pc.ceil)(node, frame, name) + + def floor(self, node: FExpr[F.Floor], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(pc.floor)(node, frame, name) + def clip(self, node: FExpr[F.Clip], frame: Frame, name: str) -> StoresNativeT_co: expr, lower, upper = node.function.unwrap_input(node) result = fn.clip( diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 2e57521eb2..6a7a996090 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -257,7 +257,12 @@ def sqrt( def unique( self, node: FunctionExpr[F.Unique], frame: FrameT_contra, name: str ) -> Self: ... - + def ceil( + self, node: FunctionExpr[F.Ceil], frame: FrameT_contra, name: str + ) -> Self: ... + def floor( + self, node: FunctionExpr[F.Floor], frame: FrameT_contra, name: str + ) -> Self: ... @property def cat( self, diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 3e56f85639..2e055cb3d9 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -335,6 +335,12 @@ def unique(self) -> Self: def round(self, decimals: int = 0) -> Self: return self._with_unary(F.Round(decimals=decimals)) + def ceil(self) -> Self: + return self._with_unary(F.Ceil()) + + def floor(self) -> Self: + return self._with_unary(F.Floor()) + def ewm_mean( self, *, @@ -623,8 +629,6 @@ def str(self) -> ExprStringNamespace: sample = not_implemented() head = not_implemented() tail = not_implemented() - ceil = not_implemented() - floor = not_implemented() class ExprV1(Expr): diff --git a/tests/plan/ceil_floor_test.py b/tests/plan/ceil_floor_test.py new file mode 100644 index 0000000000..ff08fce41a --- /dev/null +++ b/tests/plan/ceil_floor_test.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [1.12345, 2.56789, 3.901234, -0.5], "b": [1.045, None, 2.221, -5.9446]} + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("a").ceil(), [2.0, 3.0, 4.0, 0.0]), + (nwp.col("a").floor(), [1.0, 2.0, 3.0, -1.0]), + (nwp.col("b").ceil(), [2.0, None, 3.0, -5.0]), + (nwp.col("b").floor(), [1.0, None, 2.0, -6.0]), + ], + ids=["ceil", "floor", "ceil-nulls", "floor-nulls"], +) +def test_ceil_floor(data: Data, expr: nwp.Expr, expected: list[Any]) -> None: + result = dataframe(data).select(result=expr) + assert_equal_data(result, {"result": expected}) From 237e0f2d0ec89f4d3ff8d0ebb355ad4df4df1384 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:20:23 +0000 Subject: [PATCH 095/215] start adding missing `list` ops --- narwhals/_plan/arrow/expr.py | 7 +++++ narwhals/_plan/compliant/accessors.py | 9 ++++++ narwhals/_plan/exceptions.py | 7 +++++ narwhals/_plan/expressions/lists.py | 45 +++++++++++++++++++++++---- 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index c65f056dc1..aa44e52a1a 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -845,6 +845,13 @@ def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: native = node.input[0].dispatch(self.compliant, frame, name).native return self.with_native(fn.list_len(native), name) + def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: + msg = "TODO: `ArrowExpr.list.get`" + raise NotImplementedError(msg) + + unique = not_implemented() + contains = not_implemented() + class ArrowStructNamespace( ExprStructNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index d46eb46b42..3f89a49c8e 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -17,9 +17,18 @@ def get_categories( class ExprListNamespace(Protocol[FrameT_contra, ExprT_co]): + def contains( + self, node: FExpr[lists.Contains], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def get( + self, node: FExpr[lists.Get], frame: FrameT_contra, name: str + ) -> ExprT_co: ... def len( self, node: FExpr[lists.Len], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def unique( + self, node: FExpr[lists.Unique], frame: FrameT_contra, name: str + ) -> ExprT_co: ... class ExprStructNamespace(Protocol[FrameT_contra, ExprT_co]): diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 05348372c0..23aa8af096 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -48,6 +48,13 @@ def function_expr_invalid_operation_error( return InvalidOperationError(msg) +def function_arg_non_scalar_error( + function: Function, arg_name: str, arg_value: Any +) -> InvalidOperationError: + msg = f"`{function!r}({arg_name}=...)` does not support non-scalar expression `{arg_value!r}`." + return InvalidOperationError(msg) + + # TODO @dangotbanned: Use arguments in error message def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001 msg = "bins must increase monotonically" diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 8c4306378d..0ae34f901a 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -3,20 +3,42 @@ from typing import TYPE_CHECKING, ClassVar from narwhals._plan._function import Function +from narwhals._plan._parse import parse_into_expr_ir +from narwhals._plan.exceptions import function_arg_non_scalar_error from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions -from narwhals._utils import not_implemented if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import ExprIR, FunctionExpr as FExpr + from narwhals._plan.typing import IntoExpr # fmt: off -class ListFunction(Function, accessor="list"): ... -class Len(ListFunction, options=FunctionOptions.elementwise): ... +class ListFunction(Function, accessor="list", options=FunctionOptions.elementwise): ... +class Len(ListFunction): ... +class Unique(ListFunction): ... +class Get(ListFunction): + __slots__ = ("index",) + index: int # fmt: on +class Contains(ListFunction): + """N-ary (expr, item).""" + + def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, item = node.input + return expr, item + + class IRListNamespace(IRNamespace): len: ClassVar = Len + unique: ClassVar = Unique + contains: ClassVar = Contains + + def get(self, index: int) -> Get: + return Get(index=index) class ExprListNamespace(ExprNamespace[IRListNamespace]): @@ -27,6 +49,17 @@ def _ir_namespace(self) -> type[IRListNamespace]: def len(self) -> Expr: return self._with_unary(self._ir.len()) - get = not_implemented() - contains = not_implemented() - unique = not_implemented() + def unique(self) -> Expr: + return self._with_unary(self._ir.unique()) + + # TODO @dangotbanned: Needs a full impl + def get(self, index: int) -> Expr: + return self._with_unary(self._ir.get(index)) + + # TODO @dangotbanned: Needs a `expr_parsing` test + def contains(self, item: IntoExpr) -> Expr: + item_ir = parse_into_expr_ir(item, str_as_lit=True) + contains = self._ir.contains() + if not item_ir.is_scalar: + raise function_arg_non_scalar_error(contains, "item", item_ir) + return self._expr._from_ir(contains.to_function_expr(self._expr._ir, item_ir)) From dd051a311e8a34d5fa57e17354b1ed5819da4ff4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:51:48 +0000 Subject: [PATCH 096/215] test: Add `test_list_contains_invalid` --- narwhals/_plan/_parse.py | 8 +++++-- narwhals/_plan/exceptions.py | 5 +++++ narwhals/_plan/expressions/lists.py | 9 ++++---- narwhals/_plan/functions.py | 4 ++-- tests/plan/expr_parsing_test.py | 35 ++++++++++++++++++++++++++++- 5 files changed, 52 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index 5b23cefde4..8ba064c1d2 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -18,7 +18,11 @@ is_selector, ) from narwhals._plan.common import flatten_hash_safe -from narwhals._plan.exceptions import invalid_into_expr_error, is_iterable_error +from narwhals._plan.exceptions import ( + invalid_into_expr_error, + is_iterable_error, + list_literal_error, +) from narwhals._utils import qualified_type_name from narwhals.dependencies import get_polars from narwhals.exceptions import InvalidOperationError @@ -127,7 +131,7 @@ def parse_into_expr_ir( expr = col(input) elif isinstance(input, list): if list_as_series is None: - raise TypeError(input) # pragma: no cover + raise list_literal_error(input) expr = lit(list_as_series(input)) else: expr = lit(input, dtype=dtype) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 23aa8af096..cac4616b9b 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -55,6 +55,11 @@ def function_arg_non_scalar_error( return InvalidOperationError(msg) +def list_literal_error(value: Any) -> TypeError: + msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." + return TypeError(msg) + + # TODO @dangotbanned: Use arguments in error message def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001 msg = "bins must increase monotonically" diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 0ae34f901a..5e1fbfa9ac 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -27,14 +27,16 @@ class Get(ListFunction): class Contains(ListFunction): """N-ary (expr, item).""" - def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: + def unwrap_input( + self, node: FExpr[Self], / + ) -> tuple[ExprIR, ExprIR]: # pragma: no cover expr, item = node.input return expr, item class IRListNamespace(IRNamespace): len: ClassVar = Len - unique: ClassVar = Unique + unique: ClassVar = Unique # pragma: no cover contains: ClassVar = Contains def get(self, index: int) -> Get: @@ -49,14 +51,13 @@ def _ir_namespace(self) -> type[IRListNamespace]: def len(self) -> Expr: return self._with_unary(self._ir.len()) - def unique(self) -> Expr: + def unique(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.unique()) # TODO @dangotbanned: Needs a full impl def get(self, index: int) -> Expr: return self._with_unary(self._ir.get(index)) - # TODO @dangotbanned: Needs a `expr_parsing` test def contains(self, item: IntoExpr) -> Expr: item_ir = parse_into_expr_ir(item, str_as_lit=True) contains = self._ir.contains() diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index fcec956ff1..a78ac28521 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -8,6 +8,7 @@ from narwhals._duration import Interval from narwhals._plan import _guards, _parse, common, expressions as ir, selectors as cs from narwhals._plan._dispatch import get_dispatch_name +from narwhals._plan.exceptions import list_literal_error from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.expressions.ranges import DateRange, IntRange, RangeFunction @@ -82,8 +83,7 @@ def lit( if _guards.is_series(value): return SeriesLiteral(value=value).to_literal().to_narwhals() if not _guards.is_non_nested_literal(value): - msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." - raise TypeError(msg) + raise list_literal_error(value) if dtype is None: dtype = common.py_to_narwhals_dtype(value, Version.MAIN) else: diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index aabf84798e..b2fba856a6 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -15,7 +15,7 @@ from narwhals._plan import expressions as ir from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.expressions import functions as F, operators as ops -from narwhals._plan.expressions.literal import SeriesLiteral +from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.expressions.ranges import IntRange from narwhals._utils import Implementation from narwhals.exceptions import ( @@ -744,3 +744,36 @@ def test_rolling_expr_invalid( a.rolling_var(window_size, min_samples=min_samples) with context: a.rolling_std(window_size, min_samples=min_samples) + + +def test_list_contains_invalid() -> None: + a = nwp.col("a") + + ok = a.list.contains("a") + assert_expr_ir_equal( + ok, + ir.FunctionExpr( + input=( + ir.col("a"), + ir.Literal(value=ScalarLiteral(value="a", dtype=nw.String())), + ), + function=ir.lists.Contains(), + options=ir.lists.Contains().function_options, + ), + ) + assert a.list.contains(a.first()) + assert a.list.contains(1) + assert a.list.contains(nwp.lit(1)) + assert a.list.contains(dt.datetime(2000, 2, 1, 9, 26, 5)) + assert a.list.contains(a.abs().fill_null(5).mode(keep="any")) + + with pytest.raises( + InvalidOperationError, match=r"list.contains.+non-scalar.+`col\('a'\)" + ): + a.list.contains(a) + + with pytest.raises(InvalidOperationError, match=r"list.contains.+non-scalar.+abs"): + a.list.contains(a.abs()) + + with pytest.raises(TypeError, match=r"list.+not.+supported.+nw.lit.+1.+2.+3"): + a.list.contains([1, 2, 3]) # type: ignore[arg-type] From 2b458bbec486432b180180c6df005a717b5608e6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 16:13:43 +0000 Subject: [PATCH 097/215] test: Add `test_list_get` --- tests/plan/list_get_test.py | 43 +++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/plan/list_get_test.py diff --git a/tests/plan/list_get_test.py b/tests/plan/list_get_test.py new file mode 100644 index 0000000000..83813a9390 --- /dev/null +++ b/tests/plan/list_get_test.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "a": [[1, 2], [3, 4, None], None, [None]], + "b": [[None, "o"], ["b", None, "b"], [None, "oops", None, "hi"], None], + } + + +a = nwp.nth(0) +b = nwp.col("b") + + +@pytest.mark.xfail(reason="TODO: `ArrowExpr.list.get`", raises=NotImplementedError) +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + (a.list.get(0), {"a": [1, 3, None, None]}), + (b.list.get(1), {"b": ["o", None, "oops", None]}), + ], +) +def test_list_get( + data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data +) -> None: # pragma: no cover + df = dataframe(data).with_columns( + a.cast(nw.List(nw.Int32())), b.cast(nw.List(nw.String)) + ) + result = df.select(exprs) + assert_equal_data(result, expected) From e91f04fd444b4d9f1e1d5afd058346a0dd5550da Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 16:20:33 +0000 Subject: [PATCH 098/215] add `test_list_get_invalid` --- tests/plan/expr_parsing_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index b2fba856a6..7cdccbdd9b 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -777,3 +777,13 @@ def test_list_contains_invalid() -> None: with pytest.raises(TypeError, match=r"list.+not.+supported.+nw.lit.+1.+2.+3"): a.list.contains([1, 2, 3]) # type: ignore[arg-type] + + +@pytest.mark.xfail(reason="Not implemented `index` validation yet") +def test_list_get_invalid() -> None: # pragma: no cover + a = nwp.col("a") + assert a.list.get(0) + with pytest.raises(TypeError, match=re_compile(r"must be.+int.+got.+str")): + a.list.get("not an index") # type: ignore[arg-type] + with pytest.raises(ValueError, match=re_compile(r"-1.+out of bounds.+0")): + a.list.get(-1) From 55d20eaca456781666b81cab5a3632b519882fd9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 18:15:52 +0000 Subject: [PATCH 099/215] feat: Add `Expr.list.get` --- narwhals/_plan/arrow/expr.py | 4 ++-- narwhals/_plan/arrow/functions.py | 16 ++++++++++++++++ narwhals/_plan/arrow/typing.py | 6 +++--- narwhals/_plan/expressions/lists.py | 7 ++++++- tests/plan/expr_parsing_test.py | 9 +++++---- tests/plan/list_get_test.py | 5 +---- 6 files changed, 33 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index aa44e52a1a..97253968da 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -846,8 +846,8 @@ def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: return self.with_native(fn.list_len(native), name) def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: - msg = "TODO: `ArrowExpr.list.get`" - raise NotImplementedError(msg) + native = node.input[0].dispatch(self.compliant, frame, name).native + return self.with_native(fn.list_get(native, node.function.index), name) unique = not_implemented() contains = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 3c9be52abf..5871ef87e0 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -326,6 +326,22 @@ def list_len(native: ArrowAny) -> ArrowAny: return result +@t.overload +def list_get( + native: ChunkedList[DataTypeT], index: int +) -> ChunkedArray[Scalar[DataTypeT]]: ... +@t.overload +def list_get(native: ListArray[DataTypeT], index: int) -> Array[Scalar[DataTypeT]]: ... +@t.overload +def list_get(native: ListScalar[DataTypeT], index: int) -> Scalar[DataTypeT]: ... +@t.overload +def list_get(native: SameArrowT, index: int) -> SameArrowT: ... +def list_get(native: ArrowAny, index: int) -> ArrowAny: + list_get_: Incomplete = pc.list_element + result: ArrowAny = list_get_(native, index) + return result + + @t.overload def when_then( predicate: Predicate, then: SameArrowT, otherwise: SameArrowT diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index af67f7f1e2..334df56b2b 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -32,7 +32,7 @@ IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" IntegerScalar: TypeAlias = "Scalar[IntegerType]" DateScalar: TypeAlias = "Scalar[Date32Type]" - ListScalar: TypeAlias = "Scalar[pa.ListType[Any]]" + ListScalar: TypeAlias = "Scalar[pa.ListType[DataTypeT_co]]" class NativeArrowSeries(NativeSeries, Protocol): @property @@ -182,8 +182,8 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ChunkedStruct: TypeAlias = "ChunkedArray[pa.StructScalar]" StructArray: TypeAlias = "pa.StructArray | Array[pa.StructScalar]" -ChunkedList: TypeAlias = "ChunkedArray[ListScalar]" -ListArray: TypeAlias = "Array[ListScalar]" +ChunkedList: TypeAlias = "ChunkedArray[ListScalar[DataTypeT_co]]" +ListArray: TypeAlias = "Array[ListScalar[DataTypeT_co]]" Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 5e1fbfa9ac..708e584c74 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -7,6 +7,8 @@ from narwhals._plan.exceptions import function_arg_non_scalar_error from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions +from narwhals._utils import ensure_type +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from typing_extensions import Self @@ -54,8 +56,11 @@ def len(self) -> Expr: def unique(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.unique()) - # TODO @dangotbanned: Needs a full impl def get(self, index: int) -> Expr: + ensure_type(index, int, param_name="index") + if index < 0: + msg = f"`index` is out of bounds; must be >= 0, got {index}" + raise InvalidOperationError(msg) return self._with_unary(self._ir.get(index)) def contains(self, item: IntoExpr) -> Expr: diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 7cdccbdd9b..e5b37e2996 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -779,11 +779,12 @@ def test_list_contains_invalid() -> None: a.list.contains([1, 2, 3]) # type: ignore[arg-type] -@pytest.mark.xfail(reason="Not implemented `index` validation yet") -def test_list_get_invalid() -> None: # pragma: no cover +def test_list_get_invalid() -> None: a = nwp.col("a") assert a.list.get(0) - with pytest.raises(TypeError, match=re_compile(r"must be.+int.+got.+str")): + pattern = re_compile(r"expected.+int.+got.+str.+'not an index'") + with pytest.raises(TypeError, match=pattern): a.list.get("not an index") # type: ignore[arg-type] - with pytest.raises(ValueError, match=re_compile(r"-1.+out of bounds.+0")): + pattern = re_compile(r"index.+out of bounds.+>= 0.+got -1") + with pytest.raises(InvalidOperationError, match=pattern): a.list.get(-1) diff --git a/tests/plan/list_get_test.py b/tests/plan/list_get_test.py index 83813a9390..64fa644349 100644 --- a/tests/plan/list_get_test.py +++ b/tests/plan/list_get_test.py @@ -25,7 +25,6 @@ def data() -> Data: b = nwp.col("b") -@pytest.mark.xfail(reason="TODO: `ArrowExpr.list.get`", raises=NotImplementedError) @pytest.mark.parametrize( ("exprs", "expected"), [ @@ -33,9 +32,7 @@ def data() -> Data: (b.list.get(1), {"b": ["o", None, "oops", None]}), ], ) -def test_list_get( - data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data -) -> None: # pragma: no cover +def test_list_get(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: df = dataframe(data).with_columns( a.cast(nw.List(nw.Int32())), b.cast(nw.List(nw.String)) ) From d8363e1f0dc89bf43e0d3adb9a6eb8b36ae067a0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 19:29:45 +0000 Subject: [PATCH 100/215] fix: Ensure nulls are preserved in `Expr.unique` during `group_by` Realized after writing this that I didn't test it https://github.com/narwhals-dev/narwhals/pull/3332/files#r2572340215 --- narwhals/_plan/arrow/options.py | 1 + tests/plan/group_by_test.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 254f6ca6da..5e2f328bbe 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -152,6 +152,7 @@ def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: boolean.All: scalar_aggregate(ignore_nulls=True), boolean.Any: scalar_aggregate(ignore_nulls=True), functions.NullCount: count("only_null"), + functions.Unique: count("all"), } diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index d9bba6223e..d82d372d0b 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -588,8 +588,13 @@ def test_group_by_agg_last( "b_first": [3, 1, 3, 2, 1], }, ), + ( + ["d"], + [nwp.col("e", "b").unique()], + {"d": ["one", "three"], "e": [[1, 2], [None, 1]], "b": [[1, 3], [1, 2, 3]]}, + ), ], - ids=["Unique-Single", "Unique-Multi", "Unique-Selector-Fancy"], + ids=["Unique-Single", "Unique-Multi", "Unique-Selector-Fancy", "Unique-Nulls"], ) def test_group_by_agg_unique( keys: Sequence[str], aggs: Sequence[IntoExpr], expected: Mapping[str, Any] @@ -599,6 +604,7 @@ def test_group_by_agg_unique( "b": [1, 2, 1, 3, 3], "c": [5, 4, 3, 2, 1], "d": ["three", "three", "one", "three", "one"], + "e": [None, 1, 1, None, 2], } df = dataframe(data) result = df.group_by(keys).agg(aggs).sort(keys) From 2f3efe138991d06b3241e70757a71d6e290e9a53 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 23:03:25 +0000 Subject: [PATCH 101/215] feat: Add native `str.zfill` --- narwhals/_plan/arrow/expr.py | 28 +++++++++++++++++++++++++-- narwhals/_plan/arrow/functions.py | 12 ++++++++++++ narwhals/_plan/compliant/accessors.py | 8 +++++++- narwhals/_plan/compliant/expr.py | 5 +++++ narwhals/_plan/expressions/lists.py | 6 ++---- narwhals/_plan/expressions/strings.py | 12 ++++++++++-- tests/plan/str_zfill_test.py | 17 ++++++++++++++++ 7 files changed, 79 insertions(+), 9 deletions(-) create mode 100644 tests/plan/str_zfill_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 97253968da..db0c49f0ad 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -21,6 +21,7 @@ from narwhals._plan.compliant.accessors import ( ExprCatNamespace, ExprListNamespace, + ExprStringNamespace, ExprStructNamespace, ) from narwhals._plan.compliant.column import ExprDispatch @@ -54,10 +55,15 @@ from typing_extensions import Self, TypeAlias - from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame + from narwhals._plan.arrow.dataframe import ArrowDataFrame, ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction - from narwhals._plan.expressions import BinaryExpr, FunctionExpr as FExpr, lists + from narwhals._plan.expressions import ( + BinaryExpr, + FunctionExpr as FExpr, + lists, + strings, + ) from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -669,6 +675,10 @@ def cat(self) -> ArrowCatNamespace[Expr]: def list(self) -> ArrowListNamespace[Expr]: return ArrowListNamespace(self) + @property + def str(self) -> ArrowStringNamespace[Expr]: + return ArrowStringNamespace(self) + @property def struct(self) -> ArrowStructNamespace[Expr]: return ArrowStructNamespace(self) @@ -791,6 +801,10 @@ def cat(self) -> ArrowCatNamespace[Scalar]: def list(self) -> ArrowListNamespace[Scalar]: return ArrowListNamespace(self) + @property + def str(self) -> ArrowStringNamespace[Scalar]: + return ArrowStringNamespace(self) + @property def struct(self) -> ArrowStructNamespace[Scalar]: return ArrowStructNamespace(self) @@ -853,6 +867,16 @@ def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: contains = not_implemented() +class ArrowStringNamespace( + ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] +): + def zfill( + self, node: FExpr[strings.ZFill], frame: ArrowDataFrame, name: str + ) -> ArrowExpr | ArrowScalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + return self.with_native(fn.str_zfill(native, node.function.length), name) + + class ArrowStructNamespace( ExprStructNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 5871ef87e0..d9b550c16d 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -99,6 +99,9 @@ HAS_ARANGE: Final = BACKEND_VERSION >= (21,) """`pyarrow.arange` added in https://github.com/apache/arrow/pull/46778""" +HAS_ZFILL: Final = BACKEND_VERSION >= (21,) +"""`pyarrow.compute.utf8_zero_fill` added in https://github.com/apache/arrow/pull/46815""" + I64: Final = pa.int64() F64: Final = pa.float64() @@ -342,6 +345,15 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: return result +def str_zfill(native: ChunkedOrScalarAny, length: int) -> ChunkedOrScalarAny: + if HAS_ZFILL: + zfill: Incomplete = pc.utf8_zero_fill # type:ignore[attr-defined] + result: ChunkedOrScalarAny = zfill(native, length) + return result + msg = "TODO: Port hand-rolled `str.zfill` from `main`" + raise NotImplementedError(msg) + + @t.overload def when_then( predicate: Predicate, then: SameArrowT, otherwise: SameArrowT diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 3f89a49c8e..1b93d0969f 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -5,7 +5,7 @@ from narwhals._plan.compliant.typing import ExprT_co, FrameT_contra if TYPE_CHECKING: - from narwhals._plan.expressions import FunctionExpr as FExpr, lists + from narwhals._plan.expressions import FunctionExpr as FExpr, lists, strings from narwhals._plan.expressions.categorical import GetCategories from narwhals._plan.expressions.struct import FieldByName @@ -31,6 +31,12 @@ def unique( ) -> ExprT_co: ... +class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): + def zfill( + self, node: FExpr[strings.ZFill], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + + class ExprStructNamespace(Protocol[FrameT_contra, ExprT_co]): def field( self, node: FExpr[FieldByName], frame: FrameT_contra, name: str diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 6a7a996090..849ba271e3 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -19,6 +19,7 @@ from narwhals._plan.compliant.accessors import ( ExprCatNamespace, ExprListNamespace, + ExprStringNamespace, ExprStructNamespace, ) from narwhals._plan.compliant.scalar import CompliantScalar, EagerScalar @@ -272,6 +273,10 @@ def list( self, ) -> ExprListNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... @property + def str( + self, + ) -> ExprStringNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... + @property def struct( self, ) -> ExprStructNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 708e584c74..40c2364894 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -40,9 +40,7 @@ class IRListNamespace(IRNamespace): len: ClassVar = Len unique: ClassVar = Unique # pragma: no cover contains: ClassVar = Contains - - def get(self, index: int) -> Get: - return Get(index=index) + get: ClassVar = Get class ExprListNamespace(ExprNamespace[IRListNamespace]): @@ -61,7 +59,7 @@ def get(self, index: int) -> Expr: if index < 0: msg = f"`index` is out of bounds; must be >= 0, got {index}" raise InvalidOperationError(msg) - return self._with_unary(self._ir.get(index)) + return self._with_unary(self._ir.get(index=index)) def contains(self, item: IntoExpr) -> Expr: item_ir = parse_into_expr_ir(item, str_as_lit=True) diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index ebbaeb0177..037fc41c52 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -4,7 +4,7 @@ from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._utils import not_implemented if TYPE_CHECKING: @@ -75,6 +75,11 @@ class ToDatetime(StringFunction): format: str | None +class ZFill(StringFunction, config=FEOptions.renamed("zfill")): + __slots__ = ("length",) + length: int + + class IRStringNamespace(IRNamespace): len_chars: ClassVar = LenChars to_lowercase: ClassVar = ToUppercase @@ -82,6 +87,7 @@ class IRStringNamespace(IRNamespace): split: ClassVar = Split starts_with: ClassVar = StartsWith ends_with: ClassVar = EndsWith + zfill: ClassVar = ZFill def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 @@ -167,6 +173,8 @@ def to_lowercase(self) -> Expr: # pragma: no cover def to_uppercase(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_uppercase()) + def zfill(self, length: int) -> Expr: + return self._with_unary(self._ir.zfill(length=length)) + to_date = not_implemented() to_titlecase = not_implemented() - zfill = not_implemented() diff --git a/tests/plan/str_zfill_test.py b/tests/plan/str_zfill_test.py new file mode 100644 index 0000000000..e0c4313243 --- /dev/null +++ b/tests/plan/str_zfill_test.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe +from tests.utils import PYARROW_VERSION + + +@pytest.mark.xfail( + PYARROW_VERSION < (21,), reason="TODO: `str.zfill` port", raises=NotImplementedError +) +def test_str_zfill() -> None: # pragma: no cover + data = {"a": ["-1", "+1", "1", "12", "123", "99999", "+9999", None]} + expected = {"a": ["-01", "+01", "001", "012", "123", "99999", "+9999", None]} + result = dataframe(data).select(nwp.col("a").str.zfill(3)) + assert_equal_data(result, expected) From d0cc2dfda8c7236464566d19534742c363b968ce Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 28 Nov 2025 23:06:36 +0000 Subject: [PATCH 102/215] revert autocomplete --- narwhals/_plan/arrow/expr.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index db0c49f0ad..131f9fdf7a 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -55,7 +55,7 @@ from typing_extensions import Self, TypeAlias - from narwhals._plan.arrow.dataframe import ArrowDataFrame, ArrowDataFrame as Frame + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction from narwhals._plan.expressions import ( @@ -870,9 +870,7 @@ def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: class ArrowStringNamespace( ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): - def zfill( - self, node: FExpr[strings.ZFill], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + def zfill(self, node: FExpr[strings.ZFill], frame: Frame, name: str) -> Expr | Scalar: native = node.input[0].dispatch(self.compliant, frame, name).native return self.with_native(fn.str_zfill(native, node.function.length), name) From d5771bb6ea17ec6c5b4509dea21f14999d1bfb51 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:18:21 +0000 Subject: [PATCH 103/215] Adapt hand-rolled `str.zfill` from `main` --- narwhals/_plan/arrow/functions.py | 61 +++++++++++++++++++++++++++++-- tests/plan/str_zfill_test.py | 8 +--- 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d9b550c16d..0a455b6fa5 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -345,13 +345,66 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: return result +def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: # pragma: no cover + len_chars: Incomplete = pc.utf8_length + result: ChunkedOrScalarAny = len_chars(native) + return result + + +def str_slice( + native: ChunkedOrScalarAny, offset: int, length: int | None = None +) -> ChunkedOrScalarAny: # pragma: no cover + stop = length if length is None else offset + length + return pc.utf8_slice_codeunits(native, offset, stop=stop) + + +def str_pad_start( + native: ChunkedOrScalarAny, length: int, fill_char: str = " " +) -> ChunkedOrScalarAny: # pragma: no cover + return pc.utf8_lpad(native, length, fill_char) + + def str_zfill(native: ChunkedOrScalarAny, length: int) -> ChunkedOrScalarAny: if HAS_ZFILL: - zfill: Incomplete = pc.utf8_zero_fill # type:ignore[attr-defined] + zfill: Incomplete = pc.utf8_zero_fill # type: ignore[attr-defined] result: ChunkedOrScalarAny = zfill(native, length) - return result - msg = "TODO: Port hand-rolled `str.zfill` from `main`" - raise NotImplementedError(msg) + else: + result = _str_zfill_compat(native, length) + return result + + +# TODO @dangotbanned: Finish tidying this up +def _str_zfill_compat( + native: ChunkedOrScalarAny, length: int +) -> Incomplete: # pragma: no cover + dtype = string_type([native.type]) + hyphen, plus = lit("-", dtype), lit("+", dtype) + + padded_remaining = str_pad_start(str_slice(native, 1), length - 1, "0") + padded_lt_length = str_pad_start(native, length, "0") + + binary_join: Incomplete = pc.binary_join_element_wise + if isinstance(native, pa.Scalar): + case_1: ArrowAny = hyphen # starts with hyphen and less than length + case_2: ArrowAny = plus # starts with plus and less than length + else: + arr_len = len(native) + case_1 = repeat_unchecked(hyphen, arr_len) + case_2 = repeat_unchecked(plus, arr_len) + + first_char = str_slice(native, 0, 1) + lt_length = lt(str_len_chars(native), lit(length)) + first_hyphen_lt_length = and_(eq(first_char, hyphen), lt_length) + first_plus_lt_length = and_(eq(first_char, plus), lt_length) + return when_then( + first_hyphen_lt_length, + binary_join(case_1, padded_remaining, ""), + when_then( + first_plus_lt_length, + binary_join(case_2, padded_remaining, ""), + when_then(lt_length, padded_lt_length, native), + ), + ) @t.overload diff --git a/tests/plan/str_zfill_test.py b/tests/plan/str_zfill_test.py index e0c4313243..1d471b70de 100644 --- a/tests/plan/str_zfill_test.py +++ b/tests/plan/str_zfill_test.py @@ -1,16 +1,10 @@ from __future__ import annotations -import pytest - import narwhals._plan as nwp from tests.plan.utils import assert_equal_data, dataframe -from tests.utils import PYARROW_VERSION -@pytest.mark.xfail( - PYARROW_VERSION < (21,), reason="TODO: `str.zfill` port", raises=NotImplementedError -) -def test_str_zfill() -> None: # pragma: no cover +def test_str_zfill() -> None: data = {"a": ["-1", "+1", "1", "12", "123", "99999", "+9999", None]} expected = {"a": ["-01", "+01", "001", "012", "123", "99999", "+9999", None]} result = dataframe(data).select(nwp.col("a").str.zfill(3)) From b97aff94482c02b50540df564a50069b132680f7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:33:07 +0000 Subject: [PATCH 104/215] feat: Add `str.{slice,len_chars}` Fell out of the `zfill` impl --- narwhals/_plan/arrow/expr.py | 11 +++++++++++ narwhals/_plan/arrow/functions.py | 4 ++-- narwhals/_plan/compliant/accessors.py | 6 ++++++ tests/plan/str_len_chars_test.py | 11 +++++++++++ tests/plan/str_slice_test.py | 21 +++++++++++++++++++++ 5 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 tests/plan/str_len_chars_test.py create mode 100644 tests/plan/str_slice_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 131f9fdf7a..cf476b7369 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -870,6 +870,17 @@ def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: class ArrowStringNamespace( ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): + def len_chars( + self, node: FExpr[strings.LenChars], frame: Frame, name: str + ) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + return self.with_native(fn.str_len_chars(native), name) + + def slice(self, node: FExpr[strings.Slice], frame: Frame, name: str) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + func = node.function + return self.with_native(fn.str_slice(native, func.offset, func.length), name) + def zfill(self, node: FExpr[strings.ZFill], frame: Frame, name: str) -> Expr | Scalar: native = node.input[0].dispatch(self.compliant, frame, name).native return self.with_native(fn.str_zfill(native, node.function.length), name) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 0a455b6fa5..f6bed79361 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -345,7 +345,7 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: return result -def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: # pragma: no cover +def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: len_chars: Incomplete = pc.utf8_length result: ChunkedOrScalarAny = len_chars(native) return result @@ -353,7 +353,7 @@ def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: # pragma: def str_slice( native: ChunkedOrScalarAny, offset: int, length: int | None = None -) -> ChunkedOrScalarAny: # pragma: no cover +) -> ChunkedOrScalarAny: stop = length if length is None else offset + length return pc.utf8_slice_codeunits(native, offset, stop=stop) diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 1b93d0969f..cec91e56a7 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -32,6 +32,12 @@ def unique( class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): + def len_chars( + self, node: FExpr[strings.LenChars], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def slice( + self, node: FExpr[strings.Slice], frame: FrameT_contra, name: str + ) -> ExprT_co: ... def zfill( self, node: FExpr[strings.ZFill], frame: FrameT_contra, name: str ) -> ExprT_co: ... diff --git a/tests/plan/str_len_chars_test.py b/tests/plan/str_len_chars_test.py new file mode 100644 index 0000000000..e5e3589592 --- /dev/null +++ b/tests/plan/str_len_chars_test.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +def test_len_chars() -> None: + data = {"a": ["foo", "foobar", "Café", "345", "東京"]} + expected = {"a": [3, 6, 4, 3, 2]} + result = dataframe(data).select(nwp.col("a").str.len_chars()) + assert_equal_data(result, expected) diff --git a/tests/plan/str_slice_test.py b/tests/plan/str_slice_test.py new file mode 100644 index 0000000000..b2ee40c300 --- /dev/null +++ b/tests/plan/str_slice_test.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.mark.parametrize( + ("offset", "length", "expected"), + [(1, 2, {"a": ["da", "df"]}), (-2, None, {"a": ["as", "as"]})], +) +def test_str_slice(offset: int, length: int | None, expected: Data) -> None: + data = {"a": ["fdas", "edfas"]} + result = dataframe(data).select(nwp.col("a").str.slice(offset, length)) + assert_equal_data(result, expected) From 7be2d290eec77f4143752c0d6f99c4741b72d411 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Nov 2025 17:06:38 +0000 Subject: [PATCH 105/215] Add nodes for missing namespace methods Might add an impl for `to_titlecase`, since its simple, but the others are a hassle --- narwhals/_plan/expressions/strings.py | 20 +++++++++++++++---- narwhals/_plan/expressions/temporal.py | 27 ++++++++++++++++---------- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 037fc41c52..9b4b3ae043 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -5,7 +5,6 @@ from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FEOptions, FunctionOptions -from narwhals._utils import not_implemented if TYPE_CHECKING: from narwhals._plan.expr import Expr @@ -16,6 +15,7 @@ class StringFunction(Function, accessor="str", options=FunctionOptions.elementwi class LenChars(StringFunction): ... class ToLowercase(StringFunction): ... class ToUppercase(StringFunction): ... +class ToTitlecase(StringFunction): ... # fmt: on class ConcatStr(HorizontalFunction, StringFunction): __slots__ = ("ignore_nulls", "separator") @@ -70,6 +70,11 @@ class StripChars(StringFunction): characters: str | None +class ToDate(StringFunction): + __slots__ = ("format",) + format: str | None + + class ToDatetime(StringFunction): __slots__ = ("format",) format: str | None @@ -84,6 +89,7 @@ class IRStringNamespace(IRNamespace): len_chars: ClassVar = LenChars to_lowercase: ClassVar = ToUppercase to_uppercase: ClassVar = ToLowercase + to_titlecase: ClassVar = ToTitlecase split: ClassVar = Split starts_with: ClassVar = StartsWith ends_with: ClassVar = EndsWith @@ -116,6 +122,9 @@ def head(self, n: int = 5) -> Slice: def tail(self, n: int = 5) -> Slice: return self.slice(-n) + def to_date(self, format: str | None = None) -> ToDate: # pragma: no cover + return ToDate(format=format) + def to_datetime(self, format: str | None = None) -> ToDatetime: # pragma: no cover return ToDatetime(format=format) @@ -164,6 +173,9 @@ def tail(self, n: int = 5) -> Expr: def split(self, by: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.split(by=by)) + def to_date(self, format: str | None = None) -> Expr: # pragma: no cover + return self._with_unary(self._ir.to_date(format)) + def to_datetime(self, format: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_datetime(format)) @@ -173,8 +185,8 @@ def to_lowercase(self) -> Expr: # pragma: no cover def to_uppercase(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_uppercase()) + def to_titlecase(self) -> Expr: # pragma: no cover + return self._with_unary(self._ir.to_titlecase()) + def zfill(self, length: int) -> Expr: return self._with_unary(self._ir.zfill(length=length)) - - to_date = not_implemented() - to_titlecase = not_implemented() diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index 6e703899ba..257495d2b6 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -6,10 +6,9 @@ from narwhals._plan._function import Function from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions -from narwhals._utils import not_implemented if TYPE_CHECKING: - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import Self, TypeAlias, TypeIs from narwhals._duration import IntervalUnit from narwhals._plan.expr import Expr @@ -72,18 +71,24 @@ def __repr__(self) -> str: return f"{super().__repr__()}[{self.time_unit!r}]" -class Truncate(TemporalFunction): +class _IntervalFunction(TemporalFunction): __slots__ = ("multiple", "unit") multiple: int unit: IntervalUnit - @staticmethod - def from_string(every: str, /) -> Truncate: - return Truncate.from_interval(Interval.parse(every)) + @classmethod + def from_string(cls, interval: str, /) -> Self: + return cls.from_interval(Interval.parse(interval)) - @staticmethod - def from_interval(every: Interval, /) -> Truncate: - return Truncate(multiple=every.multiple, unit=every.unit) + @classmethod + def from_interval(cls, interval: Interval, /) -> Self: + return cls(multiple=interval.multiple, unit=interval.unit) + + +# fmt: off +class Truncate(_IntervalFunction): ... +class OffsetBy(_IntervalFunction): ... +# fmt: on class IRDateTimeNamespace(IRNamespace): @@ -107,6 +112,7 @@ class IRDateTimeNamespace(IRNamespace): to_string: ClassVar = ToString replace_time_zone: ClassVar = ReplaceTimeZone convert_time_zone: ClassVar = ConvertTimeZone + offset_by: ClassVar = staticmethod(OffsetBy.from_string) truncate: ClassVar = staticmethod(Truncate.from_string) timestamp: ClassVar = staticmethod(Timestamp.from_time_unit) @@ -182,4 +188,5 @@ def timestamp(self, time_unit: TimeUnit = "us") -> Expr: def truncate(self, every: str) -> Expr: return self._with_unary(self._ir.truncate(every)) - offset_by = not_implemented() + def offset_by(self, by: str) -> Expr: # pragma: no cover + return self._with_unary(self._ir.offset_by(by)) From f633ce104f10de86729bd0256ff5701237f611c4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Nov 2025 21:01:09 +0000 Subject: [PATCH 106/215] feat: Impl most of `ArrowStringNamespace` `strip_chars`, `starts_with`, `ends_with` are using different native functions --- narwhals/_plan/arrow/expr.py | 76 +++++++++++++++++++++++++-- narwhals/_plan/arrow/functions.py | 35 ++++++++++++ narwhals/_plan/compliant/accessors.py | 36 +++++++++++++ tests/plan/dispatch_test.py | 6 +-- 4 files changed, 145 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index cf476b7369..6454ddd913 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -28,7 +28,7 @@ from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar from narwhals._plan.compliant.typing import namespace -from narwhals._plan.expressions import functions as F +from narwhals._plan.expressions import FunctionExpr as FExpr, functions as F from narwhals._plan.expressions.boolean import ( IsDuplicated, IsFirstDistinct, @@ -856,8 +856,7 @@ class ArrowListNamespace( ExprListNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - return self.with_native(fn.list_len(native), name) + return self.compliant._unary_function(fn.list_len)(node, frame, name) def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: native = node.input[0].dispatch(self.compliant, frame, name).native @@ -867,14 +866,14 @@ def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: contains = not_implemented() +# TODO @dangotbanned: Add tests for these, especially those using a different native function class ArrowStringNamespace( ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): def len_chars( self, node: FExpr[strings.LenChars], frame: Frame, name: str ) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - return self.with_native(fn.str_len_chars(native), name) + return self.compliant._unary_function(fn.str_len_chars)(node, frame, name) def slice(self, node: FExpr[strings.Slice], frame: Frame, name: str) -> Expr | Scalar: native = node.input[0].dispatch(self.compliant, frame, name).native @@ -885,6 +884,73 @@ def zfill(self, node: FExpr[strings.ZFill], frame: Frame, name: str) -> Expr | S native = node.input[0].dispatch(self.compliant, frame, name).native return self.with_native(fn.str_zfill(native, node.function.length), name) + def contains( + self, node: FExpr[strings.Contains], frame: Frame, name: str + ) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + func = node.function + result = fn.str_contains(native, func.pattern, literal=func.literal) + return self.with_native(result, name) + + def ends_with( + self, node: FExpr[strings.EndsWith], frame: Frame, name: str + ) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + return self.with_native(fn.str_ends_with(native, node.function.suffix), name) + + def replace( + self, node: FExpr[strings.Replace], frame: Frame, name: str + ) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + func = node.function + pattern, value, literal, n = func.pattern, func.value, func.literal, func.n + result = fn.str_replace(native, pattern, value, literal=literal, n=n) + return self.with_native(result, name) + + def replace_all( + self, node: FExpr[strings.ReplaceAll], frame: Frame, name: str + ) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + func = node.function + pattern, value, literal = func.pattern, func.value, func.literal + result = fn.str_replace_all(native, pattern, value, literal=literal) + return self.with_native(result, name) + + def split(self, node: FExpr[strings.Split], frame: Frame, name: str) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + return self.with_native(fn.str_split(native, node.function.by), name) + + def starts_with( + self, node: FExpr[strings.StartsWith], frame: Frame, name: str + ) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + return self.with_native(fn.str_starts_with(native, node.function.prefix), name) + + def strip_chars( + self, node: FExpr[strings.StripChars], frame: Frame, name: str + ) -> Expr | Scalar: + native = node.input[0].dispatch(self.compliant, frame, name).native + characters = node.function.characters + return self.with_native(fn.str_strip_chars(native, characters), name) + + def to_uppercase( + self, node: FExpr[strings.ToUppercase], frame: Frame, name: str + ) -> Expr | Scalar: + return self.compliant._unary_function(fn.str_to_uppercase)(node, frame, name) + + def to_lowercase( + self, node: FExpr[strings.ToLowercase], frame: Frame, name: str + ) -> Expr | Scalar: + return self.compliant._unary_function(fn.str_to_lowercase)(node, frame, name) + + def to_titlecase( + self, node: FExpr[strings.ToTitlecase], frame: Frame, name: str + ) -> Expr | Scalar: + return self.compliant._unary_function(fn.str_to_titlecase)(node, frame, name) + + to_date = not_implemented() + to_datetime = not_implemented() + class ArrowStructNamespace( ExprStructNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index f6bed79361..b7d7bc3877 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -364,6 +364,41 @@ def str_pad_start( return pc.utf8_lpad(native, length, fill_char) +_StringFunction: TypeAlias = "Callable[[ChunkedOrScalarAny,str], ChunkedOrScalarAny]" +str_starts_with = t.cast("_StringFunction", pc.starts_with) +str_ends_with = t.cast("_StringFunction", pc.ends_with) +str_split = t.cast("_StringFunction", pc.split_pattern) +str_to_uppercase = pc.utf8_upper +str_to_lowercase = pc.utf8_lower +str_to_titlecase = pc.utf8_title + + +def str_contains( + native: Incomplete, pattern: str, *, literal: bool = False +) -> Incomplete: + func = pc.match_substring if literal else pc.match_substring_regex + return func(native, pattern) + + +def str_strip_chars(native: Incomplete, characters: str | None) -> Incomplete: + if characters: + return pc.utf8_trim(native, characters) + return pc.utf8_trim_whitespace(native) + + +def str_replace( + native: Incomplete, pattern: str, value: str, *, literal: bool = False, n: int = 1 +) -> Incomplete: + fn = pc.replace_substring if literal else pc.replace_substring_regex + return fn(native, pattern, replacement=value, max_replacements=n) + + +def str_replace_all( + native: Incomplete, pattern: str, value: str, *, literal: bool = False +) -> Incomplete: + return str_replace(native, pattern, value, literal=literal, n=-1) + + def str_zfill(native: ChunkedOrScalarAny, length: int) -> ChunkedOrScalarAny: if HAS_ZFILL: zfill: Incomplete = pc.utf8_zero_fill # type: ignore[attr-defined] diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index cec91e56a7..535b299dd8 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -32,12 +32,48 @@ def unique( class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): + def contains( + self, node: FExpr[strings.Contains], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def ends_with( + self, node: FExpr[strings.EndsWith], frame: FrameT_contra, name: str + ) -> ExprT_co: ... def len_chars( self, node: FExpr[strings.LenChars], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def replace( + self, node: FExpr[strings.Replace], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def replace_all( + self, node: FExpr[strings.ReplaceAll], frame: FrameT_contra, name: str + ) -> ExprT_co: ... def slice( self, node: FExpr[strings.Slice], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def split( + self, node: FExpr[strings.Split], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def starts_with( + self, node: FExpr[strings.StartsWith], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def strip_chars( + self, node: FExpr[strings.StripChars], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_uppercase( + self, node: FExpr[strings.ToUppercase], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_lowercase( + self, node: FExpr[strings.ToLowercase], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_titlecase( + self, node: FExpr[strings.ToTitlecase], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_date( + self, node: FExpr[strings.ToDate], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_datetime( + self, node: FExpr[strings.ToDatetime], frame: FrameT_contra, name: str + ) -> ExprT_co: ... def zfill( self, node: FExpr[strings.ZFill], frame: FrameT_contra, name: str ) -> ExprT_co: ... diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py index ddfac3051e..23cc085959 100644 --- a/tests/plan/dispatch_test.py +++ b/tests/plan/dispatch_test.py @@ -47,11 +47,11 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: df.select(nwp.col("c").ewm_mean()) missing_protocol = re_compile( - r"str\.contains.+has not been implemented.+compliant.+" - r"Hint.+try adding.+CompliantExpr\.str\.contains\(\)" + r"dt\.offset_by.+has not been implemented.+compliant.+" + r"Hint.+try adding.+CompliantExpr\.dt\.offset_by\(\)" ) with pytest.raises(NotImplementedError, match=missing_protocol): - df.select(nwp.col("d").str.contains("a")) + df.select(nwp.col("d").dt.offset_by("1d")) with pytest.raises( TypeError, From f4cecb0478c67cd94817ba46af5443513156cbb6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 30 Nov 2025 16:50:52 +0000 Subject: [PATCH 107/215] refactor: Extend ` _unary_function`, docs, fix typing - Still not quite as ergonomic as I'd like - Should make it easier to refactor later - Now that more things hit the same path --- narwhals/_plan/arrow/expr.py | 146 ++++++++++++++++++------------ narwhals/_plan/arrow/functions.py | 65 +++++++++---- narwhals/_plan/arrow/typing.py | 21 ++++- 3 files changed, 150 insertions(+), 82 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 6454ddd913..863afea08f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -57,7 +57,12 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction + from narwhals._plan.arrow.typing import ( + ChunkedArrayAny, + P, + UnaryFunctionP, + VectorFunction, + ) from narwhals._plan.expressions import ( BinaryExpr, FunctionExpr as FExpr, @@ -130,7 +135,7 @@ def pow(self, node: FExpr[Pow], frame: Frame, name: str) -> StoresNativeT_co: base, exponent = node.function.unwrap_input(node) base_ = base.dispatch(self, frame, "base").native exponent_ = exponent.dispatch(self, frame, "exponent").native - return self._with_native(pc.power(base_, exponent_), name) + return self._with_native(fn.power(base_, exponent_), name) def fill_null( self, node: FExpr[FillNull], frame: Frame, name: str @@ -156,17 +161,46 @@ def is_between( result = fn.is_between(native, lower, upper, node.function.closed) return self._with_native(result, name) + @overload + def _unary_function( + self, fn_native: UnaryFunctionP[P], /, *args: P.args, **kwds: P.kwargs + ) -> Callable[[FExpr[Any], Frame, str], StoresNativeT_co]: ... + @overload + def _unary_function( + self, fn_native: Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny], / + ) -> Callable[[FExpr[Any], Frame, str], StoresNativeT_co]: ... def _unary_function( - self, fn_native: Callable[[Any], Any], / + self, fn_native: UnaryFunctionP[P], /, *args: P.args, **kwds: P.kwargs ) -> Callable[[FExpr[Any], Frame, str], StoresNativeT_co]: - def func(node: FExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: + """Return a function with the signature `(node, frame, name)`. + + Handles dispatching prior expressions, and rewrapping the result of this one. + + Arity refers to the number of expression inputs to a function (after expanding). + + So a **unary** function will look like: + + col("a").round(2) + + Which unravels to: + + FunctionExpr( + input=(Column(name="a"),), + # ^ length-1 tuple + function=Round(decimals=2), + # ^ non-expression argument + options=..., + ) + """ + + def func(node: FExpr[Any], frame: Frame, name: str, /) -> StoresNativeT_co: native = node.input[0].dispatch(self, frame, name).native - return self._with_native(fn_native(native), name) + return self._with_native(fn_native(native, *args, **kwds), name) return func def abs(self, node: FExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(pc.abs)(node, frame, name) + return self._unary_function(fn.abs_)(node, frame, name) def not_(self, node: FExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.not_)(node, frame, name) @@ -196,16 +230,14 @@ def is_in_expr( def is_in_series( self, node: FExpr[IsInSeries[ChunkedArrayAny]], frame: Frame, name: str ) -> StoresNativeT_co: - native = node.input[0].dispatch(self, frame, name).native other = node.function.other.unwrap().to_native() - return self._with_native(fn.is_in(native, other), name) + return self._unary_function(fn.is_in, other)(node, frame, name) def is_in_seq( self, node: FExpr[IsInSeq], frame: Frame, name: str ) -> StoresNativeT_co: - native = node.input[0].dispatch(self, frame, name).native other = fn.array(node.function.other) - return self._with_native(fn.is_in(native, other), name) + return self._unary_function(fn.is_in, other)(node, frame, name) def is_nan(self, node: FExpr[IsNan], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.is_nan)(node, frame, name) @@ -241,24 +273,22 @@ def ternary_expr( return self._with_native(result, name) def log(self, node: FExpr[F.Log], frame: Frame, name: str) -> StoresNativeT_co: - native = node.input[0].dispatch(self, frame, name).native - return self._with_native(fn.log(native, node.function.base), name) + return self._unary_function(fn.log, node.function.base)(node, frame, name) def exp(self, node: FExpr[F.Exp], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(pc.exp)(node, frame, name) + return self._unary_function(fn.exp)(node, frame, name) def sqrt(self, node: FExpr[F.Sqrt], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(pc.sqrt)(node, frame, name) + return self._unary_function(fn.sqrt)(node, frame, name) def round(self, node: FExpr[F.Round], frame: Frame, name: str) -> StoresNativeT_co: - native = node.input[0].dispatch(self, frame, name).native - return self._with_native(fn.round(native, node.function.decimals), name) + return self._unary_function(fn.round, node.function.decimals)(node, frame, name) def ceil(self, node: FExpr[F.Ceil], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(pc.ceil)(node, frame, name) + return self._unary_function(fn.ceil)(node, frame, name) def floor(self, node: FExpr[F.Floor], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(pc.floor)(node, frame, name) + return self._unary_function(fn.floor)(node, frame, name) def clip(self, node: FExpr[F.Clip], frame: Frame, name: str) -> StoresNativeT_co: expr, lower, upper = node.function.unwrap_input(node) @@ -292,11 +322,9 @@ def clip_upper( def replace_strict( self, node: FExpr[F.ReplaceStrict], frame: Frame, name: str ) -> StoresNativeT_co: - func = node.function - native = node.input[0].dispatch(self, frame, name).native - dtype = fn.dtype_native(func.return_dtype, self.version) - result = fn.replace_strict(native, func.old, func.new, dtype) - return self._with_native(result, name) + old, new = node.function.old, node.function.new + dtype = fn.dtype_native(node.function.return_dtype, self.version) + return self._unary_function(fn.replace_strict, old, new, dtype)(node, frame, name) def replace_strict_default( self, node: FExpr[F.ReplaceStrictDefault], frame: Frame, name: str @@ -845,6 +873,19 @@ def version(self) -> Version: def with_native(self, native: ChunkedOrScalarAny, name: str, /) -> Expr | Scalar: return self.compliant._with_native(native, name) + @overload + def unary( + self, fn_native: UnaryFunctionP[P], /, *args: P.args, **kwds: P.kwargs + ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: ... + @overload + def unary( + self, fn_native: Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny], / + ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: ... + def unary( + self, fn_native: UnaryFunctionP[P], /, *args: P.args, **kwds: P.kwargs + ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: + return self.compliant._unary_function(fn_native, *args, **kwds) + class ArrowCatNamespace(ExprCatNamespace["Frame", "Expr"], ArrowAccessor[ExprOrScalarT]): def get_categories(self, node: FExpr[GetCategories], frame: Frame, name: str) -> Expr: @@ -856,11 +897,10 @@ class ArrowListNamespace( ExprListNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: - return self.compliant._unary_function(fn.list_len)(node, frame, name) + return self.unary(fn.list_len)(node, frame, name) def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - return self.with_native(fn.list_get(native, node.function.index), name) + return self.unary(fn.list_get, node.function.index)(node, frame, name) unique = not_implemented() contains = not_implemented() @@ -873,80 +913,71 @@ class ArrowStringNamespace( def len_chars( self, node: FExpr[strings.LenChars], frame: Frame, name: str ) -> Expr | Scalar: - return self.compliant._unary_function(fn.str_len_chars)(node, frame, name) + return self.unary(fn.str_len_chars)(node, frame, name) def slice(self, node: FExpr[strings.Slice], frame: Frame, name: str) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - func = node.function - return self.with_native(fn.str_slice(native, func.offset, func.length), name) + offset, length = node.function.offset, node.function.length + return self.unary(fn.str_slice, offset, length)(node, frame, name) def zfill(self, node: FExpr[strings.ZFill], frame: Frame, name: str) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - return self.with_native(fn.str_zfill(native, node.function.length), name) + return self.unary(fn.str_zfill, node.function.length)(node, frame, name) def contains( self, node: FExpr[strings.Contains], frame: Frame, name: str ) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - func = node.function - result = fn.str_contains(native, func.pattern, literal=func.literal) - return self.with_native(result, name) + pattern, literal = node.function.pattern, node.function.literal + return self.unary(fn.str_contains, pattern, literal=literal)(node, frame, name) def ends_with( self, node: FExpr[strings.EndsWith], frame: Frame, name: str ) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - return self.with_native(fn.str_ends_with(native, node.function.suffix), name) + return self.unary(fn.str_ends_with, node.function.suffix)(node, frame, name) def replace( self, node: FExpr[strings.Replace], frame: Frame, name: str ) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native func = node.function - pattern, value, literal, n = func.pattern, func.value, func.literal, func.n - result = fn.str_replace(native, pattern, value, literal=literal, n=n) - return self.with_native(result, name) + pattern, value, literal, n = (func.pattern, func.value, func.literal, func.n) + replace = fn.str_replace + return self.unary(replace, pattern, value, literal=literal, n=n)( + node, frame, name + ) def replace_all( self, node: FExpr[strings.ReplaceAll], frame: Frame, name: str ) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native func = node.function - pattern, value, literal = func.pattern, func.value, func.literal - result = fn.str_replace_all(native, pattern, value, literal=literal) - return self.with_native(result, name) + pattern, value, literal = (func.pattern, func.value, func.literal) + replace = fn.str_replace_all + return self.unary(replace, pattern, value, literal=literal)(node, frame, name) def split(self, node: FExpr[strings.Split], frame: Frame, name: str) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - return self.with_native(fn.str_split(native, node.function.by), name) + return self.unary(fn.str_split, node.function.by)(node, frame, name) def starts_with( self, node: FExpr[strings.StartsWith], frame: Frame, name: str ) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - return self.with_native(fn.str_starts_with(native, node.function.prefix), name) + return self.unary(fn.str_starts_with, node.function.prefix)(node, frame, name) def strip_chars( self, node: FExpr[strings.StripChars], frame: Frame, name: str ) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - characters = node.function.characters - return self.with_native(fn.str_strip_chars(native, characters), name) + return self.unary(fn.str_strip_chars, node.function.characters)(node, frame, name) def to_uppercase( self, node: FExpr[strings.ToUppercase], frame: Frame, name: str ) -> Expr | Scalar: - return self.compliant._unary_function(fn.str_to_uppercase)(node, frame, name) + return self.unary(fn.str_to_uppercase)(node, frame, name) def to_lowercase( self, node: FExpr[strings.ToLowercase], frame: Frame, name: str ) -> Expr | Scalar: - return self.compliant._unary_function(fn.str_to_lowercase)(node, frame, name) + return self.unary(fn.str_to_lowercase)(node, frame, name) def to_titlecase( self, node: FExpr[strings.ToTitlecase], frame: Frame, name: str ) -> Expr | Scalar: - return self.compliant._unary_function(fn.str_to_titlecase)(node, frame, name) + return self.unary(fn.str_to_titlecase)(node, frame, name) to_date = not_implemented() to_datetime = not_implemented() @@ -956,5 +987,4 @@ class ArrowStructNamespace( ExprStructNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): def field(self, node: FExpr[FieldByName], frame: Frame, name: str) -> Expr | Scalar: - native = node.input[0].dispatch(self.compliant, frame, name).native - return self.with_native(fn.struct_field(native, node.function.name), name) + return self.unary(fn.struct_field, node.function.name)(node, frame, name) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index b7d7bc3877..b89034a5f7 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -42,6 +42,7 @@ BinaryNumericTemporal, BinOp, BooleanLengthPreserving, + BooleanScalar, ChunkedArray, ChunkedArrayAny, ChunkedList, @@ -62,6 +63,7 @@ ListArray, ListScalar, NativeScalar, + NumericScalar, Predicate, SameArrowT, Scalar, @@ -71,6 +73,7 @@ StringType, StructArray, UnaryFunction, + UnaryNumeric, VectorFunction, ) from narwhals._plan.compliant.typing import SeriesT @@ -117,14 +120,22 @@ class MinMax(ir.AggExpr): IntoColumnAgg: TypeAlias = Callable[[str], ir.AggExpr] """Helper constructor for single-column aggregations.""" -is_null = pc.is_null -is_not_null = t.cast("UnaryFunction[ScalarAny,pa.BooleanScalar]", pc.is_valid) -is_nan = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_nan) -is_finite = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_finite) -not_ = t.cast("UnaryFunction[pa.BooleanScalar ,pa.BooleanScalar]", pc.invert) +is_null = t.cast("UnaryFunction[ScalarAny, BooleanScalar]", pc.is_null) +is_not_null = t.cast("UnaryFunction[ScalarAny,BooleanScalar]", pc.is_valid) +is_nan = t.cast("UnaryFunction[ScalarAny, BooleanScalar]", pc.is_nan) +is_finite = t.cast("UnaryFunction[ScalarAny, BooleanScalar]", pc.is_finite) +not_ = t.cast("UnaryFunction[ScalarAny, BooleanScalar]", pc.invert) -def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: +@overload +def is_not_nan(native: ChunkedArrayAny) -> ChunkedArray[BooleanScalar]: ... +@overload +def is_not_nan(native: ScalarAny) -> BooleanScalar: ... +@overload +def is_not_nan(native: ChunkedOrScalarAny) -> ChunkedOrScalar[BooleanScalar]: ... +@overload +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[BooleanScalar]: ... +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[BooleanScalar]: return not_(is_nan(native)) @@ -143,15 +154,20 @@ def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: add = t.cast("BinaryNumericTemporal", pc.add) sub = t.cast("BinaryNumericTemporal", pc.subtract) multiply = pc.multiply -power = t.cast("BinaryFunction[pc.NumericScalar, pc.NumericScalar]", pc.power) +power = t.cast("BinaryFunction[NumericScalar, NumericScalar]", pc.power) floordiv = _floordiv +abs_ = t.cast("UnaryNumeric", pc.abs) +exp = t.cast("UnaryNumeric", pc.exp) +sqrt = t.cast("UnaryNumeric", pc.sqrt) +ceil = t.cast("UnaryNumeric", pc.ceil) +floor = t.cast("UnaryNumeric", pc.floor) -def truediv(lhs: Any, rhs: Any) -> Any: +def truediv(lhs: Incomplete, rhs: Incomplete) -> Incomplete: return pc.divide(*cast_for_truediv(lhs, rhs)) -def modulus(lhs: Any, rhs: Any) -> Any: +def modulus(lhs: Incomplete, rhs: Incomplete) -> Incomplete: floor_div = floordiv(lhs, rhs) return sub(lhs, multiply(floor_div, rhs)) @@ -286,6 +302,8 @@ def struct_field(native: StructArray, field: Field, /) -> ArrayAny: ... def struct_field(native: pa.StructScalar, field: Field, /) -> ScalarAny: ... @t.overload def struct_field(native: SameArrowT, field: Field, /) -> SameArrowT: ... +@t.overload +def struct_field(native: ChunkedOrScalarAny, field: Field, /) -> ChunkedOrScalarAny: ... def struct_field(native: ArrowAny, field: Field, /) -> ArrowAny: """Retrieve one `Struct` field.""" func = t.cast("Callable[[Any,Any], ArrowAny]", pc.struct_field) @@ -323,6 +341,8 @@ def list_len(native: ListArray) -> pa.UInt32Array: ... def list_len(native: ListScalar) -> pa.UInt32Scalar: ... @t.overload def list_len(native: SameArrowT) -> SameArrowT: ... +@t.overload +def list_len(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[pa.UInt32Scalar]: ... def list_len(native: ArrowAny) -> ArrowAny: length: Incomplete = pc.list_value_length result: ArrowAny = length(native).cast(pa.uint32()) @@ -339,6 +359,8 @@ def list_get(native: ListArray[DataTypeT], index: int) -> Array[Scalar[DataTypeT def list_get(native: ListScalar[DataTypeT], index: int) -> Scalar[DataTypeT]: ... @t.overload def list_get(native: SameArrowT, index: int) -> SameArrowT: ... +@t.overload +def list_get(native: ChunkedOrScalarAny, index: int) -> ChunkedOrScalarAny: ... def list_get(native: ArrowAny, index: int) -> ArrowAny: list_get_: Incomplete = pc.list_element result: ArrowAny = list_get_(native, index) @@ -364,13 +386,14 @@ def str_pad_start( return pc.utf8_lpad(native, length, fill_char) -_StringFunction: TypeAlias = "Callable[[ChunkedOrScalarAny,str], ChunkedOrScalarAny]" -str_starts_with = t.cast("_StringFunction", pc.starts_with) -str_ends_with = t.cast("_StringFunction", pc.ends_with) -str_split = t.cast("_StringFunction", pc.split_pattern) -str_to_uppercase = pc.utf8_upper -str_to_lowercase = pc.utf8_lower -str_to_titlecase = pc.utf8_title +_StringFunction0: TypeAlias = "Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny]" +_StringFunction1: TypeAlias = "Callable[[ChunkedOrScalarAny, str], ChunkedOrScalarAny]" +str_starts_with = t.cast("_StringFunction1", pc.starts_with) +str_ends_with = t.cast("_StringFunction1", pc.ends_with) +str_split = t.cast("_StringFunction1", pc.split_pattern) +str_to_uppercase = t.cast("_StringFunction0", pc.utf8_upper) +str_to_lowercase = t.cast("_StringFunction0", pc.utf8_lower) +str_to_titlecase = t.cast("_StringFunction0", pc.utf8_title) def str_contains( @@ -462,15 +485,15 @@ def when_then( return pc.if_else(predicate, then, otherwise) -def any_(native: Any) -> pa.BooleanScalar: +def any_(native: Incomplete) -> pa.BooleanScalar: return pc.any(native, min_count=0) -def all_(native: Any) -> pa.BooleanScalar: +def all_(native: Incomplete) -> pa.BooleanScalar: return pc.all(native, min_count=0) -def sum_(native: Any) -> NativeScalar: +def sum_(native: Incomplete) -> NativeScalar: return pc.sum(native, min_count=0) @@ -769,6 +792,10 @@ def is_in( def is_in(values: ArrayAny, /, other: ChunkedOrArrayAny) -> Array[pa.BooleanScalar]: ... @t.overload def is_in(values: ScalarAny, /, other: ChunkedOrArrayAny) -> pa.BooleanScalar: ... +@t.overload +def is_in( + values: ChunkedOrScalarAny, /, other: ChunkedOrArrayAny +) -> ChunkedOrScalarAny: ... def is_in(values: ArrowAny, /, other: ChunkedOrArrayAny) -> ArrowAny: """Check if elements of `values` are present in `other`. diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 334df56b2b..7f8d369595 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -11,6 +11,7 @@ import pyarrow as pa import pyarrow.compute as pc from pyarrow.lib import ( + BoolType, Date32Type, Int8Type, Int16Type, @@ -33,6 +34,8 @@ IntegerScalar: TypeAlias = "Scalar[IntegerType]" DateScalar: TypeAlias = "Scalar[Date32Type]" ListScalar: TypeAlias = "Scalar[pa.ListType[DataTypeT_co]]" + BooleanScalar: TypeAlias = "Scalar[BoolType]" + NumericScalar: TypeAlias = "pc.NumericScalar" class NativeArrowSeries(NativeSeries, Protocol): @property @@ -45,9 +48,16 @@ def columns(self) -> Sequence[NativeArrowSeries]: ... P = ParamSpec("P") + class UnaryFunctionP(Protocol[P]): + """A function wrapping at-most 1 `Expr` input.""" + + def __call__( + self, native: ChunkedOrScalarAny, /, *args: P.args, **kwds: P.kwargs + ) -> ChunkedOrScalarAny: ... + class VectorFunction(Protocol[P]): def __call__( - self, native: ChunkedArrayAny, *args: P.args, **kwds: P.kwargs + self, native: ChunkedArrayAny, /, *args: P.args, **kwds: P.kwargs ) -> ChunkedArrayAny: ... class BooleanLengthPreserving(Protocol): @@ -68,7 +78,7 @@ def __call__( ) NumericOrTemporalScalar: TypeAlias = "pc.NumericOrTemporalScalar" NumericOrTemporalScalarT = TypeVar( - "NumericOrTemporalScalarT", bound=NumericOrTemporalScalar, default="pc.NumericScalar" + "NumericOrTemporalScalarT", bound=NumericOrTemporalScalar, default="NumericScalar" ) @@ -152,16 +162,17 @@ def __call__( class BinaryComp( - BinaryFunction[ScalarPT_contra, "pa.BooleanScalar"], Protocol[ScalarPT_contra] + BinaryFunction[ScalarPT_contra, "BooleanScalar"], Protocol[ScalarPT_contra] ): ... -class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Protocol): ... +class BinaryLogical(BinaryFunction["BooleanScalar", "BooleanScalar"], Protocol): ... BinaryNumericTemporal: TypeAlias = BinaryFunction[ NumericOrTemporalScalarT, NumericOrTemporalScalarT ] +UnaryNumeric: TypeAlias = UnaryFunction["NumericScalar", "NumericScalar"] DataType: TypeAlias = "pa.DataType" DataTypeT = TypeVar("DataTypeT", bound=DataType, default=Any) DataTypeT_co = TypeVar("DataTypeT_co", bound=DataType, covariant=True, default=Any) @@ -189,7 +200,7 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" SameArrowT = TypeVar("SameArrowT", ChunkedArrayAny, ArrayAny, ScalarAny) ArrowT = TypeVar("ArrowT", bound=ArrowAny) -Predicate: TypeAlias = "Arrow[pa.BooleanScalar]" +Predicate: TypeAlias = "Arrow[BooleanScalar]" """Any `pyarrow` container that wraps boolean.""" NativeScalar: TypeAlias = ScalarAny From 8456f6f16fce5743886df1a592ef22a5fc705078 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 30 Nov 2025 18:02:49 +0000 Subject: [PATCH 108/215] feat: Add `coalesce` --- narwhals/_plan/__init__.py | 2 + narwhals/_plan/arrow/namespace.py | 9 +++ narwhals/_plan/compliant/namespace.py | 3 + narwhals/_plan/expressions/functions.py | 1 + narwhals/_plan/functions.py | 10 +-- tests/plan/coalesce_test.py | 93 +++++++++++++++++++++++++ 6 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 tests/plan/coalesce_test.py diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index c53d1576ba..9eaa64fa62 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -7,6 +7,7 @@ all, all_horizontal, any_horizontal, + coalesce, col, concat_str, date_range, @@ -37,6 +38,7 @@ "all", "all_horizontal", "any_horizontal", + "coalesce", "col", "concat_str", "date_range", diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index fa0e9e1e71..29c74e75be 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -116,6 +116,15 @@ def func(node: FunctionExpr[Any], frame: Frame, name: str) -> Expr | Scalar: return func + def coalesce( + self, node: FunctionExpr[F.Coalesce], frame: Frame, name: str + ) -> Expr | Scalar: + it = (self._expr.from_ir(e, frame, name).native for e in node.input) + result = pc.coalesce(*it) + if isinstance(result, pa.Scalar): + return self._scalar.from_native(result, name, self.version) + return self._expr.from_native(result, name, self.version) + def any_horizontal( self, node: FunctionExpr[AnyHorizontal], frame: Frame, name: str ) -> Expr | Scalar: diff --git a/narwhals/_plan/compliant/namespace.py b/narwhals/_plan/compliant/namespace.py index 9fa62aad93..c3ebb47b55 100644 --- a/narwhals/_plan/compliant/namespace.py +++ b/narwhals/_plan/compliant/namespace.py @@ -54,6 +54,9 @@ def col(self, node: ir.Column, frame: FrameT, name: str) -> ExprT_co: ... def concat_str( self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... + def coalesce( + self, node: FunctionExpr[F.Coalesce], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... def date_range( self, node: ir.RangeExpr[DateRange], frame: FrameT, name: str ) -> ExprT_co: ... diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index 379cb7b2f9..0b5c10b0d7 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -78,6 +78,7 @@ class SumHorizontal(HorizontalFunction): ... class MinHorizontal(HorizontalFunction): ... class MaxHorizontal(HorizontalFunction): ... class MeanHorizontal(HorizontalFunction): ... +class Coalesce(HorizontalFunction): ... # fmt: on class Hist(Function): """Only supported for `Series` so far.""" diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index a78ac28521..49afc36ccf 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -59,11 +59,6 @@ def format() -> Expr: raise NotImplementedError(msg) -def coalesce() -> Expr: - msg = "nwp.coalesce" - raise NotImplementedError(msg) - - def col(*names: str | t.Iterable[str]) -> Expr: flat = tuple(flatten(names)) return ( @@ -169,6 +164,11 @@ def concat_str( ) +def coalesce(exprs: IntoExpr | t.Iterable[IntoExpr], *more_exprs: IntoExpr) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) + return F.Coalesce().to_function_expr(*it).to_narwhals() + + def when( *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], **constraints: t.Any ) -> When: diff --git a/tests/plan/coalesce_test.py b/tests/plan/coalesce_test.py new file mode 100644 index 0000000000..a3005d7a0d --- /dev/null +++ b/tests/plan/coalesce_test.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data_int() -> Data: + return { + "a": [0, None, None, None, None], + "b": [1, None, None, 5, 3], + "c": [5, None, 3, 2, 1], + } + + +@pytest.fixture(scope="module") +def data_str() -> Data: + return { + "a": ["0", None, None, None, None], + "b": ["1", None, None, "5", "3"], + "c": ["5", None, "3", "2", "1"], + } + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.coalesce("a", "b", "c"), {"a": [0, None, 3, 5, 3]}), + ( + nwp.coalesce("a", "b", "c", nwp.lit(-100)).alias("lit"), + {"lit": [0, -100, 3, 5, 3]}, + ), + ( + nwp.coalesce(nwp.lit(None, nw.Int64), "b", "c", 500).alias("into_lit"), + {"into_lit": [1, 500, 3, 5, 3]}, + ), + ], +) +def test_coalesce_numeric(data_int: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data_int).select(expr) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.coalesce("a", "b", "c").alias("no_lit"), + {"no_lit": ["0", None, "3", "5", "3"]}, + ), + (nwp.coalesce("a", "b", "c", nwp.lit("xyz")), {"a": ["0", "xyz", "3", "5", "3"]}), + ], +) +def test_coalesce_strings(data_str: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data_str).select(expr) + assert_equal_data(result, expected) + + +def test_coalesce_series(data_str: Data) -> None: + df = dataframe(data_str) + ser = df.get_column("b").alias("b_renamed") + exprs = nwp.coalesce(ser, "a", nwp.col("c").fill_null("filled")), nwp.lit("ignored") + result = df.select(exprs) + assert_equal_data(result, {"b_renamed": ["1", "filled", "3", "5", "3"]}) + + +def test_coalesce_raises_non_expr() -> None: + class NotAnExpr: ... + + with pytest.raises( + TypeError, match=re.escape("'NotAnExpr' is not supported in `nw.lit`") + ): + nwp.coalesce("a", "b", "c", NotAnExpr()) # type: ignore[arg-type] + + +def test_coalesce_multi_output() -> None: + data = { + "col1": [True, None, False, False, None], + "col2": [True, False, True, False, None], + } + df = dataframe(data) + result = df.select(nwp.coalesce(nwp.all(), True)) + expected = {"col1": [True, False, False, False, True]} + assert_equal_data(result, expected) From 06dc04bc36388221d99c328a20986e8b57ec9137 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 30 Nov 2025 18:53:37 +0000 Subject: [PATCH 109/215] feat: Add `format` --- narwhals/_plan/__init__.py | 2 ++ narwhals/_plan/functions.py | 29 ++++++++++++++++++++++++----- tests/plan/format_test.py | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 5 deletions(-) create mode 100644 tests/plan/format_test.py diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index 9eaa64fa62..1528ea6ac6 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -12,6 +12,7 @@ concat_str, date_range, exclude, + format, int_range, len, lit, @@ -43,6 +44,7 @@ "concat_str", "date_range", "exclude", + "format", "int_range", "len", "lit", diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 49afc36ccf..ed985b3f95 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -3,7 +3,7 @@ import builtins import datetime as dt import typing as t -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Final from narwhals._duration import Interval from narwhals._plan import _guards, _parse, common, expressions as ir, selectors as cs @@ -53,10 +53,7 @@ t.Any, CompliantSeries[NativeSeriesT], t.Any, t.Any ] - -def format() -> Expr: - msg = "nwp.format" - raise NotImplementedError(msg) +_dtypes: Final = Version.MAIN.dtypes def col(*names: str | t.Iterable[str]) -> Expr: @@ -169,6 +166,28 @@ def coalesce(exprs: IntoExpr | t.Iterable[IntoExpr], *more_exprs: IntoExpr) -> E return F.Coalesce().to_function_expr(*it).to_narwhals() +def format(f_string: str, *args: IntoExpr) -> Expr: + """Format expressions as a string. + + Arguments: + f_string: A string that with placeholders. + args: Expression(s) that fill the placeholders. + """ + if (n_placeholders := f_string.count("{}")) != builtins.len(args): + msg = f"number of placeholders should equal the number of arguments. Expected {n_placeholders} arguments, got {builtins.len(args)}." + raise ValueError(msg) + string = _dtypes.String() + exprs: list[ir.ExprIR] = [] + it = iter(args) + for i, s in enumerate(f_string.split("{}")): + if i > 0: + exprs.append(_parse.parse_into_expr_ir(next(it))) + if s: + exprs.append(lit(s, string)._ir) + f = ConcatStr(separator="", ignore_nulls=False) + return f.to_function_expr(*exprs).to_narwhals() + + def when( *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], **constraints: t.Any ) -> When: diff --git a/tests/plan/format_test.py b/tests/plan/format_test.py new file mode 100644 index 0000000000..84c479fd77 --- /dev/null +++ b/tests/plan/format_test.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.format("hello {} {} wassup", "name", nwp.col("surname")), + [ + "hello bob builder wassup", + "hello alice wonderlander wassup", + "hello dodo extinct wassup", + ], + ), + ( + nwp.format("{} {} wassup", "name", nwp.col("surname")), + ["bob builder wassup", "alice wonderlander wassup", "dodo extinct wassup"], + ), + ], +) +def test_format(expr: nwp.Expr, expected: list[str]) -> None: + data = { + "name": ["bob", "alice", "dodo"], + "surname": ["builder", "wonderlander", "extinct"], + } + result = dataframe(data).select(fmt=expr) + assert_equal_data(result, {"fmt": expected}) + + +def test_format_invalid() -> None: + with pytest.raises(ValueError, match="Expected 2 arguments, got 1"): + nwp.format("hello {} {} wassup", "name") From 11fad9977aa6364034fe0451211e346ada6ee8f9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 30 Nov 2025 21:56:04 +0000 Subject: [PATCH 110/215] start adding `{all,any}_horizontal(ignore_nulls=True)` Just have the `ArrowNamespace` impls left --- narwhals/_plan/_parse.py | 2 +- narwhals/_plan/expressions/boolean.py | 8 +- narwhals/_plan/functions.py | 22 ++++- tests/plan/all_any_horizontal_test.py | 115 ++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 7 deletions(-) create mode 100644 tests/plan/all_any_horizontal_test.py diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index 8ba064c1d2..ecdfc92e26 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -335,7 +335,7 @@ def _combine_predicates(predicates: Iterator[ExprIR], /) -> ExprIR: inputs = (first,) else: return first - return AllHorizontal().to_function_expr(*inputs) + return AllHorizontal(ignore_nulls=False).to_function_expr(*inputs) def _is_iterable(obj: Iterable[T] | Any) -> TypeIs[Iterable[T]]: diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index 97d33fa31e..5b24fed778 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -22,9 +22,13 @@ # fmt: off class BooleanFunction(Function, options=FunctionOptions.elementwise): ... class All(BooleanFunction, options=FunctionOptions.aggregation): ... -class AllHorizontal(HorizontalFunction, BooleanFunction): ... class Any(BooleanFunction, options=FunctionOptions.aggregation): ... -class AnyHorizontal(HorizontalFunction, BooleanFunction): ... +class AllHorizontal(HorizontalFunction, BooleanFunction): + __slots__ = ("ignore_nulls",) + ignore_nulls: bool +class AnyHorizontal(HorizontalFunction, BooleanFunction): + __slots__ = ("ignore_nulls",) + ignore_nulls: bool class IsDuplicated(BooleanFunction, options=FunctionOptions.length_preserving): ... class IsFinite(BooleanFunction): ... class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index ed985b3f95..897b63037e 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -116,15 +116,29 @@ def sum(*columns: str) -> Expr: # TODO @dangotbanned: Support `ignore_nulls=...` -def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: +# NOTE: `polars` doesn't support yet +# Current behavior is equivalent to `ignore_nulls=False` +def all_horizontal( + *exprs: IntoExpr | t.Iterable[IntoExpr], ignore_nulls: bool = False +) -> Expr: it = _parse.parse_into_seq_of_expr_ir(*exprs) - return ir.boolean.AllHorizontal().to_function_expr(*it).to_narwhals() + return ( + ir.boolean.AllHorizontal(ignore_nulls=ignore_nulls) + .to_function_expr(*it) + .to_narwhals() + ) # TODO @dangotbanned: Support `ignore_nulls=...` -def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: +def any_horizontal( + *exprs: IntoExpr | t.Iterable[IntoExpr], ignore_nulls: bool = False +) -> Expr: it = _parse.parse_into_seq_of_expr_ir(*exprs) - return ir.boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() + return ( + ir.boolean.AnyHorizontal(ignore_nulls=ignore_nulls) + .to_function_expr(*it) + .to_narwhals() + ) def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: diff --git a/tests/plan/all_any_horizontal_test.py b/tests/plan/all_any_horizontal_test.py new file mode 100644 index 0000000000..ae1ed9447f --- /dev/null +++ b/tests/plan/all_any_horizontal_test.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +# test_allh_iterator has a different length +@pytest.fixture(scope="module") +def data() -> Data: + return { + # test_allh, test_allh_series, test_allh_all, test_allh_nth, test_horizontal_expressions_empty + "a": [False, False, True], + "b": [False, True, True], + # test_all_ignore_nulls, test_allh_kleene, test_anyh_dask, + "c": [True, True, False], + "d": [True, None, None], + "e": [None, True, False], + } + + +XFAIL_NOT_IMPL = pytest.mark.xfail( + reason="TODO: `{all,any}_horizontal(ignore_nulls=True)`", raises=AssertionError +) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.all_horizontal("a", nwp.col("b"), ignore_nulls=True), + {"a": [False, False, True]}, + ), + pytest.param( + nwp.all_horizontal("c", "d", ignore_nulls=True), + {"c": [True, True, False]}, + id="ignore_nulls-1", + marks=XFAIL_NOT_IMPL, + ), + (nwp.all_horizontal("c", "d", ignore_nulls=False), {"c": [True, None, False]}), + ( + nwp.all_horizontal(nwp.nth(0, 1), ignore_nulls=True), + {"a": [False, False, True]}, + ), + pytest.param( + nwp.all_horizontal( + nwp.col("a"), nwp.nth(0), ncs.first(), "a", ignore_nulls=True + ), + {"a": [False, False, True]}, + id="duplicated", + ), + ( + nwp.all_horizontal(nwp.exclude("a", "b"), ignore_nulls=False), + {"c": [None, None, False]}, + ), + pytest.param( + nwp.all_horizontal(ncs.all() - ncs.by_index(0, 1), ignore_nulls=True), + {"c": [True, True, False]}, + id="ignore_nulls-2", + marks=XFAIL_NOT_IMPL, + ), + ], +) +def test_all_horizontal(data: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data).select(expr) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.any_horizontal("a", nwp.col("b"), ignore_nulls=True), + {"a": [False, True, True]}, + ), + (nwp.any_horizontal("c", "d", ignore_nulls=False), {"c": [True, True, None]}), + pytest.param( + nwp.any_horizontal("c", "d", ignore_nulls=True), + {"c": [True, True, False]}, + id="ignore_nulls-1", + marks=XFAIL_NOT_IMPL, + ), + ( + nwp.any_horizontal(nwp.nth(0, 1), ignore_nulls=False), + {"a": [False, True, True]}, + ), + pytest.param( + nwp.any_horizontal( + nwp.col("a"), nwp.nth(0), ncs.first(), "a", ignore_nulls=True + ), + {"a": [False, False, True]}, + id="duplicated", + ), + ( + nwp.any_horizontal(nwp.exclude("a", "b"), ignore_nulls=False), + {"c": [True, True, None]}, + ), + pytest.param( + nwp.any_horizontal(ncs.all() - ncs.by_index(0, 1), ignore_nulls=True), + {"c": [True, True, False]}, + id="ignore_nulls-2", + marks=XFAIL_NOT_IMPL, + ), + ], +) +def test_any_horizontal(data: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data).select(expr) + assert_equal_data(result, expected) From 29e1f2a31d3d980741d8476c04d5f0e9422673b3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:19:27 +0000 Subject: [PATCH 111/215] feat: Support `{all,any}_horizontal(ignore_nulls=True)` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I forgot I made it easy for myself with `fill` 😅 --- narwhals/_plan/arrow/functions.py | 15 +++++++++++++-- narwhals/_plan/arrow/namespace.py | 10 +++++----- tests/plan/all_any_horizontal_test.py | 11 ----------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index b89034a5f7..e1d32869ea 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -698,11 +698,22 @@ def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArra return when_then(or_(is_not_null, beyond_limit), native, native.take(index_not_null)) +@t.overload +def fill_null( + native: ChunkedOrScalarT, value: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarT: ... +@t.overload def fill_null( native: ChunkedOrArrayT, value: ScalarAny | NonNestedLiteral | ChunkedOrArrayT -) -> ChunkedOrArrayT: +) -> ChunkedOrArrayT: ... +@t.overload +def fill_null( + native: ChunkedOrScalarAny, value: ChunkedOrScalarAny | NonNestedLiteral +) -> ChunkedOrScalarAny: ... +def fill_null(native: ArrowAny, value: ArrowAny | NonNestedLiteral) -> ArrowAny: fill_value: Incomplete = value - return pc.fill_null(native, fill_value) + result: ArrowAny = pc.fill_null(native, fill_value) + return result @t.overload diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 29c74e75be..6c23a38e4a 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -100,15 +100,13 @@ def lit( nw_ser.to_native(), name or node.name, nw_ser.version ) - # NOTE: Update with `ignore_nulls`/`fill_null` behavior once added to each `Function` - # https://github.com/narwhals-dev/narwhals/pull/2719 def _horizontal_function( self, fn_native: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None ) -> Callable[[FunctionExpr[Any], Frame, str], Expr | Scalar]: def func(node: FunctionExpr[Any], frame: Frame, name: str) -> Expr | Scalar: it = (self._expr.from_ir(e, frame, name).native for e in node.input) if fill is not None: - it = (pc.fill_null(native, fn.lit(fill)) for native in it) + it = (fn.fill_null(native, fill) for native in it) result = reduce(fn_native, it) if isinstance(result, pa.Scalar): return self._scalar.from_native(result, name, self.version) @@ -128,12 +126,14 @@ def coalesce( def any_horizontal( self, node: FunctionExpr[AnyHorizontal], frame: Frame, name: str ) -> Expr | Scalar: - return self._horizontal_function(fn.or_)(node, frame, name) + fill = False if node.function.ignore_nulls else None + return self._horizontal_function(fn.or_, fill)(node, frame, name) def all_horizontal( self, node: FunctionExpr[AllHorizontal], frame: Frame, name: str ) -> Expr | Scalar: - return self._horizontal_function(fn.and_)(node, frame, name) + fill = True if node.function.ignore_nulls else None + return self._horizontal_function(fn.and_, fill)(node, frame, name) def sum_horizontal( self, node: FunctionExpr[F.SumHorizontal], frame: Frame, name: str diff --git a/tests/plan/all_any_horizontal_test.py b/tests/plan/all_any_horizontal_test.py index ae1ed9447f..df9a4859e0 100644 --- a/tests/plan/all_any_horizontal_test.py +++ b/tests/plan/all_any_horizontal_test.py @@ -16,21 +16,14 @@ @pytest.fixture(scope="module") def data() -> Data: return { - # test_allh, test_allh_series, test_allh_all, test_allh_nth, test_horizontal_expressions_empty "a": [False, False, True], "b": [False, True, True], - # test_all_ignore_nulls, test_allh_kleene, test_anyh_dask, "c": [True, True, False], "d": [True, None, None], "e": [None, True, False], } -XFAIL_NOT_IMPL = pytest.mark.xfail( - reason="TODO: `{all,any}_horizontal(ignore_nulls=True)`", raises=AssertionError -) - - @pytest.mark.parametrize( ("expr", "expected"), [ @@ -42,7 +35,6 @@ def data() -> Data: nwp.all_horizontal("c", "d", ignore_nulls=True), {"c": [True, True, False]}, id="ignore_nulls-1", - marks=XFAIL_NOT_IMPL, ), (nwp.all_horizontal("c", "d", ignore_nulls=False), {"c": [True, None, False]}), ( @@ -64,7 +56,6 @@ def data() -> Data: nwp.all_horizontal(ncs.all() - ncs.by_index(0, 1), ignore_nulls=True), {"c": [True, True, False]}, id="ignore_nulls-2", - marks=XFAIL_NOT_IMPL, ), ], ) @@ -85,7 +76,6 @@ def test_all_horizontal(data: Data, expr: nwp.Expr, expected: Data) -> None: nwp.any_horizontal("c", "d", ignore_nulls=True), {"c": [True, True, False]}, id="ignore_nulls-1", - marks=XFAIL_NOT_IMPL, ), ( nwp.any_horizontal(nwp.nth(0, 1), ignore_nulls=False), @@ -106,7 +96,6 @@ def test_all_horizontal(data: Data, expr: nwp.Expr, expected: Data) -> None: nwp.any_horizontal(ncs.all() - ncs.by_index(0, 1), ignore_nulls=True), {"c": [True, True, False]}, id="ignore_nulls-2", - marks=XFAIL_NOT_IMPL, ), ], ) From 4a86c6806a28a17021e579ace5c44b9c10a7a362 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:27:04 +0000 Subject: [PATCH 112/215] =?UTF-8?q?=F0=9F=A7=B9=F0=9F=A7=B9=F0=9F=A7=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- narwhals/_plan/arrow/namespace.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 6c23a38e4a..84efdc43bf 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -25,7 +25,7 @@ from narwhals._plan.arrow.typing import ChunkedArray, IntegerScalar from narwhals._plan.expressions import expr, functions as F from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal - from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr + from narwhals._plan.expressions.expr import FunctionExpr as FExpr, RangeExpr from narwhals._plan.expressions.ranges import DateRange, IntRange from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.series import Series as NwSeries @@ -102,8 +102,8 @@ def lit( def _horizontal_function( self, fn_native: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None - ) -> Callable[[FunctionExpr[Any], Frame, str], Expr | Scalar]: - def func(node: FunctionExpr[Any], frame: Frame, name: str) -> Expr | Scalar: + ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: + def func(node: FExpr[Any], frame: Frame, name: str) -> Expr | Scalar: it = (self._expr.from_ir(e, frame, name).native for e in node.input) if fill is not None: it = (fn.fill_null(native, fill) for native in it) @@ -114,9 +114,7 @@ def func(node: FunctionExpr[Any], frame: Frame, name: str) -> Expr | Scalar: return func - def coalesce( - self, node: FunctionExpr[F.Coalesce], frame: Frame, name: str - ) -> Expr | Scalar: + def coalesce(self, node: FExpr[F.Coalesce], frame: Frame, name: str) -> Expr | Scalar: it = (self._expr.from_ir(e, frame, name).native for e in node.input) result = pc.coalesce(*it) if isinstance(result, pa.Scalar): @@ -124,38 +122,38 @@ def coalesce( return self._expr.from_native(result, name, self.version) def any_horizontal( - self, node: FunctionExpr[AnyHorizontal], frame: Frame, name: str + self, node: FExpr[AnyHorizontal], frame: Frame, name: str ) -> Expr | Scalar: fill = False if node.function.ignore_nulls else None return self._horizontal_function(fn.or_, fill)(node, frame, name) def all_horizontal( - self, node: FunctionExpr[AllHorizontal], frame: Frame, name: str + self, node: FExpr[AllHorizontal], frame: Frame, name: str ) -> Expr | Scalar: fill = True if node.function.ignore_nulls else None return self._horizontal_function(fn.and_, fill)(node, frame, name) def sum_horizontal( - self, node: FunctionExpr[F.SumHorizontal], frame: Frame, name: str + self, node: FExpr[F.SumHorizontal], frame: Frame, name: str ) -> Expr | Scalar: return self._horizontal_function(fn.add, fill=0)(node, frame, name) def min_horizontal( - self, node: FunctionExpr[F.MinHorizontal], frame: Frame, name: str + self, node: FExpr[F.MinHorizontal], frame: Frame, name: str ) -> Expr | Scalar: return self._horizontal_function(fn.min_horizontal)(node, frame, name) def max_horizontal( - self, node: FunctionExpr[F.MaxHorizontal], frame: Frame, name: str + self, node: FExpr[F.MaxHorizontal], frame: Frame, name: str ) -> Expr | Scalar: return self._horizontal_function(fn.max_horizontal)(node, frame, name) def mean_horizontal( - self, node: FunctionExpr[F.MeanHorizontal], frame: Frame, name: str + self, node: FExpr[F.MeanHorizontal], frame: Frame, name: str ) -> Expr | Scalar: int64 = pa.int64() inputs = [self._expr.from_ir(e, frame, name).native for e in node.input] - filled = (pc.fill_null(native, fn.lit(0)) for native in inputs) + filled = (fn.fill_null(native, 0) for native in inputs) # NOTE: `mypy` doesn't like that `add` is overloaded sum_not_null = reduce( fn.add, # type: ignore[arg-type] @@ -167,7 +165,7 @@ def mean_horizontal( return self._expr.from_native(result, name, self.version) def concat_str( - self, node: FunctionExpr[ConcatStr], frame: Frame, name: str + self, node: FExpr[ConcatStr], frame: Frame, name: str ) -> Expr | Scalar: exprs = (self._expr.from_ir(e, frame, name) for e in node.input) aligned = (ser.native for ser in self._expr.align(exprs)) From 91f222d17c1b63b9d947abccf46c6163f503918e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 1 Dec 2025 20:56:47 +0000 Subject: [PATCH 113/215] feat: Add `{DataFrame,Series}.sample` `Expr.sample` is deprecated, but still not added to `ExprIR` yet Planning to just defer to the `Series` impl --- narwhals/_plan/arrow/common.py | 12 +++- narwhals/_plan/arrow/dataframe.py | 4 ++ narwhals/_plan/arrow/functions.py | 21 ++++++- narwhals/_plan/arrow/series.py | 6 ++ narwhals/_plan/compliant/dataframe.py | 11 ++++ narwhals/_plan/compliant/series.py | 11 ++++ narwhals/_plan/dataframe.py | 26 ++++++++ narwhals/_plan/expressions/literal.py | 7 +++ narwhals/_plan/series.py | 42 ++++++++++++- tests/plan/sample_test.py | 91 +++++++++++++++++++++++++++ 10 files changed, 228 insertions(+), 3 deletions(-) create mode 100644 tests/plan/sample_test.py diff --git a/narwhals/_plan/arrow/common.py b/narwhals/_plan/arrow/common.py index 217f0a193d..bd6f4b8e2e 100644 --- a/narwhals/_plan/arrow/common.py +++ b/narwhals/_plan/arrow/common.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic -from narwhals._plan.arrow.functions import BACKEND_VERSION +from narwhals._plan.arrow.functions import BACKEND_VERSION, random_indices from narwhals._typing_compat import TypeVar from narwhals._utils import Implementation, Version, _StoresNative @@ -43,6 +43,10 @@ def _with_native(self, native: NativeT) -> Self: msg = f"{type(self).__name__}._with_native" raise NotImplementedError(msg) + def __len__(self) -> int: + msg = f"{type(self).__name__}.__len__" + raise NotImplementedError(msg) + if BACKEND_VERSION >= (18,): def _gather(self, indices: Indices) -> NativeT: @@ -62,3 +66,9 @@ def gather_every(self, n: int, offset: int = 0) -> Self: def slice(self, offset: int, length: int | None = None) -> Self: return self._with_native(self.native.slice(offset=offset, length=length)) + + def sample_n( + self, n: int = 1, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: + mask = random_indices(len(self), n, with_replacement=with_replacement, seed=seed) + return self.gather(mask) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 0497c6cc41..0784b1dd00 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -48,6 +48,10 @@ def _with_native(self, native: pa.Table) -> Self: def _group_by(self) -> type[GroupBy]: return GroupBy + @property + def shape(self) -> tuple[int, int]: + return self.native.shape + def group_by_resolver(self, resolver: GroupByResolver, /) -> GroupBy: return self._group_by.from_resolver(self, resolver) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index e1d32869ea..9fc65705a7 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -580,7 +580,11 @@ def n_unique(native: Any) -> pa.Int64Scalar: return count(native, mode="all") -def round(native: ChunkedOrScalarAny, decimals: int = 0) -> ChunkedOrScalarAny: +@t.overload +def round(native: ChunkedOrScalarAny, decimals: int = ...) -> ChunkedOrScalarAny: ... +@t.overload +def round(native: ChunkedOrArrayT, decimals: int = ...) -> ChunkedOrArrayT: ... +def round(native: ArrowAny, decimals: int = 0) -> ArrowAny: return pc.round(native, decimals, round_mode="half_towards_infinity") @@ -863,6 +867,21 @@ def concat_str( return concat(*it, lit(separator, dtype), options=join) # type: ignore[no-any-return] +def random_indices( + end: int, /, n: int, *, with_replacement: bool = False, seed: int | None = None +) -> ArrayAny: + """Generate `n` random indices within the range `[0, end)`.""" + # NOTE: Review this path if anything changes upstream + # https://github.com/apache/arrow/issues/47288#issuecomment-3597653670 + if with_replacement: + rand_values = pc.random(n, initializer="system" if seed is None else seed) + return round(multiply(rand_values, lit(end - 1))).cast(I64) + + import numpy as np # ignore-banned-import + + return array(np.random.default_rng(seed).choice(np.arange(end), n, replace=False)) + + def sort_indices( native: ChunkedOrArrayAny | pa.Table, *order_by: str, diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index f54b5eec72..eefc776697 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -258,3 +258,9 @@ def zip_with(self, mask: Self, other: Self | None) -> Self: predicate = mask.native.combine_chunks() right = other.native if other is not None else other return self._with_native(fn.when_then(predicate, self.native, right)) + + def all(self) -> bool: + return fn.all_(self.native).as_py() + + def any(self) -> bool: + return fn.any_(self.native).as_py() diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index f3be22775a..83f4ff21ca 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -80,6 +80,8 @@ class CompliantDataFrame( implementation: ClassVar[_EagerAllowedImpl] _native: NativeDataFrameT + @property + def shape(self) -> tuple[int, int]: ... def __len__(self) -> int: ... @property def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... @@ -170,6 +172,15 @@ def to_series(self, index: int = 0) -> SeriesT: ... def to_polars(self) -> pl.DataFrame: ... def with_row_index(self, name: str) -> Self: ... def slice(self, offset: int, length: int | None = None) -> Self: ... + def sample_frac( + self, fraction: float, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: + n = int(len(self) * fraction) + return self.sample_n(n, with_replacement=with_replacement, seed=seed) + + def sample_n( + self, n: int = 1, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: ... class EagerDataFrame( diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index a9b3ee3893..1d118ad1dd 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -116,6 +116,8 @@ def name(self) -> str: def native(self) -> NativeSeriesT: return self._native + def all(self) -> bool: ... + def any(self) -> bool: ... def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) @@ -155,6 +157,15 @@ def rolling_sum( def rolling_var( self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 ) -> Self: ... + def sample_frac( + self, fraction: float, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: + n = int(len(self) * fraction) + return self.sample_n(n, with_replacement=with_replacement, seed=seed) + + def sample_n( + self, n: int = 1, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: ... def scatter(self, indices: Self, values: Self) -> Self: ... def slice(self, offset: int, length: int | None = None) -> Self: ... def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: ... diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index ccf78d3978..a39f6e3653 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -151,6 +151,10 @@ class DataFrame( def implementation(self) -> _EagerAllowedImpl: return self._compliant.implementation + @property + def shape(self) -> tuple[int, int]: + return self._compliant.shape + def __len__(self) -> int: # pragma: no cover return len(self._compliant) @@ -310,6 +314,28 @@ def with_row_index( def slice(self, offset: int, length: int | None = None) -> Self: # pragma: no cover return type(self)(self._compliant.slice(offset=offset, length=length)) + def sample( + self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + if n is not None and fraction is not None: + msg = "cannot specify both `n` and `fraction`" + raise ValueError(msg) + df = self._compliant + if fraction is not None: + result = df.sample_frac( + fraction, with_replacement=with_replacement, seed=seed + ) + elif n is None: + result = df.sample_n(with_replacement=with_replacement, seed=seed) + else: + result = df.sample_n(n, with_replacement=with_replacement, seed=seed) + return type(self)(result) + def _is_join_strategy(obj: Any) -> TypeIs[JoinStrategy]: return obj in {"inner", "left", "full", "cross", "anti", "semi"} diff --git a/narwhals/_plan/expressions/literal.py b/narwhals/_plan/expressions/literal.py index 7d46c8436c..b6b8659512 100644 --- a/narwhals/_plan/expressions/literal.py +++ b/narwhals/_plan/expressions/literal.py @@ -7,6 +7,8 @@ from narwhals._plan.typing import LiteralT, NativeSeriesT, NonNestedLiteralT if TYPE_CHECKING: + from collections.abc import Iterator + from typing_extensions import TypeIs from narwhals._plan.expressions.expr import Literal @@ -74,6 +76,11 @@ def __repr__(self) -> str: def unwrap(self) -> Series[NativeSeriesT]: return self.value + @property + def __immutable_values__(self) -> Iterator[Any]: + # NOTE: Adding `Series.__eq__` means this needed a manual override + yield from (self.name, self.dtype, id(self.value)) + def is_literal_scalar( obj: Literal[NonNestedLiteralT] | Any, diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 53b9806052..4e2a1d0cf6 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -24,7 +24,13 @@ from narwhals._plan.dataframe import DataFrame from narwhals._typing import EagerAllowed, IntoBackend, _EagerAllowedImpl from narwhals.dtypes import DType - from narwhals.typing import IntoDType, NonNestedLiteral, SizedMultiIndexSelector + from narwhals.typing import ( + IntoDType, + NonNestedLiteral, + NumericLiteral, + SizedMultiIndexSelector, + TemporalLiteral, + ) Incomplete: TypeAlias = Any @@ -49,6 +55,10 @@ def name(self) -> str: def implementation(self) -> _EagerAllowedImpl: return self._compliant.implementation + @property + def shape(self) -> tuple[int]: + return (self._compliant.len(),) + def __init__(self, compliant: CompliantSeries[NativeSeriesT_co], /) -> None: self._compliant = compliant @@ -183,6 +193,36 @@ def fill_nan(self, value: float | Self | None) -> Self: other = self._unwrap_compliant(value) if is_series(value) else value return type(self)(self._compliant.fill_nan(other)) + def sample( + self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + if n is not None and fraction is not None: + msg = "cannot specify both `n` and `fraction`" + raise ValueError(msg) + s = self._compliant + if fraction is not None: + result = s.sample_frac(fraction, with_replacement=with_replacement, seed=seed) + elif n is None: + result = s.sample_n(with_replacement=with_replacement, seed=seed) + else: + result = s.sample_n(n, with_replacement=with_replacement, seed=seed) + return type(self)(result) + + def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # type: ignore[override] + other_ = self._unwrap_compliant(other) if is_series(other) else other + return type(self)(self._compliant.__eq__(other_)) + + def all(self) -> bool: + return self._compliant.all() + + def any(self) -> bool: # pragma: no cover + return self._compliant.any() + class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 diff --git a/tests/plan/sample_test.py b/tests/plan/sample_test.py new file mode 100644 index 0000000000..84d32b02ee --- /dev/null +++ b/tests/plan/sample_test.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from tests.plan.utils import dataframe, series + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [1, 2, 3] * 10, "b": [4, 5, 6] * 10} + + +@pytest.fixture(scope="module") +def data_big() -> Data: + return {"a": list(range(100))} + + +@pytest.mark.parametrize("n", [None, 1, 7, 29]) +def test_sample_n_series(data: Data, n: int | None) -> None: + result = series(data["a"]).sample(n).shape + expected = (1,) if n is None else (n,) + assert result == expected + + +def test_sample_fraction_series(data: Data) -> None: + result = series(data["a"]).sample(fraction=0.1).shape + expected = (3,) + assert result == expected + + +@pytest.mark.parametrize("n", [10]) +def test_sample_with_seed_series(data_big: Data, n: int) -> None: + ser = series(data_big["a"]) + seed1 = ser.sample(n, seed=123) + seed2 = ser.sample(n, seed=123) + seed3 = ser.sample(n, seed=42) + result = {"res1": [(seed1 == seed2).all()], "res2": [(seed1 == seed3).all()]} + expected = {"res1": [True], "res2": [False]} + assert result == expected + + +@pytest.mark.parametrize("n", [2, None, 1, 18]) +def test_sample_n_dataframe(data: Data, n: int | None) -> None: + result = dataframe(data).sample(n=n).shape + expected = (1, 2) if n is None else (n, 2) + assert result == expected + + +def test_sample_fraction_dataframe(data: Data) -> None: + result = dataframe(data).sample(fraction=0.5).shape + expected = (15, 2) + assert result == expected + + +@pytest.mark.parametrize("n", [10]) +def test_sample_with_seed_dataframe(data_big: Data, n: int) -> None: + df = dataframe(data_big) + r1 = df.sample(n, seed=123).to_native() + r2 = df.sample(n, seed=123).to_native() + r3 = df.sample(n, seed=42).to_native() + assert r1.equals(r2) + assert not r1.equals(r3) + + +# NOTE: `with_replacement=True` has no tests on `main`? +@pytest.mark.xfail +def test_sample_with_replacement_series() -> None: + msg = "TODO: add tests" + raise NotImplementedError(msg) + + +# NOTE: `with_replacement=True` has no tests on `main`? +@pytest.mark.xfail +def test_sample_with_replacement_dataframe() -> None: + msg = "TODO: add tests" + raise NotImplementedError(msg) + + +def test_sample_invalid(data: Data) -> None: + df = dataframe(data) + ser = df.to_series() + + with pytest.raises(ValueError, match=r"cannot specify both `n` and `fraction`"): + df.sample(n=1, fraction=0.5) + with pytest.raises(ValueError, match=r"cannot specify both `n` and `fraction`"): + ser.sample(n=567, fraction=0.1) From d88d2c13fd4bf0ab5e3b1da3eb843f36cc42796e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 1 Dec 2025 21:42:15 +0000 Subject: [PATCH 114/215] test: cover `with_replacement=True` --- narwhals/_plan/dataframe.py | 12 ++++++++---- narwhals/_plan/series.py | 6 +++++- tests/plan/sample_test.py | 30 ++++++++++++++++++------------ 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index a39f6e3653..d7f073c96f 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -23,6 +23,7 @@ ) from narwhals._utils import Implementation, Version, generate_repr from narwhals.dependencies import is_pyarrow_table +from narwhals.exceptions import ShapeError from narwhals.schema import Schema from narwhals.typing import IntoDType, JoinStrategy @@ -155,7 +156,7 @@ def implementation(self) -> _EagerAllowedImpl: def shape(self) -> tuple[int, int]: return self._compliant.shape - def __len__(self) -> int: # pragma: no cover + def __len__(self) -> int: return len(self._compliant) @property @@ -215,7 +216,7 @@ def to_dict( } return self._compliant.to_dict(as_series=as_series) - def to_series(self, index: int = 0) -> Series[NativeSeriesT]: # pragma: no cover + def to_series(self, index: int = 0) -> Series[NativeSeriesT]: return self._series(self._compliant.to_series(index)) def to_polars(self) -> pl.DataFrame: @@ -224,7 +225,7 @@ def to_polars(self) -> pl.DataFrame: def gather_every(self, n: int, offset: int = 0) -> Self: return self._with_compliant(self._compliant.gather_every(n, offset)) - def get_column(self, name: str) -> Series[NativeSeriesT]: # pragma: no cover + def get_column(self, name: str) -> Series[NativeSeriesT]: return self._series(self._compliant.get_column(name)) @overload @@ -311,7 +312,7 @@ def with_row_index( return self._with_compliant(self._compliant.with_row_index(name)) return super().with_row_index(name, order_by=order_by) - def slice(self, offset: int, length: int | None = None) -> Self: # pragma: no cover + def slice(self, offset: int, length: int | None = None) -> Self: return type(self)(self._compliant.slice(offset=offset, length=length)) def sample( @@ -332,6 +333,9 @@ def sample( ) elif n is None: result = df.sample_n(with_replacement=with_replacement, seed=seed) + elif not with_replacement and n > len(self): + msg = "cannot take a larger sample than the total population when `with_replacement=false`" + raise ShapeError(msg) else: result = df.sample_n(n, with_replacement=with_replacement, seed=seed) return type(self)(result) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 4e2a1d0cf6..d2e4c6ed5f 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -13,6 +13,7 @@ qualified_type_name, ) from narwhals.dependencies import is_pyarrow_chunked_array +from narwhals.exceptions import ShapeError if TYPE_CHECKING: from collections.abc import Iterator @@ -138,7 +139,7 @@ def gather_every(self, n: int, offset: int = 0) -> Self: def has_nulls(self) -> bool: # pragma: no cover return self._compliant.has_nulls() - def slice(self, offset: int, length: int | None = None) -> Self: # pragma: no cover + def slice(self, offset: int, length: int | None = None) -> Self: return type(self)(self._compliant.slice(offset=offset, length=length)) def sort( @@ -209,6 +210,9 @@ def sample( result = s.sample_frac(fraction, with_replacement=with_replacement, seed=seed) elif n is None: result = s.sample_n(with_replacement=with_replacement, seed=seed) + elif not with_replacement and n > len(self): + msg = "cannot take a larger sample than the total population when `with_replacement=false`" + raise ShapeError(msg) else: result = s.sample_n(n, with_replacement=with_replacement, seed=seed) return type(self)(result) diff --git a/tests/plan/sample_test.py b/tests/plan/sample_test.py index 84d32b02ee..7d29423883 100644 --- a/tests/plan/sample_test.py +++ b/tests/plan/sample_test.py @@ -4,6 +4,7 @@ import pytest +from narwhals.exceptions import ShapeError from tests.plan.utils import dataframe, series if TYPE_CHECKING: @@ -67,25 +68,30 @@ def test_sample_with_seed_dataframe(data_big: Data, n: int) -> None: assert not r1.equals(r3) -# NOTE: `with_replacement=True` has no tests on `main`? -@pytest.mark.xfail -def test_sample_with_replacement_series() -> None: - msg = "TODO: add tests" - raise NotImplementedError(msg) +@pytest.mark.parametrize("n", [39, 42, 20, 99]) +def test_sample_with_replacement_series(data: Data, n: int) -> None: + result = series(data["a"]).slice(0, 10).sample(n, with_replacement=True) + assert len(result) == n -# NOTE: `with_replacement=True` has no tests on `main`? -@pytest.mark.xfail -def test_sample_with_replacement_dataframe() -> None: - msg = "TODO: add tests" - raise NotImplementedError(msg) +@pytest.mark.parametrize("n", [10, 15, 28, 100]) +def test_sample_with_replacement_dataframe(data: Data, n: int) -> None: + result = dataframe(data).slice(0, 5).sample(n, with_replacement=True) + assert len(result) == n def test_sample_invalid(data: Data) -> None: df = dataframe(data) ser = df.to_series() - with pytest.raises(ValueError, match=r"cannot specify both `n` and `fraction`"): + both_n_fraction = r"cannot specify both `n` and `fraction`" + too_high_n = r"cannot take a larger sample than the total population when `with_replacement=false`" + + with pytest.raises(ValueError, match=both_n_fraction): df.sample(n=1, fraction=0.5) - with pytest.raises(ValueError, match=r"cannot specify both `n` and `fraction`"): + with pytest.raises(ValueError, match=both_n_fraction): ser.sample(n=567, fraction=0.1) + with pytest.raises(ShapeError, match=too_high_n): + df.sample(n=1_000) + with pytest.raises(ShapeError, match=too_high_n): + ser.sample(n=2_000) From e5c2165a9f3b82fa9d8cf8f4f924e00abe16b69e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 2 Dec 2025 11:28:17 +0000 Subject: [PATCH 115/215] test: xfail `over(order_by=...).meta.root_names()` on nightly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Quite glad I added this test now Likely would've gone unnoticed otherwise 😄 https://github.com/narwhals-dev/narwhals/actions/runs/19837257915/job/56837407975#step:7:543 --- narwhals/_plan/expressions/expr.py | 3 ++- tests/plan/meta_test.py | 24 ++++++++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 7f173d2b1b..87271ea1db 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -397,8 +397,9 @@ def __repr__(self) -> str: args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" return f"{self.expr!r}.over({args})" + # TODO @dangotbanned: Update to align with https://github.com/pola-rs/polars/pull/25117/files#diff-45d1f22172e291bd4a5ce36d1fb8233698394f9590bcf11382b9c99b5449fff5 def iter_root_names(self) -> t.Iterator[ExprIR]: - # NOTE: `order_by` is never considered in `polars` + # NOTE: `order_by` ~~is~~ was never considered in `polars` # To match that behavior for `root_names` - but still expand in all other cases # - this little escape hatch exists # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 36e1e12160..503dd26b3e 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -17,15 +17,31 @@ pytest.importorskip("polars") import polars as pl -if POLARS_VERSION >= (1, 0): - # https://github.com/pola-rs/polars/pull/16743 - OVER_CASE = ( +if POLARS_VERSION >= (1, 0): # https://github.com/pola-rs/polars/pull/16743 + if POLARS_VERSION >= (1, 36): # pragma: no cover + # TODO @dangotbanned: Update special-casing in `OrderedWindowExpr` + # https://github.com/pola-rs/polars/pull/25117/files#diff-45d1f22172e291bd4a5ce36d1fb8233698394f9590bcf11382b9c99b5449fff5 + marks: tuple[pytest.MarkDecorator, ...] = ( + pytest.mark.xfail( + reason=( + "`polars==1.36.0b1` now considers `order_by` in `root_names`\n" + r"https://github.com/pola-rs/polars/pull/25117" + ), + raises=AssertionError, + ), + ) + else: + marks = () + OVER_CASE = pytest.param( nwp.col("a").last().over("b", order_by="c"), pl.col("a").last().over("b", order_by="c"), ["a", "b"], + marks=marks, ) else: # pragma: no cover - OVER_CASE = (nwp.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) + OVER_CASE = pytest.param( + nwp.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"] + ) if POLARS_VERSION >= (0, 20, 5): LEN_CASE = (nwp.len(), pl.len(), "len") else: # pragma: no cover From 6c16ff57bb6ede75351c9ffc87da5b7c987166a1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 2 Dec 2025 13:40:32 +0000 Subject: [PATCH 116/215] refactor: Move `unique` to `Series` --- narwhals/_plan/arrow/expr.py | 3 +-- narwhals/_plan/arrow/series.py | 3 +++ narwhals/_plan/compliant/series.py | 1 + narwhals/_plan/series.py | 3 +++ 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 863afea08f..66ebefb9b4 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -629,8 +629,7 @@ def _cumulative(self, node: FExpr[CumAgg], frame: Frame, name: str) -> Self: return self._with_native(fn.cumulative(native, node.function), name) def unique(self, node: FExpr[F.Unique], frame: Frame, name: str) -> Self: - result = self._dispatch_expr(node.input[0], frame, name).native.unique() - return self._with_native(result, name) + return self.from_series(self._dispatch_expr(node.input[0], frame, name).unique()) def gather_every(self, node: FExpr[F.GatherEvery], frame: Frame, name: str) -> Self: series = self._dispatch_expr(node.input[0], frame, name) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index eefc776697..5437403cd9 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -264,3 +264,6 @@ def all(self) -> bool: def any(self) -> bool: return fn.any_(self.native).as_py() + + def unique(self, *, maintain_order: bool = False) -> Self: + return self._with_native(self.native.unique()) diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 1d118ad1dd..5d2ee7c7ce 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -178,4 +178,5 @@ def to_narwhals(self) -> Series[NativeSeriesT]: def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... def to_polars(self) -> pl.Series: ... + def unique(self, *, maintain_order: bool = False) -> Self: ... def zip_with(self, mask: Self, other: Self) -> Self: ... diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index d2e4c6ed5f..b7dafbc92f 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -227,6 +227,9 @@ def all(self) -> bool: def any(self) -> bool: # pragma: no cover return self._compliant.any() + def unique(self, *, maintain_order: bool = False) -> Self: # pragma: no cover + return type(self)(self._compliant.unique(maintain_order=maintain_order)) + class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 From c537650e5bb98df6a06e3d8030c2866001420aae Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 2 Dec 2025 14:04:55 +0000 Subject: [PATCH 117/215] feat: Add `Expr.sample`* Not trying to revive from deprecation --- narwhals/_plan/arrow/expr.py | 14 +++++++++ narwhals/_plan/compliant/expr.py | 8 +++++ narwhals/_plan/compliant/scalar.py | 4 +++ narwhals/_plan/expr.py | 15 +++++++++- narwhals/_plan/expressions/functions.py | 29 ++++++++++++++++++ tests/plan/sample_test.py | 39 ++++++++++++++++++++++++- 6 files changed, 107 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 66ebefb9b4..67f132c3ee 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -636,6 +636,20 @@ def gather_every(self, node: FExpr[F.GatherEvery], frame: Frame, name: str) -> S n, offset = node.function.n, node.function.offset return self.from_series(series.gather_every(n=n, offset=offset)) + def sample_n(self, node: FExpr[F.SampleN], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + func = node.function + n, replace, seed = func.n, func.with_replacement, func.seed + result = series.sample_n(n, with_replacement=replace, seed=seed) + return self.from_series(result) + + def sample_frac(self, node: FExpr[F.SampleFrac], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + func = node.function + fraction, replace, seed = func.fraction, func.with_replacement, func.seed + result = series.sample_frac(fraction, with_replacement=replace, seed=seed) + return self.from_series(result) + def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: return self._vector_function(fn.drop_nulls)(node, frame, name) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 849ba271e3..8f0fa2730b 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -300,6 +300,14 @@ def is_in_series( def map_batches( self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str ) -> Self | EagerScalar[FrameT_contra, SeriesT]: ... + # NOTE: `n=1` can behave similar to an aggregation in `select(...)`, but requires `.first()` + # to trigger broadcasting in `with_columns(...)` + def sample_n( + self, node: FunctionExpr[F.SampleN], frame: FrameT_contra, name: str + ) -> Self: ... + def sample_frac( + self, node: FunctionExpr[F.SampleFrac], frame: FrameT_contra, name: str + ) -> Self: ... def __bool__(self) -> Literal[True]: # NOTE: Avoids falling back to `__len__` when truth-testing on dispatch return True diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index 18a265b282..ca4562c4ab 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -155,6 +155,10 @@ def __bool__(self) -> Literal[True]: def to_python(self) -> PythonLiteral: ... gather_every = not_implemented() # type: ignore[misc] + # NOTE: `n=1` and `fraction=1.0` *could* be special-cased here + # but seems low-priority for a deprecated method + sample_n = not_implemented() # type: ignore[misc] + sample_frac = not_implemented() # type: ignore[misc] class LazyScalar( diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 2e055cb3d9..b757ea828e 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -24,6 +24,7 @@ SortOptions, rolling_options, ) +from narwhals._typing_compat import deprecated from narwhals._utils import Version, no_default, not_implemented from narwhals.exceptions import ComputeError @@ -419,6 +420,19 @@ def map_batches( ) ) + # TODO @dangotbanned: Come back to this when *properly* building out `Version` support + @deprecated("Use `v1.Expr.sample` or `{DataFrame,Series}.sample` instead") + def sample( + self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + f = F.sample(n, fraction=fraction, with_replacement=with_replacement, seed=seed) + return self._with_unary(f) + def any(self) -> Self: return self._with_unary(ir.boolean.Any()) @@ -626,7 +640,6 @@ def str(self) -> ExprStringNamespace: return ExprStringNamespace(_expr=self) is_close = not_implemented() - sample = not_implemented() head = not_implemented() tail = not_implemented() diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index 0b5c10b0d7..a87eb56ccd 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -214,3 +214,32 @@ def to_function_expr(self, *inputs: ExprIR) -> AnonymousExpr: options = self.function_options return AnonymousExpr(input=inputs, function=self, options=options) + + +class SampleN(Function): + __slots__ = ("n", "seed", "with_replacement") + n: int + with_replacement: bool + seed: int | None + + +class SampleFrac(Function): + __slots__ = ("fraction", "seed", "with_replacement") + fraction: float + with_replacement: bool + seed: int | None + + +def sample( + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, +) -> SampleFrac | SampleN: + if n is not None and fraction is not None: + msg = "cannot specify both `n` and `fraction`" + raise ValueError(msg) + if fraction is not None: + return SampleFrac(fraction=fraction, with_replacement=with_replacement, seed=seed) + return SampleN(n=1 if n is None else n, with_replacement=with_replacement, seed=seed) diff --git a/tests/plan/sample_test.py b/tests/plan/sample_test.py index 7d29423883..97f63228b0 100644 --- a/tests/plan/sample_test.py +++ b/tests/plan/sample_test.py @@ -1,13 +1,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import sys +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any import pytest +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs from narwhals.exceptions import ShapeError from tests.plan.utils import dataframe, series if TYPE_CHECKING: + from collections.abc import Callable + from tests.conftest import Data @@ -21,6 +27,14 @@ def data_big() -> Data: return {"a": list(range(100))} +if sys.version_info >= (3, 13): + # NOTE: (#2705) Would've added the handling for `category` + # The default triggers a warning, but only on `>=3.13` + deprecated_call: Callable[..., AbstractContextManager[Any]] = pytest.deprecated_call +else: # pragma: no cover + deprecated_call = nullcontext + + @pytest.mark.parametrize("n", [None, 1, 7, 29]) def test_sample_n_series(data: Data, n: int | None) -> None: result = series(data["a"]).sample(n).shape @@ -80,6 +94,25 @@ def test_sample_with_replacement_dataframe(data: Data, n: int) -> None: assert len(result) == n +@pytest.mark.parametrize( + ("base", "kwds", "expected"), + [ + (nwp.col("a"), {"n": 2}, (2, 1)), + (nwp.all(), {"n": 1}, (1, 2)), + (nwp.nth(1, 0), {}, (1, 2)), + (~ncs.string(), {"fraction": 0.5}, (15, 2)), + (ncs.last(), {"n": 75, "with_replacement": True, "seed": 99}, (75, 1)), + ], +) +def test_sample_expr( + data: Data, base: nwp.Expr, kwds: dict[str, Any], expected: tuple[int, int] +) -> None: + with deprecated_call(): + expr = base.sample(**kwds) + result = dataframe(data).select(expr).shape + assert result == expected + + def test_sample_invalid(data: Data) -> None: df = dataframe(data) ser = df.to_series() @@ -91,7 +124,11 @@ def test_sample_invalid(data: Data) -> None: df.sample(n=1, fraction=0.5) with pytest.raises(ValueError, match=both_n_fraction): ser.sample(n=567, fraction=0.1) + with pytest.raises(ValueError, match=both_n_fraction), deprecated_call(): + nwp.col("a").sample(n=30, fraction=0.3) with pytest.raises(ShapeError, match=too_high_n): df.sample(n=1_000) with pytest.raises(ShapeError, match=too_high_n): ser.sample(n=2_000) + with pytest.raises(ShapeError), deprecated_call(): + df.with_columns(nwp.col("b").sample(123, with_replacement=True)) From 4885cba614e3e19a25278f667354427ded00ac3c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 2 Dec 2025 17:52:45 +0000 Subject: [PATCH 118/215] test: Cover `str.replace(_all)` --- narwhals/_plan/expressions/strings.py | 10 +-- tests/plan/str_replace_test.py | 110 ++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 6 deletions(-) create mode 100644 tests/plan/str_replace_test.py diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 9b4b3ae043..d879a79107 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -97,12 +97,12 @@ class IRStringNamespace(IRNamespace): def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Replace: # pragma: no cover + ) -> Replace: return Replace(pattern=pattern, value=value, literal=literal, n=n) def replace_all( self, pattern: str, value: str, *, literal: bool = False - ) -> ReplaceAll: # pragma: no cover + ) -> ReplaceAll: return ReplaceAll(pattern=pattern, value=value, literal=literal) def strip_chars( @@ -140,13 +140,11 @@ def len_chars(self) -> Expr: # TODO @dangotbanned: Support `value: IntoExpr` def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Expr: # pragma: no cover + ) -> Expr: return self._with_unary(self._ir.replace(pattern, value, literal=literal, n=n)) # TODO @dangotbanned: Support `value: IntoExpr` - def replace_all( - self, pattern: str, value: str, *, literal: bool = False - ) -> Expr: # pragma: no cover + def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Expr: return self._with_unary(self._ir.replace_all(pattern, value, literal=literal)) def strip_chars(self, characters: str | None = None) -> Expr: # pragma: no cover diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py new file mode 100644 index 0000000000..c1bb3baa84 --- /dev/null +++ b/tests/plan/str_replace_test.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import Final + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +A1: Final = ["123abc", "abc456"] +A2: Final = ["abc abc", "abc456"] +A3: Final = ["abc abc abc", "456abc"] +A4: Final = ["Dollar $ign", "literal"] +B: Final = ["ghi", "jkl"] + + +replace_scalar = pytest.mark.parametrize( + ("data", "pattern", "value", "n", "literal", "expected"), + [ + (A1, r"abc\b", "ABC", 1, False, ["123ABC", "abc456"]), + (A2, r"abc", "", 1, False, [" abc", "456"]), + (A3, r"abc", "", -1, False, [" ", "456"]), + (A4, r"$", "S", -1, True, ["Dollar Sign", "literal"]), + ], +) +replace_vector = pytest.mark.parametrize( + ("data", "pattern", "value", "n", "literal", "expected"), + [ + (A1, r"abc", "b", 1, False, ["123ghi", "jkl456"]), + (A2, r"abc", "b", 1, False, ["ghi abc", "jkl456"]), + (A3, r"abc", "b", -1, False, ["ghi ghi ghi", "456jkl"]), + (A4, r"$", "b", -1, True, ["Dollar ghiign", "literal"]), + ], +) +replace_all_scalar = pytest.mark.parametrize( + ("data", "pattern", "value", "literal", "expected"), + [ + (A1, r"abc\b", "ABC", False, ["123ABC", "abc456"]), + (A2, r"abc", "", False, [" ", "456"]), + (A3, r"abc", "", False, [" ", "456"]), + (A4, r"$", "S", True, ["Dollar Sign", "literal"]), + ], +) +replace_all_vector = pytest.mark.parametrize( + ("data", "pattern", "value", "literal", "expected"), + [ + (A1, r"abc", "b", False, ["123ghi", "jkl456"]), + (A2, r"abc", "b", False, ["ghi ghi", "jkl456"]), + (A4, r"$", "b", True, ["Dollar ghiign", "literal"]), + ], +) + + +XFAIL_STR_REPLACE_EXPR = pytest.mark.xfail( + reason="`replace(_all)(value:Expr)` is not yet supported for `pyarrow`" +) + + +@replace_scalar +def test_str_replace_scalar( + data: list[str], + pattern: str, + value: str, + n: int, + *, + literal: bool, + expected: list[str], +) -> None: + df = dataframe({"a": data}) + result = df.select(nwp.col("a").str.replace(pattern, value, n=n, literal=literal)) + assert_equal_data(result, {"a": expected}) + + +@XFAIL_STR_REPLACE_EXPR +@replace_vector +def test_str_replace_vector( + data: list[str], + pattern: str, + value: str, + n: int, + *, + literal: bool, + expected: list[str], +) -> None: # pragma: no cover + df = dataframe({"a": data, "b": B}) + result = df.select( + nwp.col("a").str.replace(pattern, nwp.col(value), n=n, literal=literal) # type: ignore[arg-type] + ) + assert_equal_data(result, {"a": expected}) + + +@replace_all_scalar +def test_str_replace_all_scalar( + data: list[str], pattern: str, value: str, *, literal: bool, expected: list[str] +) -> None: + df = dataframe({"a": data}) + result = df.select(nwp.col("a").str.replace_all(pattern, value, literal=literal)) + assert_equal_data(result, {"a": expected}) + + +@XFAIL_STR_REPLACE_EXPR +@replace_all_vector +def test_str_replace_all_vector( + data: list[str], pattern: str, value: str, *, literal: bool, expected: list[str] +) -> None: # pragma: no cover + df = dataframe({"a": data, "b": B}) + result = df.select( + nwp.col("a").str.replace_all(pattern, nwp.col(value), literal=literal) # type: ignore[arg-type] + ) + assert_equal_data(result, {"a": expected}) From 9881c931a43f8d0f79caed72eba7d9890d557fd2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:07:54 +0000 Subject: [PATCH 119/215] feat: Support some forms of `str.replace(value: Expr)` - Scalar expressions (finished) - `n=1` (work-in-progress) --- narwhals/_plan/arrow/expr.py | 23 +++++++++ narwhals/_plan/arrow/functions.py | 69 ++++++++++++++++++++++++++- narwhals/_plan/compliant/accessors.py | 6 +++ narwhals/_plan/expressions/strings.py | 59 ++++++++++++++++++++--- tests/plan/str_replace_test.py | 67 +++++++++++++++++++------- 5 files changed, 197 insertions(+), 27 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 67f132c3ee..5c64b49e68 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -956,6 +956,28 @@ def replace( node, frame, name ) + def replace_expr( + self, node: FExpr[strings.ReplaceExpr], frame: Frame, name: str + ) -> Expr | Scalar: + func = node.function + pattern, literal, n = (func.pattern, func.literal, func.n) + expr, other = func.unwrap_input(node) + prev = expr.dispatch(self.compliant, frame, name) + value = other.dispatch(self.compliant, frame, name) + if isinstance(value, ArrowScalar): + result = fn.str_replace( + prev.native, pattern, value.native.as_py(), literal=literal, n=n + ) + elif isinstance(prev, ArrowExpr): + result = fn.str_replace_vector( + prev.native, pattern, value.native, literal=literal, n=n + ) + else: + # not sure this even makes sense + msg = "TODO: `ArrowScalar.str.replace(value: ArrowExpr)`" + raise NotImplementedError(msg) + return self.with_native(result, name) + def replace_all( self, node: FExpr[strings.ReplaceAll], frame: Frame, name: str ) -> Expr | Scalar: @@ -994,6 +1016,7 @@ def to_titlecase( to_date = not_implemented() to_datetime = not_implemented() + replace_all_expr = not_implemented() class ArrowStructNamespace( diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 9fc65705a7..6ae488b3e2 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -416,6 +416,48 @@ def str_replace( return fn(native, pattern, replacement=value, max_replacements=n) +# NOTE: Starting with the "easiest" cases first +def str_replace_vector( + native: ChunkedArrayAny, + pattern: str, + replacements: ChunkedArrayAny, + *, + literal: bool = False, + n: int = 1, +) -> ChunkedArrayAny: + if n == 1: + return _str_replace_vector_n_1(native, pattern, replacements, literal=literal) + msg = f"`pyarrow` currently only supports `str.replace(value: Expr, n=1)`, got {n=} " + raise NotImplementedError(msg) + + +# TODO @dangotbanned: Super in need of a tidy +def _str_replace_vector_n_1( + native: ChunkedArrayAny, + pattern: str, + replacements: ChunkedArrayAny, + *, + literal: bool = False, +) -> ChunkedArrayAny: + # NOTE: `-1` equals no match + fn_find = pc.find_substring if literal else pc.find_substring_regex + first_idx_match = fn_find(native, pattern=pattern) + failed = lit(-1) + has_match = not_eq(first_idx_match, failed) + if not any_(has_match).as_py(): + # fastpath, no work to do + return native + fn_split = pc.split_pattern if literal else pc.split_pattern_regex + table = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter(has_match) # pyright: ignore[reportArgumentType] + list_todo = fn_split(table.column(0).combine_chunks(), pattern, max_splits=1).values + mask_replace = eq(list_todo, lit("", list_todo.type)) + replaced_wrong_shape = replace_with_mask(list_todo, mask_replace, table.column(1)) + fully_replaced = concat_str(replaced_wrong_shape[0::2], replaced_wrong_shape[1::2]) + if all_(has_match).as_py(): + return chunked_array(fully_replaced) + return replace_with_mask(native, has_match, fully_replaced) + + def str_replace_all( native: Incomplete, pattern: str, value: str, *, literal: bool = False ) -> Incomplete: @@ -789,6 +831,17 @@ def replace_strict_default( return chunked_array(result) if isinstance(native, pa.ChunkedArray) else result[0] +def replace_with_mask( + native: ChunkedOrArrayT, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayT: + if not isinstance(mask, pa.BooleanArray): + mask = t.cast("pa.BooleanArray", array(mask)) + if not isinstance(replacements, pa.Array): + replacements = array(replacements) + result: ChunkedOrArrayT = pc.replace_with_mask(native, mask, replacements) + return result + + def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT], @@ -857,9 +910,21 @@ def binary( return _DISPATCH_BINARY[op](lhs, rhs) +@t.overload +def concat_str( + *arrays: ChunkedArrayAny, separator: str = ..., ignore_nulls: bool = ... +) -> ChunkedArray[StringScalar]: ... +@t.overload +def concat_str( + *arrays: ArrayAny, separator: str = ..., ignore_nulls: bool = ... +) -> Array[StringScalar]: ... +@t.overload +def concat_str( + *arrays: ScalarAny, separator: str = ..., ignore_nulls: bool = ... +) -> StringScalar: ... def concat_str( - *arrays: ChunkedArrayAny, separator: str = "", ignore_nulls: bool = False -) -> ChunkedArray[StringScalar]: + *arrays: ArrowAny, separator: str = "", ignore_nulls: bool = False +) -> Arrow[StringScalar]: dtype = string_type(obj.type for obj in arrays) it = (obj.cast(dtype) for obj in arrays) concat: Incomplete = pc.binary_join_element_wise diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 535b299dd8..b5cd5464de 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -47,6 +47,12 @@ def replace( def replace_all( self, node: FExpr[strings.ReplaceAll], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def replace_expr( + self, node: FExpr[strings.ReplaceExpr], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def replace_all_expr( + self, node: FExpr[strings.ReplaceAllExpr], frame: FrameT_contra, name: str + ) -> ExprT_co: ... def slice( self, node: FExpr[strings.Slice], frame: FrameT_contra, name: str ) -> ExprT_co: ... diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index d879a79107..0acba61e92 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -3,11 +3,15 @@ from typing import TYPE_CHECKING, ClassVar from narwhals._plan._function import Function, HorizontalFunction +from narwhals._plan._parse import parse_into_expr_ir from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import ExprIR, FunctionExpr as FExpr # fmt: off @@ -42,6 +46,21 @@ class Replace(StringFunction): n: int +# NOTE: Alternatively, do something like `list.contains` (always wrapping) +# There's a much bigger divide between backend-support though, so opting out is easier this way +class ReplaceExpr(StringFunction): + """N-ary (expr, value).""" + + def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, value = node.input + return expr, value + + __slots__ = ("literal", "n", "pattern") + pattern: str + literal: bool + n: int + + class ReplaceAll(StringFunction): __slots__ = ("literal", "pattern", "value") pattern: str @@ -49,6 +68,20 @@ class ReplaceAll(StringFunction): literal: bool +class ReplaceAllExpr(StringFunction): + """N-ary (expr, value).""" + + def unwrap_input( + self, node: FExpr[Self], / + ) -> tuple[ExprIR, ExprIR]: # pragma: no cover + expr, value = node.input + return expr, value + + __slots__ = ("literal", "pattern") + pattern: str + literal: bool + + class Slice(StringFunction): __slots__ = ("length", "offset") offset: int @@ -87,8 +120,8 @@ class ZFill(StringFunction, config=FEOptions.renamed("zfill")): class IRStringNamespace(IRNamespace): len_chars: ClassVar = LenChars - to_lowercase: ClassVar = ToUppercase - to_uppercase: ClassVar = ToLowercase + to_lowercase: ClassVar = ToLowercase + to_uppercase: ClassVar = ToUppercase to_titlecase: ClassVar = ToTitlecase split: ClassVar = Split starts_with: ClassVar = StartsWith @@ -139,13 +172,25 @@ def len_chars(self) -> Expr: # TODO @dangotbanned: Support `value: IntoExpr` def replace( - self, pattern: str, value: str, *, literal: bool = False, n: int = 1 + self, pattern: str, value: str | Expr, *, literal: bool = False, n: int = 1 ) -> Expr: - return self._with_unary(self._ir.replace(pattern, value, literal=literal, n=n)) + if isinstance(value, str): + return self._with_unary( + self._ir.replace(pattern, value, literal=literal, n=n) + ) + other = parse_into_expr_ir(value, str_as_lit=True) + replace = ReplaceExpr(pattern=pattern, literal=literal, n=n) + return self._expr._from_ir(replace.to_function_expr(self._expr._ir, other)) # TODO @dangotbanned: Support `value: IntoExpr` - def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Expr: - return self._with_unary(self._ir.replace_all(pattern, value, literal=literal)) + def replace_all( + self, pattern: str, value: str | Expr, *, literal: bool = False + ) -> Expr: + if isinstance(value, str): + return self._with_unary(self._ir.replace_all(pattern, value, literal=literal)) + other = parse_into_expr_ir(value, str_as_lit=True) + replace = ReplaceAllExpr(pattern=pattern, literal=literal) + return self._expr._from_ir(replace.to_function_expr(self._expr._ir, other)) def strip_chars(self, characters: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.strip_chars(characters)) @@ -180,7 +225,7 @@ def to_datetime(self, format: str | None = None) -> Expr: # pragma: no cover def to_lowercase(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_lowercase()) - def to_uppercase(self) -> Expr: # pragma: no cover + def to_uppercase(self) -> Expr: return self._with_unary(self._ir.to_uppercase()) def to_titlecase(self) -> Expr: # pragma: no cover diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py index c1bb3baa84..edcf44644c 100644 --- a/tests/plan/str_replace_test.py +++ b/tests/plan/str_replace_test.py @@ -12,7 +12,14 @@ A3: Final = ["abc abc abc", "456abc"] A4: Final = ["Dollar $ign", "literal"] B: Final = ["ghi", "jkl"] - +XFAIL_STR_REPLACE_EXPR = pytest.mark.xfail( + reason="`replace(value:Expr, n>1)` is not yet supported for `pyarrow`", + raises=NotImplementedError, +) +XFAIL_STR_REPLACE_ALL_EXPR = pytest.mark.xfail( + reason="`replace_all(value:Expr)` is not yet supported for `pyarrow`", + raises=NotImplementedError, +) replace_scalar = pytest.mark.parametrize( ("data", "pattern", "value", "n", "literal", "expected"), @@ -26,10 +33,42 @@ replace_vector = pytest.mark.parametrize( ("data", "pattern", "value", "n", "literal", "expected"), [ - (A1, r"abc", "b", 1, False, ["123ghi", "jkl456"]), - (A2, r"abc", "b", 1, False, ["ghi abc", "jkl456"]), - (A3, r"abc", "b", -1, False, ["ghi ghi ghi", "456jkl"]), - (A4, r"$", "b", -1, True, ["Dollar ghiign", "literal"]), + (A1, r"abc", nwp.col("b"), 1, False, ["123ghi", "jkl456"]), + (A2, r"abc", nwp.col("b"), 1, False, ["ghi abc", "jkl456"]), + pytest.param( + A3, + r"abc", + nwp.col("b"), + -1, + False, + ["ghi ghi ghi", "456jkl"], + marks=XFAIL_STR_REPLACE_EXPR, + ), + pytest.param( + A4, + r"$", + nwp.col("b"), + -1, + True, + ["Dollar ghiign", "literal"], + marks=XFAIL_STR_REPLACE_EXPR, + ), + ( + ["dogcatdogcat", "dog dog"], + "cat", + nwp.col("b").last(), + 1, + True, + ["dogjkldogcat", "dog dog"], + ), + ( + A3, + r"^abc", + nwp.col("b").str.to_uppercase(), + 1, + False, + ["GHI abc abc", "456abc"], + ), ], ) replace_all_scalar = pytest.mark.parametrize( @@ -51,11 +90,6 @@ ) -XFAIL_STR_REPLACE_EXPR = pytest.mark.xfail( - reason="`replace(_all)(value:Expr)` is not yet supported for `pyarrow`" -) - - @replace_scalar def test_str_replace_scalar( data: list[str], @@ -71,21 +105,18 @@ def test_str_replace_scalar( assert_equal_data(result, {"a": expected}) -@XFAIL_STR_REPLACE_EXPR @replace_vector def test_str_replace_vector( data: list[str], pattern: str, - value: str, + value: nwp.Expr, n: int, *, literal: bool, expected: list[str], -) -> None: # pragma: no cover +) -> None: df = dataframe({"a": data, "b": B}) - result = df.select( - nwp.col("a").str.replace(pattern, nwp.col(value), n=n, literal=literal) # type: ignore[arg-type] - ) + result = df.select(nwp.col("a").str.replace(pattern, value, n=n, literal=literal)) assert_equal_data(result, {"a": expected}) @@ -98,13 +129,13 @@ def test_str_replace_all_scalar( assert_equal_data(result, {"a": expected}) -@XFAIL_STR_REPLACE_EXPR +@XFAIL_STR_REPLACE_ALL_EXPR @replace_all_vector def test_str_replace_all_vector( data: list[str], pattern: str, value: str, *, literal: bool, expected: list[str] ) -> None: # pragma: no cover df = dataframe({"a": data, "b": B}) result = df.select( - nwp.col("a").str.replace_all(pattern, nwp.col(value), literal=literal) # type: ignore[arg-type] + nwp.col("a").str.replace_all(pattern, nwp.col(value), literal=literal) ) assert_equal_data(result, {"a": expected}) From 44a9d1e4575b7d369ee1ba5f5edeac00756dc39a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 14:02:50 +0000 Subject: [PATCH 120/215] feat(DRAFT): Pull more out of `str_replace` `splitn` and `find` can be aligned closer with `polars`, so I've left notes for what that needs --- narwhals/_plan/arrow/functions.py | 136 +++++++++++++++++++++++++----- narwhals/_plan/arrow/options.py | 14 +++ tests/plan/str_replace_test.py | 27 ++++-- 3 files changed, 150 insertions(+), 27 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 6ae488b3e2..6a3f54da01 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -386,21 +386,118 @@ def str_pad_start( return pc.utf8_lpad(native, length, fill_char) +@t.overload +def str_find( + native: ChunkedArrayAny, + pattern: str, + *, + literal: bool = ..., + not_found: int | None = ..., +) -> ChunkedArray[IntegerScalar]: ... +@t.overload +def str_find( + native: Array, pattern: str, *, literal: bool = ..., not_found: int | None = ... +) -> Array[IntegerScalar]: ... +@t.overload +def str_find( + native: ScalarAny, pattern: str, *, literal: bool = ..., not_found: int | None = ... +) -> IntegerScalar: ... +def str_find( + native: Arrow[StringScalar], + pattern: str, + *, + literal: bool = False, + not_found: int | None = -1, +) -> Arrow[IntegerScalar]: + """Return the bytes offset of the first substring matching a pattern. + + To match `pl.Expr.str.find` behavior, pass `not_found=None`. + + Note: + `pyarrow` distinguishes null *inputs* with `None` and failed matches with `-1`. + """ + # NOTE: `pyarrow-stubs` uses concrete types here + fn_name = "find_substring" if literal else "find_substring_regex" + result: Arrow[IntegerScalar] = pc.call_function( + fn_name, [native], pa_options.match_substring(pattern) + ) + if not_found == -1: + return result + return when_then(eq(result, lit(-1)), lit(not_found, result.type), result) + + _StringFunction0: TypeAlias = "Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny]" _StringFunction1: TypeAlias = "Callable[[ChunkedOrScalarAny, str], ChunkedOrScalarAny]" str_starts_with = t.cast("_StringFunction1", pc.starts_with) str_ends_with = t.cast("_StringFunction1", pc.ends_with) -str_split = t.cast("_StringFunction1", pc.split_pattern) str_to_uppercase = t.cast("_StringFunction0", pc.utf8_upper) str_to_lowercase = t.cast("_StringFunction0", pc.utf8_lower) str_to_titlecase = t.cast("_StringFunction0", pc.utf8_title) +def _str_split( + native: ArrowAny, by: str, n: int | None = None, *, literal: bool = True +) -> Arrow[ListScalar]: + name = "split_pattern" if literal else "split_pattern_regex" + result: Arrow[ListScalar] = pc.call_function( + name, [native], pa_options.split_pattern(by, n) + ) + return result + + +@t.overload +def str_split( + native: ChunkedOrScalarAny, by: str, *, literal: bool = ... +) -> ChunkedOrScalar[ListScalar]: ... +@t.overload +def str_split(native: ArrayAny, by: str, *, literal: bool = ...) -> pa.ListArray[Any]: ... +@t.overload +def str_split(native: ArrowAny, by: str, *, literal: bool = ...) -> Arrow[ListScalar]: ... +def str_split(native: ArrowAny, by: str, *, literal: bool = True) -> Arrow[ListScalar]: + return _str_split(native, by, literal=literal) + + +@t.overload +def str_splitn( + native: ArrayAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... +) -> pa.ListArray[Any]: ... +@t.overload +def str_splitn( + native: ArrowAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... +) -> Arrow[ListScalar]: ... +def str_splitn( + native: ArrowAny, by: str, n: int, *, literal: bool = True, as_struct: bool = False +) -> Arrow[ListScalar]: + """Split the string by a substring, restricted to returning at most `n` items.""" + result = _str_split(native, by, n, literal=literal) + if as_struct: + # NOTE: `polars` would return a struct w/ field names (`'field_0`, ..., 'field_n-1`) + msg = "TODO: `ArrowExpr.str.splitn`" + raise NotImplementedError(msg) + return result + + +@t.overload def str_contains( - native: Incomplete, pattern: str, *, literal: bool = False -) -> Incomplete: - func = pc.match_substring if literal else pc.match_substring_regex - return func(native, pattern) + native: ChunkedArrayAny, pattern: str, *, literal: bool = ... +) -> ChunkedArray[pa.BooleanScalar]: ... +@t.overload +def str_contains( + native: ChunkedOrScalarAny, pattern: str, *, literal: bool = ... +) -> ChunkedOrScalar[pa.BooleanScalar]: ... +@t.overload +def str_contains( + native: ArrowAny, pattern: str, *, literal: bool = ... +) -> Arrow[pa.BooleanScalar]: ... +def str_contains( + native: ArrowAny, pattern: str, *, literal: bool = False +) -> Arrow[pa.BooleanScalar]: + """Check if the string contains a substring that matches a pattern.""" + name = "match_substring" if literal else "match_substring_regex" + result: Arrow[pa.BooleanScalar] = pc.call_function( + name, [native], pa_options.match_substring(pattern) + ) + return result def str_strip_chars(native: Incomplete, characters: str | None) -> Incomplete: @@ -439,21 +536,18 @@ def _str_replace_vector_n_1( *, literal: bool = False, ) -> ChunkedArrayAny: - # NOTE: `-1` equals no match - fn_find = pc.find_substring if literal else pc.find_substring_regex - first_idx_match = fn_find(native, pattern=pattern) - failed = lit(-1) - has_match = not_eq(first_idx_match, failed) + has_match = str_contains(native, pattern, literal=literal) if not any_(has_match).as_py(): # fastpath, no work to do return native - fn_split = pc.split_pattern if literal else pc.split_pattern_regex - table = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter(has_match) # pyright: ignore[reportArgumentType] - list_todo = fn_split(table.column(0).combine_chunks(), pattern, max_splits=1).values + table = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter(has_match) + # Needs better name + list_todo = str_splitn(array(table.column(0)), pattern, n=2, literal=literal).values mask_replace = eq(list_todo, lit("", list_todo.type)) + # Needs better name replaced_wrong_shape = replace_with_mask(list_todo, mask_replace, table.column(1)) fully_replaced = concat_str(replaced_wrong_shape[0::2], replaced_wrong_shape[1::2]) - if all_(has_match).as_py(): + if all_(has_match, ignore_nulls=False).as_py(): return chunked_array(fully_replaced) return replace_with_mask(native, has_match, fully_replaced) @@ -509,11 +603,13 @@ def _str_zfill_compat( @t.overload def when_then( - predicate: Predicate, then: SameArrowT, otherwise: SameArrowT + predicate: Predicate, then: SameArrowT, otherwise: SameArrowT | None ) -> SameArrowT: ... @t.overload +def when_then(predicate: Predicate, then: ScalarAny, otherwise: ArrowT) -> ArrowT: ... +@t.overload def when_then( - predicate: Predicate, then: ArrowT, otherwise: NonNestedLiteral = ... + predicate: Predicate, then: ArrowT, otherwise: ScalarAny | NonNestedLiteral = ... ) -> ArrowT: ... @t.overload def when_then( @@ -527,12 +623,12 @@ def when_then( return pc.if_else(predicate, then, otherwise) -def any_(native: Incomplete) -> pa.BooleanScalar: - return pc.any(native, min_count=0) +def any_(native: Incomplete, *, ignore_nulls: bool = True) -> pa.BooleanScalar: + return pc.any(native, min_count=0, skip_nulls=ignore_nulls) -def all_(native: Incomplete) -> pa.BooleanScalar: - return pc.all(native, min_count=0) +def all_(native: Incomplete, *, ignore_nulls: bool = True) -> pa.BooleanScalar: + return pc.all(native, min_count=0, skip_nulls=ignore_nulls) def sum_(native: Incomplete) -> NativeScalar: diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 5e2f328bbe..3d44487bc7 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -133,6 +133,20 @@ def rank( ) +def match_substring(pattern: str) -> pc.MatchSubstringOptions: + return pc.MatchSubstringOptions(pattern) + + +def split_pattern(by: str, n: int | None = None) -> pc.SplitPatternOptions: + """Similar to `str.splitn`. + + Some glue for `max_splits=n - 1` + """ + if n is not None: + return pc.SplitPatternOptions(by, max_splits=n - 1) + return pc.SplitPatternOptions(by) + + def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: from narwhals._plan.expressions import aggregation as agg diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py index edcf44644c..ff4b8f075b 100644 --- a/tests/plan/str_replace_test.py +++ b/tests/plan/str_replace_test.py @@ -1,16 +1,20 @@ from __future__ import annotations -from typing import Final +from typing import TYPE_CHECKING, Final import pytest import narwhals._plan as nwp from tests.plan.utils import assert_equal_data, dataframe +if TYPE_CHECKING: + from collections.abc import Sequence + A1: Final = ["123abc", "abc456"] A2: Final = ["abc abc", "abc456"] A3: Final = ["abc abc abc", "456abc"] A4: Final = ["Dollar $ign", "literal"] +A5: Final = [None, "oop"] B: Final = ["ghi", "jkl"] XFAIL_STR_REPLACE_EXPR = pytest.mark.xfail( reason="`replace(value:Expr, n>1)` is not yet supported for `pyarrow`", @@ -33,8 +37,12 @@ replace_vector = pytest.mark.parametrize( ("data", "pattern", "value", "n", "literal", "expected"), [ - (A1, r"abc", nwp.col("b"), 1, False, ["123ghi", "jkl456"]), - (A2, r"abc", nwp.col("b"), 1, False, ["ghi abc", "jkl456"]), + pytest.param( + A1, r"abc", nwp.col("b"), 1, False, ["123ghi", "jkl456"], id="n-1-single" + ), + pytest.param( + A2, r"abc", nwp.col("b"), 1, False, ["ghi abc", "jkl456"], id="n-1-mixed" + ), pytest.param( A3, r"abc", @@ -43,6 +51,7 @@ False, ["ghi ghi ghi", "456jkl"], marks=XFAIL_STR_REPLACE_EXPR, + id="replace_all", ), pytest.param( A4, @@ -52,23 +61,27 @@ True, ["Dollar ghiign", "literal"], marks=XFAIL_STR_REPLACE_EXPR, + id="literal-replace_all", ), - ( + pytest.param( ["dogcatdogcat", "dog dog"], "cat", nwp.col("b").last(), 1, True, ["dogjkldogcat", "dog dog"], + id="agg-replacement", ), - ( + pytest.param( A3, r"^abc", nwp.col("b").str.to_uppercase(), 1, False, ["GHI abc abc", "456abc"], + id="transformed-replacement", ), + pytest.param(A5, r"o", nwp.col("b"), 1, False, [None, "jklop"], id="null-input"), ], ) replace_all_scalar = pytest.mark.parametrize( @@ -107,13 +120,13 @@ def test_str_replace_scalar( @replace_vector def test_str_replace_vector( - data: list[str], + data: Sequence[str | None], pattern: str, value: nwp.Expr, n: int, *, literal: bool, - expected: list[str], + expected: Sequence[str | None], ) -> None: df = dataframe({"a": data, "b": B}) result = df.select(nwp.col("a").str.replace(pattern, value, n=n, literal=literal)) From 0fea5b580ce625fa80c81c49eb8c02c30eb5a957 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 14:04:08 +0000 Subject: [PATCH 121/215] cov --- tests/plan/sample_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plan/sample_test.py b/tests/plan/sample_test.py index 97f63228b0..5491a4189b 100644 --- a/tests/plan/sample_test.py +++ b/tests/plan/sample_test.py @@ -27,7 +27,7 @@ def data_big() -> Data: return {"a": list(range(100))} -if sys.version_info >= (3, 13): +if sys.version_info >= (3, 13): # pragma: no cover # NOTE: (#2705) Would've added the handling for `category` # The default triggers a warning, but only on `>=3.13` deprecated_call: Callable[..., AbstractContextManager[Any]] = pytest.deprecated_call From 91cfa33e4b60e1c62806b6dc6e2d4e9abd351b34 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:35:06 +0000 Subject: [PATCH 122/215] =?UTF-8?q?feat:=20Support=20`str.replace(value:Ex?= =?UTF-8?q?pr,=20n=3D-1)`=20as=20well=20=F0=9F=A5=B3=F0=9F=A5=B3?= =?UTF-8?q?=F0=9F=A5=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit aka `replace_all`, need to link that up now - but quite happy w/ how clean it's looking --- narwhals/_plan/arrow/functions.py | 75 ++++++++++++++++++++++++--- narwhals/_plan/expressions/strings.py | 5 +- tests/plan/str_replace_test.py | 7 +-- 3 files changed, 73 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 6a3f54da01..228bbfdc2e 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -367,6 +367,37 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: return result +@t.overload +def list_join( + native: ChunkedList[StringType], separator: Arrow[StringScalar] | str +) -> ChunkedArray[StringScalar]: ... +@t.overload +def list_join( + native: ListArray[StringType], separator: Arrow[StringScalar] | str +) -> pa.StringArray: ... +@t.overload +def list_join( + native: ListScalar[StringType], separator: Arrow[StringScalar] | str +) -> pa.StringScalar: ... +def list_join(native: ArrowAny, separator: Arrow[StringScalar] | str) -> ArrowAny: + """Join all string items in a sublist and place a separator between them. + + Each list of values in the first input is joined using each second input as separator. + If any input list is null or contains a null, the corresponding output will be null. + """ + return pc.binary_join(native, separator) + + +def str_join(native: Arrow[StringScalar], separator: str) -> StringScalar: + """Vertically concatenate the string values in the column to a single string value.""" + if isinstance(native, pa.Scalar): + # already joined + return native + offsets = [0, len(native)] + scalar = pa.ListArray.from_arrays(offsets, array(native))[0] + return list_join(scalar, separator) + + def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: len_chars: Incomplete = pc.utf8_length result: ChunkedOrScalarAny = len_chars(native) @@ -524,11 +555,13 @@ def str_replace_vector( ) -> ChunkedArrayAny: if n == 1: return _str_replace_vector_n_1(native, pattern, replacements, literal=literal) + if n == -1: + return _str_replace_all_vector(native, pattern, replacements, literal=literal) msg = f"`pyarrow` currently only supports `str.replace(value: Expr, n=1)`, got {n=} " raise NotImplementedError(msg) -# TODO @dangotbanned: Super in need of a tidy +# TODO @dangotbanned: Is this worth keeping, now that `replace_all` is simpler? def _str_replace_vector_n_1( native: ChunkedArrayAny, pattern: str, @@ -540,18 +573,48 @@ def _str_replace_vector_n_1( if not any_(has_match).as_py(): # fastpath, no work to do return native - table = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter(has_match) - # Needs better name - list_todo = str_splitn(array(table.column(0)), pattern, n=2, literal=literal).values - mask_replace = eq(list_todo, lit("", list_todo.type)) + tbl_matches = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter( + has_match + ) + list_split_by = str_splitn( + array(tbl_matches.column(0)), pattern, n=2, literal=literal + ) + list_flat = list_split_by.values + needs_replacing = eq(list_flat, lit("", list_flat.type)) # Needs better name - replaced_wrong_shape = replace_with_mask(list_todo, mask_replace, table.column(1)) + replaced_wrong_shape = replace_with_mask( + list_flat, needs_replacing, tbl_matches.column(1) + ) fully_replaced = concat_str(replaced_wrong_shape[0::2], replaced_wrong_shape[1::2]) if all_(has_match, ignore_nulls=False).as_py(): return chunked_array(fully_replaced) return replace_with_mask(native, has_match, fully_replaced) +# TODO @dangotbanned: Link `str.replace_all` up to this +# TODO @dangotbanned: Share more with `n=1` +def _str_replace_all_vector( + native: ChunkedArrayAny, + pattern: str, + replacements: ChunkedArrayAny, + *, + literal: bool = False, +) -> ChunkedArrayAny: + has_match = str_contains(native, pattern, literal=literal) + if not any_(has_match).as_py(): + # fastpath, no work to do + return native + tbl_matches = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter( + has_match + ) + # here we can have unequal-length lists + list_split_by = str_split(array(tbl_matches.column(0)), pattern, literal=literal) + fully_replaced = list_join(list_split_by, tbl_matches.column(1)) + if all_(has_match, ignore_nulls=False).as_py(): + return chunked_array(fully_replaced) + return replace_with_mask(native, has_match, fully_replaced) + + def str_replace_all( native: Incomplete, pattern: str, value: str, *, literal: bool = False ) -> Incomplete: diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 0acba61e92..678154395e 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -46,8 +46,8 @@ class Replace(StringFunction): n: int -# NOTE: Alternatively, do something like `list.contains` (always wrapping) -# There's a much bigger divide between backend-support though, so opting out is easier this way +# TODO @dangotbanned: Undo the (`Expr`) split and just have `Replace` +# This needs to handle scalars *anyway*, so no point in separating class ReplaceExpr(StringFunction): """N-ary (expr, value).""" @@ -68,6 +68,7 @@ class ReplaceAll(StringFunction): literal: bool +# TODO @dangotbanned: Undo the (`Expr`) split and just have `ReplaceAll` class ReplaceAllExpr(StringFunction): """N-ary (expr, value).""" diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py index ff4b8f075b..0f38f0d58a 100644 --- a/tests/plan/str_replace_test.py +++ b/tests/plan/str_replace_test.py @@ -16,10 +16,7 @@ A4: Final = ["Dollar $ign", "literal"] A5: Final = [None, "oop"] B: Final = ["ghi", "jkl"] -XFAIL_STR_REPLACE_EXPR = pytest.mark.xfail( - reason="`replace(value:Expr, n>1)` is not yet supported for `pyarrow`", - raises=NotImplementedError, -) + XFAIL_STR_REPLACE_ALL_EXPR = pytest.mark.xfail( reason="`replace_all(value:Expr)` is not yet supported for `pyarrow`", raises=NotImplementedError, @@ -50,7 +47,6 @@ -1, False, ["ghi ghi ghi", "456jkl"], - marks=XFAIL_STR_REPLACE_EXPR, id="replace_all", ), pytest.param( @@ -60,7 +56,6 @@ -1, True, ["Dollar ghiign", "literal"], - marks=XFAIL_STR_REPLACE_EXPR, id="literal-replace_all", ), pytest.param( From 60b643349c9a9dee40eec96a800f99dc43195664 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:47:16 +0000 Subject: [PATCH 123/215] factor-in `ReplaceExpr` --- narwhals/_plan/arrow/expr.py | 10 ---------- narwhals/_plan/compliant/accessors.py | 3 --- narwhals/_plan/expressions/strings.py | 23 +++-------------------- 3 files changed, 3 insertions(+), 33 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 5c64b49e68..e832fd772f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -948,16 +948,6 @@ def ends_with( def replace( self, node: FExpr[strings.Replace], frame: Frame, name: str - ) -> Expr | Scalar: - func = node.function - pattern, value, literal, n = (func.pattern, func.value, func.literal, func.n) - replace = fn.str_replace - return self.unary(replace, pattern, value, literal=literal, n=n)( - node, frame, name - ) - - def replace_expr( - self, node: FExpr[strings.ReplaceExpr], frame: Frame, name: str ) -> Expr | Scalar: func = node.function pattern, literal, n = (func.pattern, func.literal, func.n) diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index b5cd5464de..0db0058923 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -47,9 +47,6 @@ def replace( def replace_all( self, node: FExpr[strings.ReplaceAll], frame: FrameT_contra, name: str ) -> ExprT_co: ... - def replace_expr( - self, node: FExpr[strings.ReplaceExpr], frame: FrameT_contra, name: str - ) -> ExprT_co: ... def replace_all_expr( self, node: FExpr[strings.ReplaceAllExpr], frame: FrameT_contra, name: str ) -> ExprT_co: ... diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 678154395e..a9d15bd678 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -39,16 +39,6 @@ class EndsWith(StringFunction): class Replace(StringFunction): - __slots__ = ("literal", "n", "pattern", "value") - pattern: str - value: str - literal: bool - n: int - - -# TODO @dangotbanned: Undo the (`Expr`) split and just have `Replace` -# This needs to handle scalars *anyway*, so no point in separating -class ReplaceExpr(StringFunction): """N-ary (expr, value).""" def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: @@ -129,10 +119,8 @@ class IRStringNamespace(IRNamespace): ends_with: ClassVar = EndsWith zfill: ClassVar = ZFill - def replace( - self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Replace: - return Replace(pattern=pattern, value=value, literal=literal, n=n) + def replace(self, pattern: str, *, literal: bool = False, n: int = 1) -> Replace: + return Replace(pattern=pattern, literal=literal, n=n) def replace_all( self, pattern: str, value: str, *, literal: bool = False @@ -171,16 +159,11 @@ def _ir_namespace(self) -> type[IRStringNamespace]: def len_chars(self) -> Expr: return self._with_unary(self._ir.len_chars()) - # TODO @dangotbanned: Support `value: IntoExpr` def replace( self, pattern: str, value: str | Expr, *, literal: bool = False, n: int = 1 ) -> Expr: - if isinstance(value, str): - return self._with_unary( - self._ir.replace(pattern, value, literal=literal, n=n) - ) other = parse_into_expr_ir(value, str_as_lit=True) - replace = ReplaceExpr(pattern=pattern, literal=literal, n=n) + replace = self._ir.replace(pattern, literal=literal, n=n) return self._expr._from_ir(replace.to_function_expr(self._expr._ir, other)) # TODO @dangotbanned: Support `value: IntoExpr` From 41634b646fd3e65246e5fae2d46e3834957f6ca3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:09:49 +0000 Subject: [PATCH 124/215] feat: Support `str.replace_all(value: Expr)` Able to pass all of the original tests now, but not satisfied with their coverage yet --- narwhals/_plan/arrow/expr.py | 11 +++++------ narwhals/_plan/arrow/functions.py | 14 ++++++-------- narwhals/_plan/compliant/accessors.py | 3 --- narwhals/_plan/expressions/strings.py | 22 ++++++---------------- tests/plan/str_replace_test.py | 9 +++------ 5 files changed, 20 insertions(+), 39 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index e832fd772f..3a8b77bc48 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -7,7 +7,7 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import narwhals_to_native_dtype -from narwhals._plan import expressions as ir +from narwhals._plan import common, expressions as ir from narwhals._plan._guards import ( is_function_expr, is_iterable_reject, @@ -971,10 +971,10 @@ def replace( def replace_all( self, node: FExpr[strings.ReplaceAll], frame: Frame, name: str ) -> Expr | Scalar: - func = node.function - pattern, value, literal = (func.pattern, func.value, func.literal) - replace = fn.str_replace_all - return self.unary(replace, pattern, value, literal=literal)(node, frame, name) + rewrite: FExpr[Any] = common.replace( + node, function=node.function.to_replace_n(-1) + ) + return self.replace(rewrite, frame, name) def split(self, node: FExpr[strings.Split], frame: Frame, name: str) -> Expr | Scalar: return self.unary(fn.str_split, node.function.by)(node, frame, name) @@ -1006,7 +1006,6 @@ def to_titlecase( to_date = not_implemented() to_datetime = not_implemented() - replace_all_expr = not_implemented() class ArrowStructNamespace( diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 228bbfdc2e..f18ca20c2e 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -544,7 +544,12 @@ def str_replace( return fn(native, pattern, replacement=value, max_replacements=n) -# NOTE: Starting with the "easiest" cases first +def str_replace_all( + native: Incomplete, pattern: str, value: str, *, literal: bool = False +) -> Incomplete: + return str_replace(native, pattern, value, literal=literal, n=-1) + + def str_replace_vector( native: ChunkedArrayAny, pattern: str, @@ -591,7 +596,6 @@ def _str_replace_vector_n_1( return replace_with_mask(native, has_match, fully_replaced) -# TODO @dangotbanned: Link `str.replace_all` up to this # TODO @dangotbanned: Share more with `n=1` def _str_replace_all_vector( native: ChunkedArrayAny, @@ -615,12 +619,6 @@ def _str_replace_all_vector( return replace_with_mask(native, has_match, fully_replaced) -def str_replace_all( - native: Incomplete, pattern: str, value: str, *, literal: bool = False -) -> Incomplete: - return str_replace(native, pattern, value, literal=literal, n=-1) - - def str_zfill(native: ChunkedOrScalarAny, length: int) -> ChunkedOrScalarAny: if HAS_ZFILL: zfill: Incomplete = pc.utf8_zero_fill # type: ignore[attr-defined] diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 0db0058923..535b299dd8 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -47,9 +47,6 @@ def replace( def replace_all( self, node: FExpr[strings.ReplaceAll], frame: FrameT_contra, name: str ) -> ExprT_co: ... - def replace_all_expr( - self, node: FExpr[strings.ReplaceAllExpr], frame: FrameT_contra, name: str - ) -> ExprT_co: ... def slice( self, node: FExpr[strings.Slice], frame: FrameT_contra, name: str ) -> ExprT_co: ... diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index a9d15bd678..103c6bb1f4 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -52,14 +52,6 @@ def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: class ReplaceAll(StringFunction): - __slots__ = ("literal", "pattern", "value") - pattern: str - value: str - literal: bool - - -# TODO @dangotbanned: Undo the (`Expr`) split and just have `ReplaceAll` -class ReplaceAllExpr(StringFunction): """N-ary (expr, value).""" def unwrap_input( @@ -68,6 +60,9 @@ def unwrap_input( expr, value = node.input return expr, value + def to_replace_n(self, n: int) -> Replace: + return Replace(pattern=self.pattern, literal=self.literal, n=n) + __slots__ = ("literal", "pattern") pattern: str literal: bool @@ -122,10 +117,8 @@ class IRStringNamespace(IRNamespace): def replace(self, pattern: str, *, literal: bool = False, n: int = 1) -> Replace: return Replace(pattern=pattern, literal=literal, n=n) - def replace_all( - self, pattern: str, value: str, *, literal: bool = False - ) -> ReplaceAll: - return ReplaceAll(pattern=pattern, value=value, literal=literal) + def replace_all(self, pattern: str, *, literal: bool = False) -> ReplaceAll: + return ReplaceAll(pattern=pattern, literal=literal) def strip_chars( self, characters: str | None = None @@ -166,14 +159,11 @@ def replace( replace = self._ir.replace(pattern, literal=literal, n=n) return self._expr._from_ir(replace.to_function_expr(self._expr._ir, other)) - # TODO @dangotbanned: Support `value: IntoExpr` def replace_all( self, pattern: str, value: str | Expr, *, literal: bool = False ) -> Expr: - if isinstance(value, str): - return self._with_unary(self._ir.replace_all(pattern, value, literal=literal)) other = parse_into_expr_ir(value, str_as_lit=True) - replace = ReplaceAllExpr(pattern=pattern, literal=literal) + replace = self._ir.replace_all(pattern, literal=literal) return self._expr._from_ir(replace.to_function_expr(self._expr._ir, other)) def strip_chars(self, characters: str | None = None) -> Expr: # pragma: no cover diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py index 0f38f0d58a..2bdc93bb89 100644 --- a/tests/plan/str_replace_test.py +++ b/tests/plan/str_replace_test.py @@ -17,10 +17,6 @@ A5: Final = [None, "oop"] B: Final = ["ghi", "jkl"] -XFAIL_STR_REPLACE_ALL_EXPR = pytest.mark.xfail( - reason="`replace_all(value:Expr)` is not yet supported for `pyarrow`", - raises=NotImplementedError, -) replace_scalar = pytest.mark.parametrize( ("data", "pattern", "value", "n", "literal", "expected"), @@ -88,6 +84,8 @@ (A4, r"$", "S", True, ["Dollar Sign", "literal"]), ], ) + +# TODO @dangotbanned: Cover more than these cases, it's just a repeat of `-1` replace_all_vector = pytest.mark.parametrize( ("data", "pattern", "value", "literal", "expected"), [ @@ -137,11 +135,10 @@ def test_str_replace_all_scalar( assert_equal_data(result, {"a": expected}) -@XFAIL_STR_REPLACE_ALL_EXPR @replace_all_vector def test_str_replace_all_vector( data: list[str], pattern: str, value: str, *, literal: bool, expected: list[str] -) -> None: # pragma: no cover +) -> None: df = dataframe({"a": data, "b": B}) result = df.select( nwp.col("a").str.replace_all(pattern, nwp.col(value), literal=literal) From 3b03f2af54f68dc7be0e16e06e5995af5c3bd415 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:12:37 +0000 Subject: [PATCH 125/215] partially handle `ignore_nulls` in `{list,str}_join` --- narwhals/_plan/arrow/functions.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index f18ca20c2e..0e1e58f172 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -369,30 +369,49 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: @t.overload def list_join( - native: ChunkedList[StringType], separator: Arrow[StringScalar] | str + native: ChunkedList[StringType], + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = ..., ) -> ChunkedArray[StringScalar]: ... @t.overload def list_join( - native: ListArray[StringType], separator: Arrow[StringScalar] | str + native: ListArray[StringType], + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = ..., ) -> pa.StringArray: ... @t.overload def list_join( - native: ListScalar[StringType], separator: Arrow[StringScalar] | str + native: ListScalar[StringType], + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = ..., ) -> pa.StringScalar: ... -def list_join(native: ArrowAny, separator: Arrow[StringScalar] | str) -> ArrowAny: +def list_join( + native: ArrowAny, separator: Arrow[StringScalar] | str, *, ignore_nulls: bool = False +) -> ArrowAny: """Join all string items in a sublist and place a separator between them. Each list of values in the first input is joined using each second input as separator. If any input list is null or contains a null, the corresponding output will be null. """ + if ignore_nulls: + # NOTE: `polars` default is `True`, will need to handle that if this becomes api + msg = "TODO: `ArrowExpr.list.join(ignore_nulls=True)`" + raise NotImplementedError(msg) return pc.binary_join(native, separator) -def str_join(native: Arrow[StringScalar], separator: str) -> StringScalar: +def str_join( + native: Arrow[StringScalar], separator: str, *, ignore_nulls: bool = True +) -> StringScalar: """Vertically concatenate the string values in the column to a single string value.""" if isinstance(native, pa.Scalar): # already joined return native + if ignore_nulls and native.null_count: + native = native.drop_null() offsets = [0, len(native)] scalar = pa.ListArray.from_arrays(offsets, array(native))[0] return list_join(scalar, separator) From 2786a09b0ebd66c8a05cbfa55a409577ff5d9221 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:35:37 +0000 Subject: [PATCH 126/215] test: Cover more in `str.replace_all` --- narwhals/_plan/expressions/strings.py | 8 +++--- tests/plan/str_replace_test.py | 38 +++++++++++++++++++++------ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 103c6bb1f4..5dc730e8e5 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -120,9 +120,7 @@ def replace(self, pattern: str, *, literal: bool = False, n: int = 1) -> Replace def replace_all(self, pattern: str, *, literal: bool = False) -> ReplaceAll: return ReplaceAll(pattern=pattern, literal=literal) - def strip_chars( - self, characters: str | None = None - ) -> StripChars: # pragma: no cover + def strip_chars(self, characters: str | None = None) -> StripChars: return StripChars(characters=characters) def contains(self, pattern: str, *, literal: bool = False) -> Contains: @@ -166,7 +164,7 @@ def replace_all( replace = self._ir.replace_all(pattern, literal=literal) return self._expr._from_ir(replace.to_function_expr(self._expr._ir, other)) - def strip_chars(self, characters: str | None = None) -> Expr: # pragma: no cover + def strip_chars(self, characters: str | None = None) -> Expr: return self._with_unary(self._ir.strip_chars(characters)) def starts_with(self, prefix: str) -> Expr: @@ -196,7 +194,7 @@ def to_date(self, format: str | None = None) -> Expr: # pragma: no cover def to_datetime(self, format: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_datetime(format)) - def to_lowercase(self) -> Expr: # pragma: no cover + def to_lowercase(self) -> Expr: return self._with_unary(self._ir.to_lowercase()) def to_uppercase(self) -> Expr: diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py index 2bdc93bb89..d2190a650d 100644 --- a/tests/plan/str_replace_test.py +++ b/tests/plan/str_replace_test.py @@ -85,13 +85,32 @@ ], ) -# TODO @dangotbanned: Cover more than these cases, it's just a repeat of `-1` + replace_all_vector = pytest.mark.parametrize( ("data", "pattern", "value", "literal", "expected"), [ - (A1, r"abc", "b", False, ["123ghi", "jkl456"]), - (A2, r"abc", "b", False, ["ghi ghi", "jkl456"]), - (A4, r"$", "b", True, ["Dollar ghiign", "literal"]), + pytest.param(A1, r"abc", nwp.col("b"), False, ["123ghi", "jkl456"], id="single"), + pytest.param(A2, r"abc", nwp.col("b"), False, ["ghi ghi", "jkl456"], id="mixed"), + pytest.param( + A4, r"$", nwp.col("b"), True, ["Dollar ghiign", "literal"], id="literal" + ), + pytest.param(A5, r"o", nwp.col("b"), False, [None, "jkljklp"], id="null-input"), + pytest.param( + A3, + r"\d", + nwp.col("b").first(), + False, + ["abc abc abc", "ghighighiabc"], + id="agg-replacement", + ), + pytest.param( + A3, + r" ?abc$", + nwp.lit(" HELLO").str.to_lowercase().str.strip_chars(), + False, + ["abc abchello", "456hello"], + id="transformed-replacement", + ), ], ) @@ -137,10 +156,13 @@ def test_str_replace_all_scalar( @replace_all_vector def test_str_replace_all_vector( - data: list[str], pattern: str, value: str, *, literal: bool, expected: list[str] + data: Sequence[str | None], + pattern: str, + value: nwp.Expr, + *, + literal: bool, + expected: Sequence[str | None], ) -> None: df = dataframe({"a": data, "b": B}) - result = df.select( - nwp.col("a").str.replace_all(pattern, nwp.col(value), literal=literal) - ) + result = df.select(nwp.col("a").str.replace_all(pattern, value, literal=literal)) assert_equal_data(result, {"a": expected}) From 9cc469bcf9afe698d23c85b616d90b026c66c573 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:01:13 +0000 Subject: [PATCH 127/215] refactor: Merge `str_replace` impls --- narwhals/_plan/arrow/functions.py | 62 ++++++++++--------------------- 1 file changed, 20 insertions(+), 42 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 0e1e58f172..d9ddc32cf5 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -577,49 +577,23 @@ def str_replace_vector( literal: bool = False, n: int = 1, ) -> ChunkedArrayAny: - if n == 1: - return _str_replace_vector_n_1(native, pattern, replacements, literal=literal) + n_: Literal[1] | None if n == -1: - return _str_replace_all_vector(native, pattern, replacements, literal=literal) - msg = f"`pyarrow` currently only supports `str.replace(value: Expr, n=1)`, got {n=} " - raise NotImplementedError(msg) - - -# TODO @dangotbanned: Is this worth keeping, now that `replace_all` is simpler? -def _str_replace_vector_n_1( - native: ChunkedArrayAny, - pattern: str, - replacements: ChunkedArrayAny, - *, - literal: bool = False, -) -> ChunkedArrayAny: - has_match = str_contains(native, pattern, literal=literal) - if not any_(has_match).as_py(): - # fastpath, no work to do - return native - tbl_matches = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter( - has_match - ) - list_split_by = str_splitn( - array(tbl_matches.column(0)), pattern, n=2, literal=literal - ) - list_flat = list_split_by.values - needs_replacing = eq(list_flat, lit("", list_flat.type)) - # Needs better name - replaced_wrong_shape = replace_with_mask( - list_flat, needs_replacing, tbl_matches.column(1) - ) - fully_replaced = concat_str(replaced_wrong_shape[0::2], replaced_wrong_shape[1::2]) - if all_(has_match, ignore_nulls=False).as_py(): - return chunked_array(fully_replaced) - return replace_with_mask(native, has_match, fully_replaced) + n_ = None + elif n == 1: + n_ = 1 + else: + msg = f"`pyarrow` currently only supports `str.replace(value: Expr, n=1)`, got {n=} " + raise NotImplementedError(msg) + return _str_replace_vector(native, pattern, replacements, n_, literal=literal) -# TODO @dangotbanned: Share more with `n=1` -def _str_replace_all_vector( +def _str_replace_vector( native: ChunkedArrayAny, pattern: str, replacements: ChunkedArrayAny, + # TODO @dangotbanned: Might be simple to do `n>1` now? + n: Literal[1] | None = None, *, literal: bool = False, ) -> ChunkedArrayAny: @@ -630,12 +604,16 @@ def _str_replace_all_vector( tbl_matches = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter( has_match ) - # here we can have unequal-length lists - list_split_by = str_split(array(tbl_matches.column(0)), pattern, literal=literal) - fully_replaced = list_join(list_split_by, tbl_matches.column(1)) + matches = tbl_matches.column(0) + match_replacements = tbl_matches.column(1) + if n is None: + list_split_by = str_split(matches, pattern, literal=literal) + else: + list_split_by = str_splitn(matches, pattern, n + 1, literal=literal) + replaced = list_join(list_split_by, match_replacements) if all_(has_match, ignore_nulls=False).as_py(): - return chunked_array(fully_replaced) - return replace_with_mask(native, has_match, fully_replaced) + return chunked_array(replaced) + return replace_with_mask(native, has_match, array(replaced)) def str_zfill(native: ChunkedOrScalarAny, length: int) -> ChunkedOrScalarAny: From 3f0d8e0266d1496ed9f96bb791e2cafe008aedbc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:07:29 +0000 Subject: [PATCH 128/215] test: Add a test for `str.replace(value: Expr, n>1)` --- tests/plan/str_replace_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py index d2190a650d..933da4afbc 100644 --- a/tests/plan/str_replace_test.py +++ b/tests/plan/str_replace_test.py @@ -36,6 +36,18 @@ pytest.param( A2, r"abc", nwp.col("b"), 1, False, ["ghi abc", "jkl456"], id="n-1-mixed" ), + pytest.param( + A3, + r"a", + nwp.col("b"), + 2, + False, + ["ghibc ghibc abc", "456jklbc"], + id="n-2-mixed", + marks=pytest.mark.xfail( + reason="TODO: str.replace(value: Expr, n>1)", raises=NotImplementedError + ), + ), pytest.param( A3, r"abc", From abe577d9b50e977b56c0fe7d0bce5a6a50e071d0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:39:20 +0000 Subject: [PATCH 129/215] feat: Support `str.replace(value: Expr, n>1)` --- narwhals/_plan/arrow/functions.py | 24 ++---------------------- tests/plan/str_replace_test.py | 3 --- 2 files changed, 2 insertions(+), 25 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d9ddc32cf5..d3b5ca5676 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -575,27 +575,7 @@ def str_replace_vector( replacements: ChunkedArrayAny, *, literal: bool = False, - n: int = 1, -) -> ChunkedArrayAny: - n_: Literal[1] | None - if n == -1: - n_ = None - elif n == 1: - n_ = 1 - else: - msg = f"`pyarrow` currently only supports `str.replace(value: Expr, n=1)`, got {n=} " - raise NotImplementedError(msg) - return _str_replace_vector(native, pattern, replacements, n_, literal=literal) - - -def _str_replace_vector( - native: ChunkedArrayAny, - pattern: str, - replacements: ChunkedArrayAny, - # TODO @dangotbanned: Might be simple to do `n>1` now? - n: Literal[1] | None = None, - *, - literal: bool = False, + n: int | None = 1, ) -> ChunkedArrayAny: has_match = str_contains(native, pattern, literal=literal) if not any_(has_match).as_py(): @@ -606,7 +586,7 @@ def _str_replace_vector( ) matches = tbl_matches.column(0) match_replacements = tbl_matches.column(1) - if n is None: + if n is None or n == -1: list_split_by = str_split(matches, pattern, literal=literal) else: list_split_by = str_splitn(matches, pattern, n + 1, literal=literal) diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py index 933da4afbc..b61db9010e 100644 --- a/tests/plan/str_replace_test.py +++ b/tests/plan/str_replace_test.py @@ -44,9 +44,6 @@ False, ["ghibc ghibc abc", "456jklbc"], id="n-2-mixed", - marks=pytest.mark.xfail( - reason="TODO: str.replace(value: Expr, n>1)", raises=NotImplementedError - ), ), pytest.param( A3, From 24f2358825f66837a0184a985eb658087ef76a34 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 3 Dec 2025 21:28:18 +0000 Subject: [PATCH 130/215] cursed typing --- narwhals/_plan/arrow/functions.py | 39 ++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d3b5ca5676..b4c6ca40e9 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -26,7 +26,7 @@ import datetime as dt from collections.abc import Iterable, Mapping - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import TypeAlias, TypeIs, TypeVarTuple, Unpack from narwhals._arrow.typing import Incomplete, PromoteOptions from narwhals._plan.arrow.acero import Field @@ -88,6 +88,8 @@ PythonLiteral, ) + Ts = TypeVarTuple("Ts") + BACKEND_VERSION = Implementation.PYARROW._backend_version() """Static backend version for `pyarrow`.""" @@ -507,6 +509,15 @@ def str_split(native: ArrowAny, by: str, *, literal: bool = True) -> Arrow[ListS return _str_split(native, by, literal=literal) +@t.overload +def str_splitn( + native: ChunkedOrScalarAny, + by: str, + n: int, + *, + literal: bool = ..., + as_struct: bool = ..., +) -> ChunkedOrScalar[ListScalar]: ... @t.overload def str_splitn( native: ArrayAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... @@ -581,15 +592,11 @@ def str_replace_vector( if not any_(has_match).as_py(): # fastpath, no work to do return native - tbl_matches = pa.Table.from_arrays([native, replacements], ["0", "1"]).filter( - has_match - ) - matches = tbl_matches.column(0) - match_replacements = tbl_matches.column(1) + match, match_replacements = filter_arrays(has_match, native, replacements) if n is None or n == -1: - list_split_by = str_split(matches, pattern, literal=literal) + list_split_by = str_split(match, pattern, literal=literal) else: - list_split_by = str_splitn(matches, pattern, n + 1, literal=literal) + list_split_by = str_splitn(match, pattern, n + 1, literal=literal) replaced = list_join(list_split_by, match_replacements) if all_(has_match, ignore_nulls=False).as_py(): return chunked_array(replaced) @@ -1289,3 +1296,19 @@ def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataT and isinstance(first[0], str) and isinstance(first[1], pa.DataType) ) + + +def filter_arrays( + predicate: ChunkedOrArray[BooleanScalar] | pc.Expression, + *arrays: Unpack[Ts], + ignore_nulls: bool = True, +) -> tuple[Unpack[Ts]]: + """Apply the same filter to multiple arrays, returning them independently. + + Note: + The typing here is a minefield. You'll get an `*arrays`-length `tuple[ChunkedArray, ...]`. + """ + table: Incomplete = pa.Table.from_arrays + tmp = [str(i) for i in range(len(arrays))] + result = table(arrays, tmp).filter(predicate, "drop" if ignore_nulls else "emit_null") + return t.cast("tuple[Unpack[Ts]]", tuple(result.columns)) From 10100a2066770a98a2e75cd158fba22ed12a5ac2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Nov 2025 14:33:51 +0000 Subject: [PATCH 131/215] feat(DRAFT): Add native `linear_space` --- narwhals/_plan/arrow/functions.py | 42 +++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index b4c6ca40e9..af06966a5b 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1210,6 +1210,48 @@ def date_range( return ca.cast(pa.date32()) +def linear_space( + start: int, end: int, num_samples: int, *, closed: ClosedInterval = "both" +) -> ChunkedArray[pc.NumericScalar]: + """Based on [`np.linspace`]. + + Use when implementing `hist`. + + [`np.linspace`]: https://github.com/numpy/numpy/blob/v2.3.0/numpy/_core/function_base.py#L26-L187 + """ + if num_samples < 0: + msg = f"Number of samples, {num_samples}, must be non-negative." + raise ValueError(msg) + if num_samples == 1: + msg = f"num_samples {num_samples} is not >= 2" + raise NotImplementedError(msg) + if closed == "both": + range_end = num_samples + div = num_samples - 1 + elif closed == "left": + range_end = num_samples + div = num_samples + elif closed == "right": + range_end = num_samples + 1 + div = num_samples + elif closed == "none": + range_end = num_samples + 1 + div = num_samples + 1 + ca: ChunkedArray[pc.NumericScalar] = int_range(0, range_end).cast(F64) + delta = float(end - start) + step = delta / div + if step == 0: + ca = truediv(ca, lit(div)) + ca = multiply(ca, lit(delta)) + else: + ca = multiply(ca, lit(step)) + if start != 0: + ca = add(ca, lit(start, F64)) + if closed in {"right", "none"}: + return ca.slice(1) + return ca + + def repeat(value: ScalarAny | NonNestedLiteral, n: int) -> ArrayAny: value = value if isinstance(value, pa.Scalar) else lit(value) return repeat_unchecked(value, n) From 048f37245ece12ba2c19ad3098d6fc3a5d42bbf0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:23:33 +0000 Subject: [PATCH 132/215] feat(DRAFT): More `Series.hist` prep - Port most of the test suite - Add `DataFrame`, `Series` features used in the tests - Extend `hist` signatures --- narwhals/_plan/arrow/series.py | 18 ++ narwhals/_plan/compliant/dataframe.py | 1 + narwhals/_plan/compliant/series.py | 10 ++ narwhals/_plan/dataframe.py | 69 +++++++- narwhals/_plan/expr.py | 6 +- narwhals/_plan/series.py | 60 ++++++- tests/plan/hist_test.py | 235 ++++++++++++++++++++++++++ 7 files changed, 393 insertions(+), 6 deletions(-) create mode 100644 tests/plan/hist_test.py diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 5437403cd9..c2d8d3a47b 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -99,6 +99,18 @@ def scatter(self, indices: Self, values: Self) -> Self: def is_in(self, other: Self) -> Self: return self._with_native(fn.is_in(self.native, other.native)) + def is_nan(self) -> Self: + return self._with_native(fn.is_nan(self.native)) + + def is_null(self) -> Self: + return self._with_native(fn.is_null(self.native)) + + def is_not_nan(self) -> Self: + return self._with_native(fn.is_not_nan(self.native)) + + def is_not_null(self) -> Self: + return self._with_native(fn.is_not_null(self.native)) + def has_nulls(self) -> bool: return bool(self.native.null_count) @@ -265,5 +277,11 @@ def all(self) -> bool: def any(self) -> bool: return fn.any_(self.native).as_py() + def sum(self) -> float: + return fn.sum_(self.native).as_py() # type: ignore[no-any-return] + + def count(self) -> int: + return fn.count(self.native).as_py() + def unique(self, *, maintain_order: bool = False) -> Self: return self._with_native(self.native.unique()) diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index 83f4ff21ca..e8af7c4532 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -138,6 +138,7 @@ def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Se return self._group_by.from_resolver(self, resolver) def filter(self, predicate: NamedIR, /) -> Self: ... + def iter_columns(self) -> Iterator[SeriesT]: ... def join( self, other: Self, diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 5d2ee7c7ce..186b1e96c8 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -118,6 +118,8 @@ def native(self) -> NativeSeriesT: def all(self) -> bool: ... def any(self) -> bool: ... + def sum(self) -> float: ... + def count(self) -> int: ... def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) @@ -145,6 +147,14 @@ def is_empty(self) -> bool: return len(self) == 0 def is_in(self, other: Self) -> Self: ... + def is_nan(self) -> Self: ... + def is_null(self) -> Self: ... + def is_not_nan(self) -> Self: + return self.is_nan().__invert__() + + def is_not_null(self) -> Self: + return self.is_null().__invert__() + def rolling_mean( self, window_size: int, *, min_samples: int, center: bool = False ) -> Self: ... diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index d7f073c96f..b791c71780 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -25,18 +25,31 @@ from narwhals.dependencies import is_pyarrow_table from narwhals.exceptions import ShapeError from narwhals.schema import Schema -from narwhals.typing import IntoDType, JoinStrategy +from narwhals.typing import EagerAllowed, IntoBackend, IntoDType, IntoSchema, JoinStrategy if TYPE_CHECKING: - from collections.abc import Iterable, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence import polars as pl import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs from narwhals._plan.arrow.typing import NativeArrowDataFrame - from narwhals._plan.compliant.dataframe import CompliantDataFrame, CompliantFrame - from narwhals._typing import _EagerAllowedImpl + from narwhals._plan.compliant.dataframe import ( + CompliantDataFrame, + CompliantFrame, + EagerDataFrame, + ) + from narwhals._plan.compliant.namespace import EagerNamespace + from narwhals._plan.compliant.series import CompliantSeries + from narwhals._typing import Arrow, _EagerAllowedImpl + + EagerNs: TypeAlias = EagerNamespace[ + EagerDataFrame[Any, NativeDataFrameT, Any], + CompliantSeries[NativeSeriesT], + Any, + Any, + ] Incomplete: TypeAlias = Any @@ -143,6 +156,15 @@ def with_row_index( return self._with_compliant(self._compliant.with_row_index_by(name, by_names)) +def _dataframe_from_dict( + data: Mapping[str, Any], + schema: IntoSchema | None, + ns: EagerNs[NativeDataFrameT, NativeSeriesT], + /, +) -> DataFrame[NativeDataFrameT, NativeSeriesT]: + return ns._dataframe.from_dict(data, schema=schema).to_narwhals() + + class DataFrame( BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT] ): @@ -196,6 +218,41 @@ def from_native( raise NotImplementedError(type(native)) + @overload + @classmethod + def from_dict( + cls: type[DataFrame[Any, Any]], + data: Mapping[str, Any], + schema: IntoSchema | None = ..., + *, + backend: Arrow, + ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: ... + @overload + @classmethod + def from_dict( + cls: type[DataFrame[Any, Any]], + data: Mapping[str, Any], + schema: IntoSchema | None = None, + *, + backend: IntoBackend[EagerAllowed] | None = ..., + ) -> DataFrame[Any, Any]: ... + @classmethod + def from_dict( + cls: type[DataFrame[Any, Any]], + data: Mapping[str, Any], + schema: IntoSchema | None = None, + *, + backend: IntoBackend[EagerAllowed] | None = None, + ) -> DataFrame[Any, Any]: + from narwhals._plan import functions as F + + if backend is None: + msg = f"`from_dict({backend=})`" + raise NotImplementedError(msg) + + ns = F._eager_namespace(backend) + return _dataframe_from_dict(data, schema, ns) + @overload def to_dict( self, *, as_series: Literal[True] = ... @@ -254,6 +311,10 @@ def group_by( def row(self, index: int) -> tuple[Any, ...]: return self._compliant.row(index) + def iter_columns(self) -> Iterator[Series[NativeSeriesT]]: + for series in self._compliant.iter_columns(): + yield self._series(series) + def join( self, other: Self, diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index b757ea828e..540ba60d0a 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -195,8 +195,12 @@ def hist( bins: Sequence[float] | None = None, *, bin_count: int | None = None, - include_breakpoint: bool = True, + include_breakpoint: bool = True, # NOTE: `pl.Expr.hist` default is `False` + include_category: bool = False, ) -> Self: + if include_category: + msg = f"`Expr.hist({include_category=})` is not yet implemented" + raise NotImplementedError(msg) node: F.Hist if bins is not None: if bin_count is not None: diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index b7dafbc92f..19ae1fe7df 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any, ClassVar, Generic from narwhals._plan._guards import is_series @@ -187,6 +187,18 @@ def scatter( def is_in(self, other: Iterable[Any]) -> Self: return type(self)(self._compliant.is_in(self._parse_into_compliant(other))) + def is_nan(self) -> Self: + return type(self)(self._compliant.is_nan()) + + def is_null(self) -> Self: + return type(self)(self._compliant.is_null()) + + def is_not_nan(self) -> Self: # pragma: no cover + return type(self)(self._compliant.is_not_nan()) + + def is_not_null(self) -> Self: # pragma: no cover + return type(self)(self._compliant.is_not_null()) + def null_count(self) -> int: return self._compliant.null_count() @@ -221,15 +233,61 @@ def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # typ other_ = self._unwrap_compliant(other) if is_series(other) else other return type(self)(self._compliant.__eq__(other_)) + def __or__(self, other: bool | Self, /) -> Self: + other_ = self._unwrap_compliant(other) if is_series(other) else other + return type(self)(self._compliant.__or__(other_)) + + def __invert__(self) -> Self: + return type(self)(self._compliant.__invert__()) + + def __add__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: + other_ = self._unwrap_compliant(other) if is_series(other) else other + return type(self)(self._compliant.__add__(other_)) + def all(self) -> bool: return self._compliant.all() def any(self) -> bool: # pragma: no cover return self._compliant.any() + def sum(self) -> float: + return self._compliant.sum() + + def count(self) -> int: + return self._compliant.count() + def unique(self, *, maintain_order: bool = False) -> Self: # pragma: no cover return type(self)(self._compliant.unique(maintain_order=maintain_order)) + def hist( + self, + bins: Sequence[float] | None = None, + *, + bin_count: int | None = None, + # NOTE: `pl.Series.hist` defaults are the opposite of `pl.Expr.hist` + include_breakpoint: bool = True, + include_category: bool = False, # NOTE: `pl.Series.hist` default is `True`, but that would be breaking (ish) for narwhals + ) -> DataFrame[Incomplete, NativeSeriesT_co]: + from narwhals._plan import functions as F + + result = ( + self.to_frame() + .select( + F.col(self.name).hist( + bins, + bin_count=bin_count, + include_breakpoint=include_breakpoint, + include_category=include_category, + ) + ) + .to_series() + ) + if not include_breakpoint and not include_category: + return result.to_frame() + msg = f"`Series.hist({include_breakpoint=}, {include_category=})` requires `Series.struct.unnest`" + raise NotImplementedError(msg) + return result.struct.unnest() + class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py new file mode 100644 index 0000000000..cacdba0078 --- /dev/null +++ b/tests/plan/hist_test.py @@ -0,0 +1,235 @@ +# TODO(unassigned): cudf has too many spurious failures. Report and revisit? +# Modin is too slow so is excluded. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data + +if TYPE_CHECKING: + from collections.abc import Sequence + + from narwhals.typing import EagerAllowed + from tests.conftest import Data + + +XFAIL_HIST_NOT_IMPLEMENTED = pytest.mark.xfail( + reason="`ArrowExpr.hist_*` is not yet implemented" +) + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "int": [0, 1, 2, 3, 4, 5, 6], + "float": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "int_shuffled": [1, 0, 2, 3, 6, 5, 4], + "float_shuffled": [1.0, 0.0, 2.0, 3.0, 6.0, 5.0, 4.0], + } + + +@pytest.fixture(scope="module") +def data_missing(data: Data) -> Data: + return {"has_nan": [float("nan"), *data["int"]], "has_null": [None, *data["int"]]} + + +@pytest.fixture(scope="module", params=["pyarrow"]) +def backend(request: pytest.FixtureRequest) -> EagerAllowed: + result: EagerAllowed = request.param + return result + + +@pytest.fixture( + scope="module", params=[True, False], ids=["breakpoint-True", "breakpoint-False"] +) +def include_breakpoint(request: pytest.FixtureRequest) -> bool: + result: bool = request.param + return result + + +counts_and_expected = [ + { + "bin_count": 4, + "expected_bins": [0, 1.5, 3.0, 4.5, 6.0], + "expected_count": [2, 2, 1, 2], + }, + { + "bin_count": 12, + "expected_bins": [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0], + "expected_count": [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + }, + {"bin_count": 1, "expected_bins": [0, 6], "expected_count": [7]}, + {"bin_count": 0, "expected_bins": [], "expected_count": []}, +] + + +SHIFT_BINS_BY = 10 +"""shift bins property""" + + +# TODO @dangotbanned: Try to avoid all this looping (3x `iter_columns` in a single test?) +@XFAIL_HIST_NOT_IMPLEMENTED +@pytest.mark.parametrize( + ("bins", "expected"), + [ + ([-float("inf"), 2.5, 5.5, float("inf")], [3, 3, 1]), + ([1.0, 2.5, 5.5, float("inf")], [2, 3, 1]), + ([1.0, 2.5, 5.5], [2, 3]), + ([-10.0, -1.0, 2.5, 5.5], [0, 3, 3]), + ([1.0, 2.0625], [2]), + ([1], []), + ([0, 10], [7]), + ], + ids=str, +) +def test_hist_bin( + data: Data, + data_missing: Data, + backend: EagerAllowed, + bins: list[float], + expected: Sequence[float], + *, + include_breakpoint: bool, +) -> None: + df = nwp.DataFrame.from_dict(data, backend=backend).with_columns( + float=nwp.col("int").cast(nw.Float64) + ) + expected_full = {"count": expected} + if include_breakpoint: + expected_full = {"breakpoint": bins[1:], **expected_full} + # smoke tests + for series in df.iter_columns(): + result = series.hist(bins=bins, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected_full) + + # result size property + assert len(result) == max(len(bins) - 1, 0) + + # shift bins property + shifted_bins = [b + SHIFT_BINS_BY for b in bins] + expected_full = {"count": expected} + if include_breakpoint: + expected_full = {"breakpoint": shifted_bins[1:], **expected_full} + + for series in df.iter_columns(): + result = (series + SHIFT_BINS_BY).hist( + shifted_bins, include_breakpoint=include_breakpoint + ) + assert_equal_data(result, expected_full) + + # missing/nan results + df = nwp.DataFrame.from_dict(data_missing, backend=backend) + expected_full = {"count": expected} + if include_breakpoint: + expected_full = {"breakpoint": bins[1:], **expected_full} + for series in df.iter_columns(): + result = series.hist(bins, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected_full) + + +# TODO @dangotbanned: Avoid using `del` +# TODO @dangotbanned: Split up `params` +# TODO @dangotbanned: Try to avoid all this looping +@XFAIL_HIST_NOT_IMPLEMENTED +@pytest.mark.parametrize("params", counts_and_expected) +def test_hist_count( + data: Data, + data_missing: Data, + backend: EagerAllowed, + *, + params: dict[str, Any], + include_breakpoint: bool, +) -> None: + df = nwp.DataFrame.from_dict(data, backend=backend).with_columns( + float=nwp.col("int").cast(nw.Float64) + ) + bin_count = params["bin_count"] + + expected_bins = params["expected_bins"] + expected = {"breakpoint": expected_bins[1:], "count": params["expected_count"]} + if not include_breakpoint: + del expected["breakpoint"] + + # smoke tests + for col in df.columns: + result = df.get_column(col).hist( + bin_count=bin_count, include_breakpoint=include_breakpoint + ) + assert_equal_data(result, expected) + + # result size property + + assert len(result) == bin_count + if bin_count > 0: + assert result.get_column("count").sum() == df.get_column(col).count() + + # missing/nan results + df = nwp.DataFrame.from_dict(data_missing, backend=backend) + + for col in df.columns: + result = df.get_column(col).hist( + bin_count=bin_count, include_breakpoint=include_breakpoint + ) + assert_equal_data(result, expected) + + # result size property + assert len(result) == bin_count + ser = df.get_column(col) + if bin_count > 0: + # NOTE: Could this just be a filter? + assert ( + result.get_column("count").sum() + == (~(ser.is_nan() | ser.is_null())).sum() + ) + + +# TODO @dangotbanned: parametrize into 3 cases +@XFAIL_HIST_NOT_IMPLEMENTED +def test_hist_count_no_spread(backend: EagerAllowed) -> None: + data_ = {"all_zero": [0, 0, 0], "all_non_zero": [5, 5, 5]} + df = nwp.DataFrame.from_dict(data_, backend=backend) + + result = df.get_column("all_zero").hist(bin_count=4, include_breakpoint=True) + expected = {"breakpoint": [-0.25, 0.0, 0.25, 0.5], "count": [0, 3, 0, 0]} + assert_equal_data(result, expected) + + result = df.get_column("all_non_zero").hist(bin_count=4, include_breakpoint=True) + expected = {"breakpoint": [4.75, 5.0, 5.25, 5.5], "count": [0, 3, 0, 0]} + assert_equal_data(result, expected) + + result = df.get_column("all_zero").hist(bin_count=1, include_breakpoint=True) + expected = {"breakpoint": [0.5], "count": [3]} + assert_equal_data(result, expected) + + +# TODO @dangotbanned: parametrize into 2 cases? +@XFAIL_HIST_NOT_IMPLEMENTED +def test_hist_no_data(backend: EagerAllowed, *, include_breakpoint: bool) -> None: + data_: Data = {"values": []} + df = nwp.DataFrame.from_dict(data_, {"values": nw.Float64()}, backend=backend) + s = df.to_series() + for bin_count in [1, 10]: + result = s.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) + assert len(result) == bin_count + assert result.get_column("count").sum() == 0 + + if include_breakpoint: + bps = result.get_column("breakpoint").to_list() + assert bps[0] == (1 / bin_count) + if bin_count > 1: + assert bps[-1] == 1 + + result = s.hist(bins=[1, 5, 10], include_breakpoint=include_breakpoint) + assert len(result) == 2 + assert result.get_column("count").sum() == 0 + + +@XFAIL_HIST_NOT_IMPLEMENTED +def test_hist_small_bins(backend: EagerAllowed) -> None: + s = nwp.Series.from_iterable([1, 2, 3], name="values", backend=backend) + result = s.hist(bins=None, bin_count=None) + assert len(result) == 10 From 9e92e08de5244123cb63917b39398775532f41c4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:30:05 +0000 Subject: [PATCH 133/215] feat: Add `DataFrame.to_struct` Needed for `Expr.hist` --- narwhals/_plan/arrow/dataframe.py | 13 ++++++++++++- narwhals/_plan/arrow/functions.py | 18 ++++++++++++++++++ narwhals/_plan/compliant/dataframe.py | 1 + narwhals/_plan/dataframe.py | 3 +++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 0784b1dd00..4e2a5d3817 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -24,7 +24,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence import polars as pl - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan.arrow.typing import ChunkedArrayAny from narwhals._plan.compliant.group_by import GroupByResolver @@ -34,6 +34,8 @@ from narwhals.dtypes import DType from narwhals.typing import IntoSchema +Incomplete: TypeAlias = Any + class ArrowDataFrame( FrameSeries["pa.Table"], EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"] @@ -127,6 +129,15 @@ def with_row_index_by( column = fn.unsort_indices(indices) return self._with_native(self.native.add_column(0, name, column)) + def to_struct(self, name: str = "") -> Series: + native = self.native + struct = ( + native.to_struct_array() + if fn.BACKEND_VERSION >= (15, 0) + else fn.struct(self.columns, values=native.columns) + ) + return Series.from_native(struct, name, version=self.version) + def get_column(self, name: str) -> Series: chunked = self.native.column(name) return Series.from_native(chunked, name, version=self.version) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index af06966a5b..9b337da209 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -296,6 +296,24 @@ def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStr return pa.large_string() if has_large_string(data_types) else pa.string() +@t.overload +def struct(names: Sequence[str], values: Iterable[ScalarAny]) -> pa.StructScalar: ... +@t.overload +def struct( + names: Sequence[str], values: Iterable[ChunkedOrScalarAny] +) -> ChunkedStruct: ... +@t.overload +def struct( + names: Sequence[str], values: Iterable[ArrayAny | ScalarAny] +) -> pa.StructArray: ... +def struct(names: Sequence[str], values: Iterable[Any]) -> Any: + """Convert `values` into a struct. + + The output shape will be scalar if all inputs are scalar, otherwise any scalars will be broadcast to arrays. + """ + return pc.make_struct(*values, options=pc.MakeStructOptions(field_names=names)) + + @t.overload def struct_field(native: ChunkedStruct, field: Field, /) -> ChunkedArrayAny: ... @t.overload diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index e8af7c4532..b502848534 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -170,6 +170,7 @@ def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: return DataFrame[NativeDataFrameT, NativeSeriesT](self) def to_series(self, index: int = 0) -> SeriesT: ... + def to_struct(self, name: str = "") -> SeriesT: ... def to_polars(self) -> pl.DataFrame: ... def with_row_index(self, name: str) -> Self: ... def slice(self, offset: int, length: int | None = None) -> Self: ... diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index b791c71780..27a237ed36 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -276,6 +276,9 @@ def to_dict( def to_series(self, index: int = 0) -> Series[NativeSeriesT]: return self._series(self._compliant.to_series(index)) + def to_struct(self, name: str = "") -> Series[NativeSeriesT]: + return self._series(self._compliant.to_struct(name)) + def to_polars(self) -> pl.DataFrame: return self._compliant.to_polars() From 1172f82a647e799c71841db8b097dec6a314f988 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 17:03:36 +0000 Subject: [PATCH 134/215] feat: Add `Series.struct.unnest` Also needed in `Expr.hist` --- narwhals/_plan/arrow/series.py | 49 +++++++++++++++++++++++++-- narwhals/_plan/compliant/accessors.py | 12 ++++++- narwhals/_plan/compliant/series.py | 4 +++ narwhals/_plan/compliant/typing.py | 1 + narwhals/_plan/series.py | 47 +++++++++++++++++++++---- narwhals/_plan/typing.py | 1 + 6 files changed, 105 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index c2d8d3a47b..35530964bc 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -1,12 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast +import pyarrow as pa import pyarrow.compute as pc from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn, options from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries +from narwhals._plan.compliant.accessors import SeriesStructNamespace as StructNamespace from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import functions as F @@ -20,6 +22,7 @@ from typing_extensions import Self from narwhals._plan.arrow.dataframe import ArrowDataFrame as DataFrame + from narwhals._plan.arrow.namespace import ArrowNamespace as Namespace from narwhals._plan.arrow.typing import ChunkedArrayAny from narwhals.dtypes import DType from narwhals.typing import ( @@ -53,7 +56,7 @@ def to_polars(self) -> pl.Series: import polars as pl # ignore-banned-import # NOTE: Recommended in https://github.com/pola-rs/polars/issues/22921#issuecomment-2908506022 - return pl.Series(self.native) + return pl.Series(self.name, self.native) def __len__(self) -> int: return self.native.length() @@ -285,3 +288,45 @@ def count(self) -> int: def unique(self, *, maintain_order: bool = False) -> Self: return self._with_native(self.native.unique()) + + @property + def struct(self) -> SeriesStructNamespace: + return SeriesStructNamespace(self) + + +class SeriesStructNamespace(StructNamespace[ArrowSeries, "DataFrame"]): + def __init__(self, compliant: ArrowSeries, /) -> None: + self._compliant: ArrowSeries = compliant + + @property + def compliant(self) -> ArrowSeries: + return self._compliant + + @property + def native(self) -> ChunkedArrayAny: + return self.compliant.native + + def __narwhals_namespace__(self) -> Namespace: + return namespace(self.compliant) + + @property + def version(self) -> Version: + return self.compliant.version + + def with_native(self, native: ChunkedArrayAny, name: str, /) -> ArrowSeries: + return self.compliant.from_native(native, name, version=self.version) + + def unnest(self) -> DataFrame: + if len(self.native): + table = pa.Table.from_struct_array(self.native) + else: + # TODO @dangotbanned: Report empty bug upstream, no option to pass a schema to resolve the error + # `ValueError: Must pass schema, or at least one RecordBatch` + # https://github.com/apache/arrow/blob/b2e8f2505ba3eafe65a78ece6ae87fa7d0c1c133/python/pyarrow/table.pxi#L4943-L4949 + tp_struct = cast("pa.StructType", self.native.type) + table = pa.schema(tp_struct.fields).empty_table() + return namespace(self)._dataframe.from_native(table, self.version) + + # name overriding *may* be wrong + def field(self, name: str) -> ArrowSeries: + return self.with_native(fn.struct_field(self.native, name), name) diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 535b299dd8..3ba16b4e0d 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -2,7 +2,12 @@ from typing import TYPE_CHECKING, Protocol -from narwhals._plan.compliant.typing import ExprT_co, FrameT_contra +from narwhals._plan.compliant.typing import ( + DataFrameT_co, + ExprT_co, + FrameT_contra, + SeriesT_co, +) if TYPE_CHECKING: from narwhals._plan.expressions import FunctionExpr as FExpr, lists, strings @@ -83,3 +88,8 @@ class ExprStructNamespace(Protocol[FrameT_contra, ExprT_co]): def field( self, node: FExpr[FieldByName], frame: FrameT_contra, name: str ) -> ExprT_co: ... + + +class SeriesStructNamespace(Protocol[SeriesT_co, DataFrameT_co]): + def field(self, name: str) -> SeriesT_co: ... + def unnest(self) -> DataFrameT_co: ... diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 186b1e96c8..5f38505d81 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -12,6 +12,7 @@ import polars as pl from typing_extensions import Self, TypeAlias + from narwhals._plan.compliant.accessors import SeriesStructNamespace from narwhals._plan.series import Series from narwhals._typing import _EagerAllowedImpl from narwhals.dtypes import DType @@ -190,3 +191,6 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: def to_polars(self) -> pl.Series: ... def unique(self, *, maintain_order: bool = False) -> Self: ... def zip_with(self, mask: Self, other: Self) -> Self: ... + + @property + def struct(self) -> SeriesStructNamespace[Self, Incomplete]: ... diff --git a/narwhals/_plan/compliant/typing.py b/narwhals/_plan/compliant/typing.py index ed9af83a50..01daea4c40 100644 --- a/narwhals/_plan/compliant/typing.py +++ b/narwhals/_plan/compliant/typing.py @@ -49,6 +49,7 @@ FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) DataFrameT = TypeVar("DataFrameT", bound=DataFrameAny) +DataFrameT_co = TypeVar("DataFrameT_co", bound=DataFrameAny, covariant=True) NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 19ae1fe7df..a24e3f2c1d 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -4,13 +4,14 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic from narwhals._plan._guards import is_series -from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co, OneOrIterable +from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co, OneOrIterable, SeriesT from narwhals._utils import ( Implementation, Version, generate_repr, is_eager_allowed, qualified_type_name, + unstable, ) from narwhals.dependencies import is_pyarrow_chunked_array from narwhals.exceptions import ShapeError @@ -259,6 +260,7 @@ def count(self) -> int: def unique(self, *, maintain_order: bool = False) -> Self: # pragma: no cover return type(self)(self._compliant.unique(maintain_order=maintain_order)) + @unstable def hist( self, bins: Sequence[float] | None = None, @@ -267,7 +269,20 @@ def hist( # NOTE: `pl.Series.hist` defaults are the opposite of `pl.Expr.hist` include_breakpoint: bool = True, include_category: bool = False, # NOTE: `pl.Series.hist` default is `True`, but that would be breaking (ish) for narwhals + _use_current_polars_behavior: bool = False, ) -> DataFrame[Incomplete, NativeSeriesT_co]: + """Well ... + + `_use_current_polars_behavior` would preserve the series name, in line with current `polars`: + + import polars as pl + ser = pl.Series("original_name", [0, 1, 2, 3, 4, 5, 6]) + hist = ser.hist(bin_count=4, include_breakpoint=False, include_category=False) + hist_to_dict(as_series=False) + {'original_name': [2, 2, 1, 2]} + + But all of our tests expect `"count"` as the name 🤔 + """ from narwhals._plan import functions as F result = ( @@ -281,12 +296,32 @@ def hist( ) ) .to_series() + .struct.unnest() ) - if not include_breakpoint and not include_category: - return result.to_frame() - msg = f"`Series.hist({include_breakpoint=}, {include_category=})` requires `Series.struct.unnest`" - raise NotImplementedError(msg) - return result.struct.unnest() + + if ( + not include_breakpoint + and not include_category + and _use_current_polars_behavior + ): + return result.rename({"count": self.name}) + return result + + @property + def struct(self) -> SeriesStructNamespace[Self]: + return SeriesStructNamespace(self) + + +class SeriesStructNamespace(Generic[SeriesT]): + def __init__(self, series: SeriesT) -> None: + self._series: SeriesT = series + + def unnest(self) -> DataFrame[Any, Any]: + """Convert this struct Series to a DataFrame with a separate column for each field.""" + result: DataFrame[Any, Any] = ( + self._series._compliant.struct.unnest().to_narwhals() + ) + return result class SeriesV1(Series[NativeSeriesT_co]): diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index f8d212432a..ad9c2881c2 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -129,6 +129,7 @@ OneOrIterable: TypeAlias = "T | Iterable[T]" OneOrSeq: TypeAlias = t.Union[T, Seq[T]] DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") +SeriesT = TypeVar("SeriesT", bound="Series[t.Any]") Order: TypeAlias = t.Literal["ascending", "descending"] NonCrossJoinStrategy: TypeAlias = t.Literal["inner", "left", "full", "semi", "anti"] PartialSeries: TypeAlias = "Callable[[Iterable[t.Any]], Series[NativeSeriesAnyT]]" From 36ff66ebfd6910f0457c826c56bea31428c321ca Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 18:47:54 +0000 Subject: [PATCH 135/215] kinda fix `struct` typing it is bad, but I only found out this even exists from `cpp` docs https://arrow.apache.org/docs/cpp/compute.html#structural-transforms --- narwhals/_plan/arrow/dataframe.py | 2 +- narwhals/_plan/arrow/functions.py | 32 ++++++++++++++++++++++--------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 4e2a5d3817..73d17d4140 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -134,7 +134,7 @@ def to_struct(self, name: str = "") -> Series: struct = ( native.to_struct_array() if fn.BACKEND_VERSION >= (15, 0) - else fn.struct(self.columns, values=native.columns) + else fn.struct(native.column_names, native.columns) ) return Series.from_native(struct, name, version=self.version) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 9b337da209..d91605bdde 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -16,7 +16,7 @@ floordiv_compat as _floordiv, narwhals_to_native_dtype as _dtype_native, ) -from narwhals._plan import expressions as ir +from narwhals._plan import common, expressions as ir from narwhals._plan._guards import is_non_nested_literal from narwhals._plan.arrow import options as pa_options from narwhals._plan.expressions import functions as F, operators as ops @@ -296,22 +296,36 @@ def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStr return pa.large_string() if has_large_string(data_types) else pa.string() +# NOTE: `mypy` isn't happy, but this broadcasting behavior is worth documenting @t.overload -def struct(names: Sequence[str], values: Iterable[ScalarAny]) -> pa.StructScalar: ... +def struct(names: Iterable[str], columns: Iterable[ChunkedArrayAny]) -> ChunkedStruct: ... @t.overload -def struct( - names: Sequence[str], values: Iterable[ChunkedOrScalarAny] +def struct(names: Iterable[str], columns: Iterable[ArrayAny]) -> pa.StructArray: ... +@t.overload +def struct( # type: ignore[overload-overlap] + names: Iterable[str], columns: Iterable[ScalarAny] | Iterable[NonNestedLiteral] +) -> pa.StructScalar: ... +@t.overload +def struct( # type: ignore[overload-overlap] + names: Iterable[str], columns: Iterable[ChunkedArrayAny | NonNestedLiteral] ) -> ChunkedStruct: ... @t.overload def struct( - names: Sequence[str], values: Iterable[ArrayAny | ScalarAny] + names: Iterable[str], columns: Iterable[ArrayAny | NonNestedLiteral] ) -> pa.StructArray: ... -def struct(names: Sequence[str], values: Iterable[Any]) -> Any: - """Convert `values` into a struct. +@t.overload +def struct(names: Iterable[str], columns: Iterable[ArrowAny]) -> Incomplete: ... +def struct(names: Iterable[str], columns: Iterable[Incomplete]) -> Incomplete: + """Collect columns into a struct. - The output shape will be scalar if all inputs are scalar, otherwise any scalars will be broadcast to arrays. + Arguments: + names: Names of the struct fields to create. + columns: Value(s) to collect into a struct. Scalars will will be broadcast unless all + inputs are scalar. """ - return pc.make_struct(*values, options=pc.MakeStructOptions(field_names=names)) + return pc.make_struct( + *columns, options=pc.MakeStructOptions(common.ensure_seq_str(names)) + ) @t.overload From 5bd4651e1395ca84c14ad6914f0ecd2b27d812a3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:01:16 +0000 Subject: [PATCH 136/215] test: un-xfail all but 1 `hist` test --- tests/plan/hist_test.py | 84 ++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index cacdba0078..889d6cd94c 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -1,5 +1,3 @@ -# TODO(unassigned): cudf has too many spurious failures. Report and revisit? -# Modin is too slow so is excluded. from __future__ import annotations from typing import TYPE_CHECKING, Any @@ -16,10 +14,8 @@ from narwhals.typing import EagerAllowed from tests.conftest import Data - -XFAIL_HIST_NOT_IMPLEMENTED = pytest.mark.xfail( - reason="`ArrowExpr.hist_*` is not yet implemented" -) +pytest.importorskip("pyarrow") +import pyarrow as pa @pytest.fixture(scope="module") @@ -51,28 +47,11 @@ def include_breakpoint(request: pytest.FixtureRequest) -> bool: return result -counts_and_expected = [ - { - "bin_count": 4, - "expected_bins": [0, 1.5, 3.0, 4.5, 6.0], - "expected_count": [2, 2, 1, 2], - }, - { - "bin_count": 12, - "expected_bins": [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0], - "expected_count": [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], - }, - {"bin_count": 1, "expected_bins": [0, 6], "expected_count": [7]}, - {"bin_count": 0, "expected_bins": [], "expected_count": []}, -] - - SHIFT_BINS_BY = 10 """shift bins property""" # TODO @dangotbanned: Try to avoid all this looping (3x `iter_columns` in a single test?) -@XFAIL_HIST_NOT_IMPLEMENTED @pytest.mark.parametrize( ("bins", "expected"), [ @@ -131,11 +110,43 @@ def test_hist_bin( assert_equal_data(result, expected_full) +params_params = pytest.mark.parametrize( + "params", + [ + { + "bin_count": 4, + "expected_bins": [0, 1.5, 3.0, 4.5, 6.0], + "expected_count": [2, 2, 1, 2], + }, + { + "bin_count": 12, + "expected_bins": [ + 0, + 0.5, + 1.0, + 1.5, + 2.0, + 2.5, + 3.0, + 3.5, + 4.0, + 4.5, + 5.0, + 5.5, + 6.0, + ], + "expected_count": [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + }, + {"bin_count": 1, "expected_bins": [0, 6], "expected_count": [7]}, + {"bin_count": 0, "expected_bins": [], "expected_count": []}, + ], +) + + # TODO @dangotbanned: Avoid using `del` # TODO @dangotbanned: Split up `params` # TODO @dangotbanned: Try to avoid all this looping -@XFAIL_HIST_NOT_IMPLEMENTED -@pytest.mark.parametrize("params", counts_and_expected) +@params_params def test_hist_count( data: Data, data_missing: Data, @@ -148,11 +159,9 @@ def test_hist_count( float=nwp.col("int").cast(nw.Float64) ) bin_count = params["bin_count"] - - expected_bins = params["expected_bins"] - expected = {"breakpoint": expected_bins[1:], "count": params["expected_count"]} - if not include_breakpoint: - del expected["breakpoint"] + expected = {"count": params["expected_count"]} + if include_breakpoint: + expected = {"breakpoint": params["expected_bins"][1:], **expected} # smoke tests for col in df.columns: @@ -188,7 +197,6 @@ def test_hist_count( # TODO @dangotbanned: parametrize into 3 cases -@XFAIL_HIST_NOT_IMPLEMENTED def test_hist_count_no_spread(backend: EagerAllowed) -> None: data_ = {"all_zero": [0, 0, 0], "all_non_zero": [5, 5, 5]} df = nwp.DataFrame.from_dict(data_, backend=backend) @@ -206,9 +214,18 @@ def test_hist_count_no_spread(backend: EagerAllowed) -> None: assert_equal_data(result, expected) +# TODO @dangotbanned: Fix length? # TODO @dangotbanned: parametrize into 2 cases? -@XFAIL_HIST_NOT_IMPLEMENTED -def test_hist_no_data(backend: EagerAllowed, *, include_breakpoint: bool) -> None: +def test_hist_no_data( + backend: EagerAllowed, *, include_breakpoint: bool, request: pytest.FixtureRequest +) -> None: + request.applymarker( + pytest.mark.xfail( + include_breakpoint, + reason="TODO: Investigate `Column 1 named count expected length 0 but got length 1`", + raises=pa.ArrowInvalid, + ) + ) data_: Data = {"values": []} df = nwp.DataFrame.from_dict(data_, {"values": nw.Float64()}, backend=backend) s = df.to_series() @@ -228,7 +245,6 @@ def test_hist_no_data(backend: EagerAllowed, *, include_breakpoint: bool) -> Non assert result.get_column("count").sum() == 0 -@XFAIL_HIST_NOT_IMPLEMENTED def test_hist_small_bins(backend: EagerAllowed) -> None: s = nwp.Series.from_iterable([1, 2, 3], name="values", backend=backend) result = s.hist(bins=None, bin_count=None) From d69ad77de45390e7a9928d4a81d4385cfd4dbf38 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:04:15 +0000 Subject: [PATCH 137/215] feat: rough port of `ArrowSeries.hist` - Switches over to the native `linear_space` - Still has a lot of `numpy` I want to factor out - Not planning to keep as functions --- narwhals/_plan/arrow/expr.py | 26 ++++++-- narwhals/_plan/arrow/functions.py | 99 +++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 3a8b77bc48..d70f748d8f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -701,11 +701,27 @@ def rolling_expr( result = method(s, size, min_samples=samples, center=center, ddof=ddof) return self.from_series(result) - # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_compliant/series.py#L349-L415 - # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L1060-L1076 - # - https://github.com/narwhals-dev/narwhals/blob/84ce86c618c0103cb08bc63d68a709c424da2106/narwhals/_arrow/series.py#L1130-L1215 - hist_bins = not_implemented() - hist_bin_count = not_implemented() + def hist_bins(self, node: FExpr[F.HistBins], frame: Frame, name: str) -> Self: + s = self._dispatch_expr(node.input[0], frame, name) + func = node.function + struct_data = fn.hist_with_bins( + s.native, list(func.bins), include_breakpoint=func.include_breakpoint + ) + return self.from_series( + namespace(self)._dataframe.from_dict(struct_data).to_struct(name) + ) + + def hist_bin_count( + self, node: FExpr[F.HistBinCount], frame: Frame, name: str + ) -> Self: + s = self._dispatch_expr(node.input[0], frame, name) + func = node.function + struct_data = fn.hist_with_bin_count( + s.native, func.bin_count, include_breakpoint=func.include_breakpoint + ) + return self.from_series( + namespace(self)._dataframe.from_dict(struct_data).to_struct(name) + ) # ewm_mean = not_implemented() # noqa: ERA001 @property diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d91605bdde..b81108b6f7 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1308,6 +1308,105 @@ def nulls_like(n: int, native: ArrowAny) -> ArrayAny: return result +def zeros(n: int, /) -> pa.Int64Array: + return pa.repeat(0, n) + + +def _hist_is_empty_series(native: ChunkedArrayAny) -> bool: + is_null = native.is_null(nan_is_null=True) + arr = t.cast("pa.BooleanArray", is_null.combine_chunks()) + return arr.false_count == 0 + + +def _hist_calculate_breakpoint( + arg: int | list[float], / +) -> list[float] | ChunkedArray[NumericScalar]: + bins = linear_space(0, 1, arg + 1).slice(1) if isinstance(arg, int) else arg + return bins[1:] + + +def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, list[Any]]: + return {"breakpoint": [], "count": []} if include_breakpoint else {"count": []} + + +def _hist_series_empty( + arg: int | list[float], *, include_breakpoint: bool +) -> dict[str, ChunkedOrArrayAny | list[float]]: + count = zeros(arg) if isinstance(arg, int) else zeros(len(arg) - 1) + if include_breakpoint: + return {"breakpoint": _hist_calculate_breakpoint(arg), "count": count} + return {"count": count} + + +# TODO @dangotbanned: ughhhhh +# figure out whatever this is supposed to be called +def _hist_calculate_bins( + native: ChunkedArrayAny, bin_count: int +) -> ChunkedArray[NumericScalar]: + d = pc.min_max(native) + lower, upper = d["min"].as_py(), d["max"].as_py() + if lower == upper: + lower -= 0.5 + upper += 0.5 + return linear_space(lower, upper, bin_count + 1) + + +def _hist_calculate_hist( + native: ChunkedArrayAny, + bins: list[float] | ChunkedArray[NumericScalar], + *, + include_breakpoint: bool, +) -> Mapping[str, Iterable[Any]]: + if len(bins) == 2: + is_between_bins = and_(gt_eq(native, lit(bins[0])), lt_eq(native, lit(bins[1]))) + count = sum_(is_between_bins.cast(pa.uint8())) + if include_breakpoint: + return {"breakpoint": [bins[-1]], "count": [count]} + return {"count": [count]} + # TODO @dangotbanned: replacing `np.searchsorted` + # TODO @dangotbanned: replacing `np.isin` x2 assign weirdness + # Handle multiple bins + import numpy as np # ignore-banned-import + + bin_indices = np.searchsorted(bins, native, side="left") + # lowest bin is inclusive + bin_indices = pc.if_else(pc.equal(native, lit(bins[0])), 1, bin_indices) + + # Align unique categories and counts appropriately + obs_cats, obs_counts = np.unique(bin_indices, return_counts=True) + obj_cats = np.arange(1, len(bins)) + counts = np.zeros_like(obj_cats) + counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)] + + if include_breakpoint: + return {"breakpoint": bins[1:], "count": counts} + return {"count": counts} + + +def hist_with_bins( + native: ChunkedArrayAny, bins: list[float], *, include_breakpoint: bool +) -> Mapping[str, Iterable[Any]]: + if len(bins) <= 1: + return _hist_data_empty(include_breakpoint=include_breakpoint) + if _hist_is_empty_series(native): + return _hist_series_empty(bins, include_breakpoint=include_breakpoint) + return _hist_calculate_hist(native, bins, include_breakpoint=include_breakpoint) + + +def hist_with_bin_count( + native: ChunkedArrayAny, bin_count: int, *, include_breakpoint: bool +) -> Mapping[str, Iterable[Any]]: + if bin_count == 0: + return _hist_data_empty(include_breakpoint=include_breakpoint) + if _hist_is_empty_series(native): + return _hist_series_empty(bin_count, include_breakpoint=include_breakpoint) + return _hist_calculate_hist( + native, + _hist_calculate_bins(native, bin_count), + include_breakpoint=include_breakpoint, + ) + + def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) From 36f3b683b2dc18cd04b1ed2479f628f371e7f12d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:04:27 +0000 Subject: [PATCH 138/215] cov --- narwhals/_plan/dataframe.py | 2 +- narwhals/_plan/series.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 27a237ed36..4ceb06a2ea 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -276,7 +276,7 @@ def to_dict( def to_series(self, index: int = 0) -> Series[NativeSeriesT]: return self._series(self._compliant.to_series(index)) - def to_struct(self, name: str = "") -> Series[NativeSeriesT]: + def to_struct(self, name: str = "") -> Series[NativeSeriesT]: # pragma: no cover return self._series(self._compliant.to_struct(name)) def to_polars(self) -> pl.DataFrame: diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index a24e3f2c1d..1acdc52701 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -303,7 +303,7 @@ def hist( not include_breakpoint and not include_category and _use_current_polars_behavior - ): + ): # pragma: no cover return result.rename({"count": self.name}) return result From da791a7d4451a0d3b1ebfbd98f4f40aa5d63bcae Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:25:11 +0000 Subject: [PATCH 139/215] fix: `pyarrow<15` compat --- narwhals/_plan/arrow/dataframe.py | 2 +- narwhals/_plan/arrow/functions.py | 3 +++ narwhals/_plan/arrow/series.py | 26 +++++++++++++++++--------- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 73d17d4140..375e513dfa 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -133,7 +133,7 @@ def to_struct(self, name: str = "") -> Series: native = self.native struct = ( native.to_struct_array() - if fn.BACKEND_VERSION >= (15, 0) + if fn.HAS_FROM_TO_STRUCT_ARRAY else fn.struct(native.column_names, native.columns) ) return Series.from_native(struct, name, version=self.version) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index b81108b6f7..3565ad4488 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -95,6 +95,9 @@ RANK_ACCEPTS_CHUNKED: Final = BACKEND_VERSION >= (14,) +HAS_FROM_TO_STRUCT_ARRAY: Final = BACKEND_VERSION >= (15,) +"""`pyarrow.Table.{from,to}_struct_array` added in https://github.com/apache/arrow/pull/38520""" + HAS_SCATTER: Final = BACKEND_VERSION >= (20,) """`pyarrow.compute.scatter` added in https://github.com/apache/arrow/pull/44394""" diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 35530964bc..2d60e65f0f 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -19,7 +19,7 @@ from collections.abc import Iterable import polars as pl - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan.arrow.dataframe import ArrowDataFrame as DataFrame from narwhals._plan.arrow.namespace import ArrowNamespace as Namespace @@ -33,6 +33,8 @@ _1DArray, ) +Incomplete: TypeAlias = Any + class ArrowSeries(FrameSeries["ChunkedArrayAny"], CompliantSeries["ChunkedArrayAny"]): _name: str @@ -317,14 +319,20 @@ def with_native(self, native: ChunkedArrayAny, name: str, /) -> ArrowSeries: return self.compliant.from_native(native, name, version=self.version) def unnest(self) -> DataFrame: - if len(self.native): - table = pa.Table.from_struct_array(self.native) - else: - # TODO @dangotbanned: Report empty bug upstream, no option to pass a schema to resolve the error - # `ValueError: Must pass schema, or at least one RecordBatch` - # https://github.com/apache/arrow/blob/b2e8f2505ba3eafe65a78ece6ae87fa7d0c1c133/python/pyarrow/table.pxi#L4943-L4949 - tp_struct = cast("pa.StructType", self.native.type) - table = pa.schema(tp_struct.fields).empty_table() + native = cast("pa.ChunkedArray[pa.StructScalar]", self.native) + if fn.HAS_FROM_TO_STRUCT_ARRAY: + if len(self.native): + table = pa.Table.from_struct_array(native) + else: + # TODO @dangotbanned: Report empty bug upstream, no option to pass a schema to resolve the error + # `ValueError: Must pass schema, or at least one RecordBatch` + # https://github.com/apache/arrow/blob/b2e8f2505ba3eafe65a78ece6ae87fa7d0c1c133/python/pyarrow/table.pxi#L4943-L4949 + table = pa.schema(native.type.fields).empty_table() + else: # pragma: no cover + # NOTE: Too strict, doesn't allow `Array[StructScalar]` + rec_batch: Incomplete = pa.RecordBatch.from_struct_array + batches = (rec_batch(chunk) for chunk in native.chunks) + table = pa.Table.from_batches(batches, pa.schema(native.type.fields)) return namespace(self)._dataframe.from_native(table, self.version) # name overriding *may* be wrong From 9ab3d00cd6f29d4521b96734c2268eb307d05d42 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 20:05:02 +0000 Subject: [PATCH 140/215] import --- narwhals/_plan/arrow/series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 2d60e65f0f..8a9c9c1de3 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, cast -import pyarrow as pa +import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype From 642dcb40466b63e0607d2ba30041aff5aa5ad3a3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 21:15:11 +0000 Subject: [PATCH 141/215] fix: Don't slice twice The `slice` on the left was new, I missed the other one that was there before --- narwhals/_plan/arrow/functions.py | 3 +-- tests/plan/hist_test.py | 13 +------------ 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 3565ad4488..c1d6484edf 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1324,8 +1324,7 @@ def _hist_is_empty_series(native: ChunkedArrayAny) -> bool: def _hist_calculate_breakpoint( arg: int | list[float], / ) -> list[float] | ChunkedArray[NumericScalar]: - bins = linear_space(0, 1, arg + 1).slice(1) if isinstance(arg, int) else arg - return bins[1:] + return linear_space(0, 1, arg + 1).slice(1) if isinstance(arg, int) else arg[1:] def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, list[Any]]: diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index 889d6cd94c..06848800b5 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -15,7 +15,6 @@ from tests.conftest import Data pytest.importorskip("pyarrow") -import pyarrow as pa @pytest.fixture(scope="module") @@ -214,18 +213,8 @@ def test_hist_count_no_spread(backend: EagerAllowed) -> None: assert_equal_data(result, expected) -# TODO @dangotbanned: Fix length? # TODO @dangotbanned: parametrize into 2 cases? -def test_hist_no_data( - backend: EagerAllowed, *, include_breakpoint: bool, request: pytest.FixtureRequest -) -> None: - request.applymarker( - pytest.mark.xfail( - include_breakpoint, - reason="TODO: Investigate `Column 1 named count expected length 0 but got length 1`", - raises=pa.ArrowInvalid, - ) - ) +def test_hist_no_data(backend: EagerAllowed, *, include_breakpoint: bool) -> None: data_: Data = {"values": []} df = nwp.DataFrame.from_dict(data_, {"values": nw.Float64()}, backend=backend) s = df.to_series() From 2f3a557ddd557a6df6f8f6de92540791cad281f3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 21:31:39 +0000 Subject: [PATCH 142/215] perf: Don't slice twice (again) Last one was 3x slice --- narwhals/_plan/arrow/functions.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index c1d6484edf..5362c2d0ef 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1257,9 +1257,6 @@ def linear_space( if num_samples < 0: msg = f"Number of samples, {num_samples}, must be non-negative." raise ValueError(msg) - if num_samples == 1: - msg = f"num_samples {num_samples} is not >= 2" - raise NotImplementedError(msg) if closed == "both": range_end = num_samples div = num_samples - 1 @@ -1274,12 +1271,15 @@ def linear_space( div = num_samples + 1 ca: ChunkedArray[pc.NumericScalar] = int_range(0, range_end).cast(F64) delta = float(end - start) - step = delta / div - if step == 0: - ca = truediv(ca, lit(div)) - ca = multiply(ca, lit(delta)) + if div > 0: + step = delta / div + if step == 0: + ca = truediv(ca, lit(div)) + ca = multiply(ca, lit(delta)) + else: + ca = multiply(ca, lit(step)) else: - ca = multiply(ca, lit(step)) + ca = multiply(ca, lit(delta)) if start != 0: ca = add(ca, lit(start, F64)) if closed in {"right", "none"}: @@ -1324,7 +1324,7 @@ def _hist_is_empty_series(native: ChunkedArrayAny) -> bool: def _hist_calculate_breakpoint( arg: int | list[float], / ) -> list[float] | ChunkedArray[NumericScalar]: - return linear_space(0, 1, arg + 1).slice(1) if isinstance(arg, int) else arg[1:] + return linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, list[Any]]: From e4f0f83a15808349f2f4970e5549b791671fe601 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 22:01:50 +0000 Subject: [PATCH 143/215] fix: `pyarrow<18` compat --- narwhals/_plan/arrow/functions.py | 14 ++++++++++++++ narwhals/_plan/arrow/series.py | 6 +++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 5362c2d0ef..bd7e71ad64 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -98,6 +98,9 @@ HAS_FROM_TO_STRUCT_ARRAY: Final = BACKEND_VERSION >= (15,) """`pyarrow.Table.{from,to}_struct_array` added in https://github.com/apache/arrow/pull/38520""" +HAS_STRUCT_TYPE_FIELDS: Final = BACKEND_VERSION >= (18,) +"""`pyarrow.StructType.fields` added in https://github.com/apache/arrow/pull/43481""" + HAS_SCATTER: Final = BACKEND_VERSION >= (20,) """`pyarrow.compute.scatter` added in https://github.com/apache/arrow/pull/44394""" @@ -331,6 +334,13 @@ def struct(names: Iterable[str], columns: Iterable[Incomplete]) -> Incomplete: ) +def struct_schema(native: Arrow[pa.StructScalar] | pa.StructType) -> pa.Schema: + """Get the struct definition as a schema.""" + tp = native.type if _is_arrow(native) else native + fields = tp.fields if HAS_STRUCT_TYPE_FIELDS else list(tp) + return pa.schema(fields) + + @t.overload def struct_field(native: ChunkedStruct, field: Field, /) -> ChunkedArrayAny: ... @t.overload @@ -1473,6 +1483,10 @@ def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataT ) +def _is_arrow(obj: Arrow[ScalarT] | Any) -> TypeIs[Arrow[ScalarT]]: + return isinstance(obj, (pa.Scalar, pa.Array, pa.ChunkedArray)) + + def filter_arrays( predicate: ChunkedOrArray[BooleanScalar] | pc.Expression, *arrays: Unpack[Ts], diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 8a9c9c1de3..d6da36937a 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -321,18 +321,18 @@ def with_native(self, native: ChunkedArrayAny, name: str, /) -> ArrowSeries: def unnest(self) -> DataFrame: native = cast("pa.ChunkedArray[pa.StructScalar]", self.native) if fn.HAS_FROM_TO_STRUCT_ARRAY: - if len(self.native): + if len(native): table = pa.Table.from_struct_array(native) else: # TODO @dangotbanned: Report empty bug upstream, no option to pass a schema to resolve the error # `ValueError: Must pass schema, or at least one RecordBatch` # https://github.com/apache/arrow/blob/b2e8f2505ba3eafe65a78ece6ae87fa7d0c1c133/python/pyarrow/table.pxi#L4943-L4949 - table = pa.schema(native.type.fields).empty_table() + table = fn.struct_schema(native).empty_table() else: # pragma: no cover # NOTE: Too strict, doesn't allow `Array[StructScalar]` rec_batch: Incomplete = pa.RecordBatch.from_struct_array batches = (rec_batch(chunk) for chunk in native.chunks) - table = pa.Table.from_batches(batches, pa.schema(native.type.fields)) + table = pa.Table.from_batches(batches, fn.struct_schema(native)) return namespace(self)._dataframe.from_native(table, self.version) # name overriding *may* be wrong From 8886421119d567716e0ad9039958f75b73465d82 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 4 Dec 2025 22:44:42 +0000 Subject: [PATCH 144/215] fix: `pyarrow<21` compat this is getting silly now --- narwhals/_plan/arrow/dataframe.py | 14 +++++++++----- narwhals/_plan/arrow/functions.py | 3 +++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 375e513dfa..2c1bbedda6 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -131,11 +131,15 @@ def with_row_index_by( def to_struct(self, name: str = "") -> Series: native = self.native - struct = ( - native.to_struct_array() - if fn.HAS_FROM_TO_STRUCT_ARRAY - else fn.struct(native.column_names, native.columns) - ) + if fn.TO_STRUCT_ARRAY_ACCEPTS_EMPTY: + struct = native.to_struct_array() + elif fn.HAS_FROM_TO_STRUCT_ARRAY: + if len(native): + struct = native.to_struct_array() + else: + struct = fn.chunked_array([], pa.struct(native.schema)) + else: + struct = fn.struct(native.column_names, native.columns) return Series.from_native(struct, name, version=self.version) def get_column(self, name: str) -> Series: diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index bd7e71ad64..1928bba343 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -110,6 +110,9 @@ HAS_ARANGE: Final = BACKEND_VERSION >= (21,) """`pyarrow.arange` added in https://github.com/apache/arrow/pull/46778""" +TO_STRUCT_ARRAY_ACCEPTS_EMPTY: Final = BACKEND_VERSION >= (21,) +"""`pyarrow.Table.to_struct_array` fixed in https://github.com/apache/arrow/pull/46357""" + HAS_ZFILL: Final = BACKEND_VERSION >= (21,) """`pyarrow.compute.utf8_zero_fill` added in https://github.com/apache/arrow/pull/46815""" From bb1d1b06d39cf15a8a2cdde8eebc02e49fc264c5 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 12:28:51 +0000 Subject: [PATCH 145/215] chore: remove outdated comments --- narwhals/_plan/functions.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 897b63037e..334eea1e91 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -115,9 +115,6 @@ def sum(*columns: str) -> Expr: return col(columns).sum() -# TODO @dangotbanned: Support `ignore_nulls=...` -# NOTE: `polars` doesn't support yet -# Current behavior is equivalent to `ignore_nulls=False` def all_horizontal( *exprs: IntoExpr | t.Iterable[IntoExpr], ignore_nulls: bool = False ) -> Expr: @@ -129,7 +126,6 @@ def all_horizontal( ) -# TODO @dangotbanned: Support `ignore_nulls=...` def any_horizontal( *exprs: IntoExpr | t.Iterable[IntoExpr], ignore_nulls: bool = False ) -> Expr: From b1b6f85951f5df4b4df1e5e6d897179bdf9d8777 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 12:57:07 +0000 Subject: [PATCH 146/215] refactor: Reuse `is_between`, rather than re-implementing --- narwhals/_plan/arrow/functions.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 1928bba343..6bd8ef9e9d 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -85,6 +85,7 @@ IntoArrowSchema, IntoDType, NonNestedLiteral, + NumericLiteral, PythonLiteral, ) @@ -1033,12 +1034,14 @@ def replace_with_mask( def is_between( native: ChunkedOrScalar[ScalarT], - lower: ChunkedOrScalar[ScalarT], - upper: ChunkedOrScalar[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, closed: ClosedInterval, -) -> ChunkedOrScalar[pa.BooleanScalar]: +) -> ChunkedOrScalar[BooleanScalar]: fn_lhs, fn_rhs = _IS_BETWEEN[closed] - return and_(fn_lhs(native, lower), fn_rhs(native, upper)) # type: ignore[no-any-return] + low, high = (el if _is_arrow(el) else lit(el) for el in (lower, upper)) + out: ChunkedOrScalar[BooleanScalar] = and_(fn_lhs(native, low), fn_rhs(native, high)) + return out @t.overload @@ -1373,8 +1376,9 @@ def _hist_calculate_hist( include_breakpoint: bool, ) -> Mapping[str, Iterable[Any]]: if len(bins) == 2: - is_between_bins = and_(gt_eq(native, lit(bins[0])), lt_eq(native, lit(bins[1]))) - count = sum_(is_between_bins.cast(pa.uint8())) + # NOTE: I still don't like this summing a mask to get a count + # TODO @dangotbanned: Isn't there a compute function for this? + count = sum_(is_between(native, bins[0], bins[1], closed="both").cast(pa.uint8())) if include_breakpoint: return {"breakpoint": [bins[-1]], "count": [count]} return {"count": [count]} From 08ad7f18a260d65f397e4b6bccbc0c87e80f1247 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:10:43 +0000 Subject: [PATCH 147/215] refactor: Keep everything-but `np.searchsorted` native I've pulled out `search_sorted` so this is more visible --- narwhals/_plan/arrow/functions.py | 61 ++++++++++++++++++----- tests/plan/hist_test.py | 80 ++++++++++++++++++------------- 2 files changed, 95 insertions(+), 46 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 6bd8ef9e9d..83a3130c6f 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1369,6 +1369,42 @@ def _hist_calculate_bins( return linear_space(lower, upper, bin_count + 1) +SearchSortedSide: TypeAlias = Literal["left", "right"] + + +# TODO @dangotbanned: replacing `np.searchsorted`? +@t.overload +def search_sorted( + native: ChunkedOrArrayT, + element: ChunkedOrArray[NumericScalar] | Sequence[float], + *, + side: SearchSortedSide = ..., +) -> ChunkedOrArrayT: ... + + +# NOTE: scalar case may work with only `partition_nth_indices`? +@t.overload +def search_sorted( + native: ChunkedOrArrayT, element: float, *, side: SearchSortedSide = ... +) -> ScalarAny: ... + + +def search_sorted( + native: ChunkedOrArrayT, + element: ChunkedOrArray[NumericScalar] | Sequence[float] | float, + *, + side: SearchSortedSide = "left", +) -> ChunkedOrArrayT | ScalarAny: + import numpy as np # ignore-banned-import + + indices = np.searchsorted(element, native, side=side) + if isinstance(indices, np.generic): + return lit(indices) + if isinstance(native, pa.ChunkedArray): + return chunked_array([indices]) + return array(indices) + + def _hist_calculate_hist( native: ChunkedArrayAny, bins: list[float] | ChunkedArray[NumericScalar], @@ -1382,20 +1418,21 @@ def _hist_calculate_hist( if include_breakpoint: return {"breakpoint": [bins[-1]], "count": [count]} return {"count": [count]} - # TODO @dangotbanned: replacing `np.searchsorted` - # TODO @dangotbanned: replacing `np.isin` x2 assign weirdness - # Handle multiple bins - import numpy as np # ignore-banned-import - bin_indices = np.searchsorted(bins, native, side="left") # lowest bin is inclusive - bin_indices = pc.if_else(pc.equal(native, lit(bins[0])), 1, bin_indices) - - # Align unique categories and counts appropriately - obs_cats, obs_counts = np.unique(bin_indices, return_counts=True) - obj_cats = np.arange(1, len(bins)) - counts = np.zeros_like(obj_cats) - counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)] + # NOTE: `np.unique` behavior sorts first + value_counts = ( + when_then(not_eq(native, lit(bins[0])), search_sorted(native, bins), 1) + .sort() + .value_counts() + ) + values, counts = struct_fields(value_counts, "values", "counts") + bin_count = len(bins) + # TODO @dangotbanned: I'd still like to do this in less steps, but it is *more* native + int_range_ = int_range(1, bin_count, chunked=False) + mask = is_in(int_range_, values) + replacements = counts.filter(is_in(values, int_range_)) + counts = replace_with_mask(zeros(bin_count - 1), mask, replacements) if include_breakpoint: return {"breakpoint": bins[1:], "count": counts} diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index 06848800b5..4f40607c75 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -54,15 +54,16 @@ def include_breakpoint(request: pytest.FixtureRequest) -> bool: @pytest.mark.parametrize( ("bins", "expected"), [ - ([-float("inf"), 2.5, 5.5, float("inf")], [3, 3, 1]), - ([1.0, 2.5, 5.5, float("inf")], [2, 3, 1]), - ([1.0, 2.5, 5.5], [2, 3]), - ([-10.0, -1.0, 2.5, 5.5], [0, 3, 3]), - ([1.0, 2.0625], [2]), - ([1], []), - ([0, 10], [7]), + pytest.param( + [-float("inf"), 2.5, 5.5, float("inf")], [3, 3, 1], id="4_bins-neg-inf-inf" + ), + pytest.param([1.0, 2.5, 5.5, float("inf")], [2, 3, 1], id="4_bins-inf"), + pytest.param([1.0, 2.5, 5.5], [2, 3], id="3_bins"), + pytest.param([-10.0, -1.0, 2.5, 5.5], [0, 3, 3], id="4_bins"), + pytest.param([1.0, 2.0625], [2], id="2_bins-1"), + pytest.param([1], [], id="1_bins"), + pytest.param([0, 10], [7], id="2_bins-2"), ], - ids=str, ) def test_hist_bin( data: Data, @@ -112,32 +113,43 @@ def test_hist_bin( params_params = pytest.mark.parametrize( "params", [ - { - "bin_count": 4, - "expected_bins": [0, 1.5, 3.0, 4.5, 6.0], - "expected_count": [2, 2, 1, 2], - }, - { - "bin_count": 12, - "expected_bins": [ - 0, - 0.5, - 1.0, - 1.5, - 2.0, - 2.5, - 3.0, - 3.5, - 4.0, - 4.5, - 5.0, - 5.5, - 6.0, - ], - "expected_count": [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], - }, - {"bin_count": 1, "expected_bins": [0, 6], "expected_count": [7]}, - {"bin_count": 0, "expected_bins": [], "expected_count": []}, + pytest.param( + { + "bin_count": 4, + "expected_bins": [0, 1.5, 3.0, 4.5, 6.0], + "expected_count": [2, 2, 1, 2], + }, + id="bin_count-4", + ), + pytest.param( + { + "bin_count": 12, + "expected_bins": [ + 0, + 0.5, + 1.0, + 1.5, + 2.0, + 2.5, + 3.0, + 3.5, + 4.0, + 4.5, + 5.0, + 5.5, + 6.0, + ], + "expected_count": [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + }, + id="bin_count-12", + ), + pytest.param( + {"bin_count": 1, "expected_bins": [0, 6], "expected_count": [7]}, + id="bin_count-1", + ), + pytest.param( + {"bin_count": 0, "expected_bins": [], "expected_count": []}, id="bin_count-0" + ), ], ) From 96e5fb52b20ff06bd0903849cb15fd3b735386bd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:32:57 +0000 Subject: [PATCH 148/215] refactor: Make `BooleanArray.{false,true}_count` easier to access --- narwhals/_plan/arrow/functions.py | 34 +++++++++++++++++++++++-------- narwhals/_plan/arrow/typing.py | 2 +- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 83a3130c6f..0f8b3ec758 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -43,6 +43,7 @@ BinOp, BooleanLengthPreserving, BooleanScalar, + BoolType, ChunkedArray, ChunkedArrayAny, ChunkedList, @@ -117,9 +118,10 @@ HAS_ZFILL: Final = BACKEND_VERSION >= (21,) """`pyarrow.compute.utf8_zero_fill` added in https://github.com/apache/arrow/pull/46815""" - +# NOTE: Common data type instances to share I64: Final = pa.int64() F64: Final = pa.float64() +BOOL: Final = pa.bool_() class MinMax(ir.AggExpr): @@ -1032,6 +1034,20 @@ def replace_with_mask( return result +@t.overload +def is_between( + native: ChunkedArray[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, + closed: ClosedInterval, +) -> ChunkedArray[BooleanScalar]: ... +@t.overload +def is_between( + native: ChunkedOrScalar[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, + closed: ClosedInterval, +) -> ChunkedOrScalar[BooleanScalar]: ... def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT] | NumericLiteral, @@ -1332,9 +1348,7 @@ def zeros(n: int, /) -> pa.Int64Array: def _hist_is_empty_series(native: ChunkedArrayAny) -> bool: - is_null = native.is_null(nan_is_null=True) - arr = t.cast("pa.BooleanArray", is_null.combine_chunks()) - return arr.false_count == 0 + return array(native.is_null(nan_is_null=True), BOOL).false_count == 0 def _hist_calculate_breakpoint( @@ -1412,9 +1426,9 @@ def _hist_calculate_hist( include_breakpoint: bool, ) -> Mapping[str, Iterable[Any]]: if len(bins) == 2: - # NOTE: I still don't like this summing a mask to get a count - # TODO @dangotbanned: Isn't there a compute function for this? - count = sum_(is_between(native, bins[0], bins[1], closed="both").cast(pa.uint8())) + count = array( + is_between(native, bins[0], bins[1], closed="both"), BOOL + ).true_count if include_breakpoint: return {"breakpoint": [bins[-1]], "count": [count]} return {"count": [count]} @@ -1470,6 +1484,8 @@ def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: @overload def array(data: ArrowAny, /) -> ArrayAny: ... @overload +def array(data: Arrow[BooleanScalar], dtype: BoolType, /) -> pa.BooleanArray: ... +@overload def array( data: Iterable[PythonLiteral], dtype: DataType | None = None, / ) -> ArrayAny: ... @@ -1479,7 +1495,9 @@ def array( """Convert `data` into an Array instance. Note: - `dtype` is not used for existing `pyarrow` data, use `cast` instead. + `dtype` is **not used** for existing `pyarrow` data, but it can be used to signal + the concrete `Array` subclass that is returned. + To actually changed the type, use `cast` instead. """ if isinstance(data, pa.ChunkedArray): return data.combine_chunks() diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 7f8d369595..63ee251997 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -11,7 +11,7 @@ import pyarrow as pa import pyarrow.compute as pc from pyarrow.lib import ( - BoolType, + BoolType as BoolType, Date32Type, Int8Type, Int16Type, From 57c854cb70f4be50fa0e2fc1bd6f30619694defc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:52:47 +0000 Subject: [PATCH 149/215] feat: Add `Series.{cast,drop_nulls,drop_nans}` Will simplify the tests for `hist` --- narwhals/_plan/arrow/series.py | 10 +++++++++- narwhals/_plan/compliant/series.py | 2 ++ narwhals/_plan/series.py | 9 +++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index d6da36937a..b92341b19b 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -283,7 +283,8 @@ def any(self) -> bool: return fn.any_(self.native).as_py() def sum(self) -> float: - return fn.sum_(self.native).as_py() # type: ignore[no-any-return] + result: float = fn.sum_(self.native).as_py() + return result def count(self) -> int: return fn.count(self.native).as_py() @@ -291,6 +292,13 @@ def count(self) -> int: def unique(self, *, maintain_order: bool = False) -> Self: return self._with_native(self.native.unique()) + def drop_nulls(self) -> Self: + return self._with_native(self.native.drop_null()) + + def drop_nans(self) -> Self: + predicate: Incomplete = fn.is_not_nan(self.native) + return self._with_native(self.native.filter(predicate, "emit_null")) + @property def struct(self) -> SeriesStructNamespace: return SeriesStructNamespace(self) diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 5f38505d81..7a51d7802e 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -131,6 +131,8 @@ def cum_min(self, *, reverse: bool = False) -> Self: ... def cum_prod(self, *, reverse: bool = False) -> Self: ... def cum_sum(self, *, reverse: bool = False) -> Self: ... def diff(self, n: int = 1) -> Self: ... + def drop_nulls(self) -> Self: ... + def drop_nans(self) -> Self: ... def fill_nan(self, value: float | Self | None) -> Self: ... def fill_null(self, value: NonNestedLiteral | Self) -> Self: ... def fill_null_with_strategy( diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 1acdc52701..3e31e4e480 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -125,6 +125,9 @@ def __iter__(self) -> Iterator[Any]: # pragma: no cover def alias(self, name: str) -> Self: return type(self)(self._compliant.alias(name)) + def cast(self, dtype: IntoDType) -> Self: # pragma: no cover + return type(self)(self._compliant.cast(dtype)) + def __len__(self) -> int: return len(self._compliant) @@ -260,6 +263,12 @@ def count(self) -> int: def unique(self, *, maintain_order: bool = False) -> Self: # pragma: no cover return type(self)(self._compliant.unique(maintain_order=maintain_order)) + def drop_nulls(self) -> Self: # pragma: no cover + return type(self)(self._compliant.drop_nulls()) + + def drop_nans(self) -> Self: + return type(self)(self._compliant.drop_nans()) + @unstable def hist( self, From bcea3fb9f13e1a07d4d10c1b4ce2721c08862d0b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 21:10:59 +0000 Subject: [PATCH 150/215] test: Simplify, parallelize `hist(bin_count=...)` Rather than 1 monster test per backend, these are now 48 very short ones --- narwhals/_plan/series.py | 8 +-- tests/plan/hist_test.py | 145 ++++++++++++++++++--------------------- 2 files changed, 72 insertions(+), 81 deletions(-) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 3e31e4e480..fa689e3174 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -191,10 +191,10 @@ def scatter( def is_in(self, other: Iterable[Any]) -> Self: return type(self)(self._compliant.is_in(self._parse_into_compliant(other))) - def is_nan(self) -> Self: + def is_nan(self) -> Self: # pragma: no cover return type(self)(self._compliant.is_nan()) - def is_null(self) -> Self: + def is_null(self) -> Self: # pragma: no cover return type(self)(self._compliant.is_null()) def is_not_nan(self) -> Self: # pragma: no cover @@ -237,11 +237,11 @@ def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # typ other_ = self._unwrap_compliant(other) if is_series(other) else other return type(self)(self._compliant.__eq__(other_)) - def __or__(self, other: bool | Self, /) -> Self: + def __or__(self, other: bool | Self, /) -> Self: # pragma: no cover other_ = self._unwrap_compliant(other) if is_series(other) else other return type(self)(self._compliant.__or__(other_)) - def __invert__(self) -> Self: + def __invert__(self) -> Self: # pragma: no cover return type(self)(self._compliant.__invert__()) def __add__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index 4f40607c75..cdfd89e258 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest @@ -27,11 +27,28 @@ def data() -> Data: } +@pytest.fixture(scope="module") +def schema_data() -> nw.Schema: + return nw.Schema( + { + "int": nw.Int64(), + "float": nw.Float64(), + "int_shuffled": nw.Int64(), + "float_shuffled": nw.Float64(), + } + ) + + @pytest.fixture(scope="module") def data_missing(data: Data) -> Data: return {"has_nan": [float("nan"), *data["int"]], "has_null": [None, *data["int"]]} +@pytest.fixture(scope="module") +def schema_data_missing() -> nw.Schema: + return nw.Schema({"has_nan": nw.Float64(), "has_null": nw.Int64()}) + + @pytest.fixture(scope="module", params=["pyarrow"]) def backend(request: pytest.FixtureRequest) -> EagerAllowed: result: EagerAllowed = request.param @@ -65,7 +82,7 @@ def include_breakpoint(request: pytest.FixtureRequest) -> bool: pytest.param([0, 10], [7], id="2_bins-2"), ], ) -def test_hist_bin( +def test_hist_bins( data: Data, data_missing: Data, backend: EagerAllowed, @@ -111,100 +128,74 @@ def test_hist_bin( params_params = pytest.mark.parametrize( - "params", + ("bin_count", "expected_bins", "expected_count"), [ - pytest.param( - { - "bin_count": 4, - "expected_bins": [0, 1.5, 3.0, 4.5, 6.0], - "expected_count": [2, 2, 1, 2], - }, - id="bin_count-4", - ), - pytest.param( - { - "bin_count": 12, - "expected_bins": [ - 0, - 0.5, - 1.0, - 1.5, - 2.0, - 2.5, - 3.0, - 3.5, - 4.0, - 4.5, - 5.0, - 5.5, - 6.0, - ], - "expected_count": [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], - }, - id="bin_count-12", - ), - pytest.param( - {"bin_count": 1, "expected_bins": [0, 6], "expected_count": [7]}, - id="bin_count-1", - ), - pytest.param( - {"bin_count": 0, "expected_bins": [], "expected_count": []}, id="bin_count-0" + (4, [1.5, 3.0, 4.5, 6.0], [2, 2, 1, 2]), + ( + 12, + [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0], + [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], ), + (1, [6], [7]), + (0, [], []), ], ) -# TODO @dangotbanned: Avoid using `del` -# TODO @dangotbanned: Split up `params` -# TODO @dangotbanned: Try to avoid all this looping +@pytest.mark.parametrize("column", ["int", "float", "int_shuffled", "float_shuffled"]) @params_params -def test_hist_count( +def test_hist_bin_count( data: Data, - data_missing: Data, + schema_data: nw.Schema, backend: EagerAllowed, + column: str, + bin_count: int, + expected_bins: list[float], + expected_count: list[int], *, - params: dict[str, Any], include_breakpoint: bool, ) -> None: - df = nwp.DataFrame.from_dict(data, backend=backend).with_columns( - float=nwp.col("int").cast(nw.Float64) - ) - bin_count = params["bin_count"] - expected = {"count": params["expected_count"]} + values, dtype = data[column], schema_data[column] + ser = nwp.Series.from_iterable(values, name=column, dtype=dtype, backend=backend) if include_breakpoint: - expected = {"breakpoint": params["expected_bins"][1:], **expected} + expected = {"breakpoint": expected_bins, "count": expected_count} + else: + expected = {"count": expected_count} - # smoke tests - for col in df.columns: - result = df.get_column(col).hist( - bin_count=bin_count, include_breakpoint=include_breakpoint - ) - assert_equal_data(result, expected) + result = ser.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) - # result size property + assert_equal_data(result, expected) + assert len(result) == bin_count + if bin_count > 0: + assert result.get_column("count").sum() == ser.drop_nans().count() - assert len(result) == bin_count - if bin_count > 0: - assert result.get_column("count").sum() == df.get_column(col).count() - # missing/nan results - df = nwp.DataFrame.from_dict(data_missing, backend=backend) +@pytest.mark.parametrize("column", ["has_nan", "has_null"]) +@params_params +def test_hist_bin_count_missing( + data_missing: Data, + schema_data_missing: nw.Schema, + backend: EagerAllowed, + column: str, + bin_count: int, + expected_bins: list[float], + expected_count: list[int], + *, + include_breakpoint: bool, +) -> None: + values, dtype = data_missing[column], schema_data_missing[column] + ser = nwp.Series.from_iterable(values, name=column, dtype=dtype, backend=backend) + if include_breakpoint: + expected = {"breakpoint": expected_bins, "count": expected_count} + else: + expected = {"count": expected_count} - for col in df.columns: - result = df.get_column(col).hist( - bin_count=bin_count, include_breakpoint=include_breakpoint - ) - assert_equal_data(result, expected) + result = ser.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) - # result size property - assert len(result) == bin_count - ser = df.get_column(col) - if bin_count > 0: - # NOTE: Could this just be a filter? - assert ( - result.get_column("count").sum() - == (~(ser.is_nan() | ser.is_null())).sum() - ) + assert_equal_data(result, expected) + assert len(result) == bin_count + if bin_count > 0: + assert result.get_column("count").sum() == ser.drop_nans().count() # TODO @dangotbanned: parametrize into 3 cases From a443e745f426b870f33734596a64ba394d34b5fd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 21:33:54 +0000 Subject: [PATCH 151/215] test: break up smaller `hist` stuff --- tests/plan/hist_test.py | 75 +++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index cdfd89e258..3b86f63488 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -198,46 +198,55 @@ def test_hist_bin_count_missing( assert result.get_column("count").sum() == ser.drop_nans().count() -# TODO @dangotbanned: parametrize into 3 cases -def test_hist_count_no_spread(backend: EagerAllowed) -> None: - data_ = {"all_zero": [0, 0, 0], "all_non_zero": [5, 5, 5]} - df = nwp.DataFrame.from_dict(data_, backend=backend) - - result = df.get_column("all_zero").hist(bin_count=4, include_breakpoint=True) - expected = {"breakpoint": [-0.25, 0.0, 0.25, 0.5], "count": [0, 3, 0, 0]} - assert_equal_data(result, expected) - - result = df.get_column("all_non_zero").hist(bin_count=4, include_breakpoint=True) - expected = {"breakpoint": [4.75, 5.0, 5.25, 5.5], "count": [0, 3, 0, 0]} - assert_equal_data(result, expected) - - result = df.get_column("all_zero").hist(bin_count=1, include_breakpoint=True) - expected = {"breakpoint": [0.5], "count": [3]} +@pytest.mark.parametrize( + ("column", "bin_count", "expected_breakpoint", "expected_count"), + [ + ("all_zero", 4, [-0.25, 0.0, 0.25, 0.5], [0, 3, 0, 0]), + ("all_non_zero", 4, [4.75, 5.0, 5.25, 5.5], [0, 3, 0, 0]), + ("all_zero", 1, [0.5], [3]), + ], +) +def test_hist_bin_count_no_spread( + backend: EagerAllowed, + column: str, + bin_count: int, + expected_breakpoint: list[float], + expected_count: list[int], +) -> None: + data = {"all_zero": [0, 0, 0], "all_non_zero": [5, 5, 5]} + ser = nwp.DataFrame.from_dict(data, backend=backend).get_column(column) + result = ser.hist(bin_count=bin_count, include_breakpoint=True) + expected = {"breakpoint": expected_breakpoint, "count": expected_count} assert_equal_data(result, expected) -# TODO @dangotbanned: parametrize into 2 cases? -def test_hist_no_data(backend: EagerAllowed, *, include_breakpoint: bool) -> None: - data_: Data = {"values": []} - df = nwp.DataFrame.from_dict(data_, {"values": nw.Float64()}, backend=backend) - s = df.to_series() - for bin_count in [1, 10]: - result = s.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) - assert len(result) == bin_count - assert result.get_column("count").sum() == 0 +@pytest.mark.parametrize("bins", [[1, 5, 10]]) +def test_hist_bins_no_data( + backend: EagerAllowed, bins: list[int], *, include_breakpoint: bool +) -> None: + s = nwp.Series.from_iterable([], dtype=nw.Float64(), backend=backend) + result = s.hist(bins, include_breakpoint=include_breakpoint) + assert len(result) == 2 + assert result.get_column("count").sum() == 0 - if include_breakpoint: - bps = result.get_column("breakpoint").to_list() - assert bps[0] == (1 / bin_count) - if bin_count > 1: - assert bps[-1] == 1 - result = s.hist(bins=[1, 5, 10], include_breakpoint=include_breakpoint) - assert len(result) == 2 +@pytest.mark.parametrize("bin_count", [1, 10]) +def test_hist_bin_count_no_data( + backend: EagerAllowed, bin_count: int, *, include_breakpoint: bool +) -> None: + s = nwp.Series.from_iterable([], dtype=nw.Float64(), backend=backend) + result = s.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) + assert len(result) == bin_count assert result.get_column("count").sum() == 0 + if include_breakpoint: + bps = result.get_column("breakpoint").to_list() + assert bps[0] == (1 / bin_count) + if bin_count > 1: + assert bps[-1] == 1 + -def test_hist_small_bins(backend: EagerAllowed) -> None: - s = nwp.Series.from_iterable([1, 2, 3], name="values", backend=backend) +def test_hist_bins_none(backend: EagerAllowed) -> None: + s = nwp.Series.from_iterable([1, 2, 3], backend=backend) result = s.hist(bins=None, bin_count=None) assert len(result) == 10 From 251cff58d4432a415c97460085b58cadd74521ed Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:16:08 +0000 Subject: [PATCH 152/215] tidier? --- narwhals/_plan/arrow/functions.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 0f8b3ec758..9afbb3ef89 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1351,23 +1351,19 @@ def _hist_is_empty_series(native: ChunkedArrayAny) -> bool: return array(native.is_null(nan_is_null=True), BOOL).false_count == 0 -def _hist_calculate_breakpoint( - arg: int | list[float], / -) -> list[float] | ChunkedArray[NumericScalar]: - return linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] - - +# TODO @dangotbanned: Really need to reduce repeating this breakpoint/results stuff def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, list[Any]]: return {"breakpoint": [], "count": []} if include_breakpoint else {"count": []} def _hist_series_empty( arg: int | list[float], *, include_breakpoint: bool -) -> dict[str, ChunkedOrArrayAny | list[float]]: - count = zeros(arg) if isinstance(arg, int) else zeros(len(arg) - 1) - if include_breakpoint: - return {"breakpoint": _hist_calculate_breakpoint(arg), "count": count} - return {"count": count} +) -> dict[str, Iterable[Any]]: + n = arg if isinstance(arg, int) else len(arg) - 1 + if not include_breakpoint: + return {"count": zeros(n)} + bp = linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] + return {"breakpoint": bp, "count": zeros(n)} # TODO @dangotbanned: ughhhhh @@ -1426,11 +1422,10 @@ def _hist_calculate_hist( include_breakpoint: bool, ) -> Mapping[str, Iterable[Any]]: if len(bins) == 2: - count = array( - is_between(native, bins[0], bins[1], closed="both"), BOOL - ).true_count + upper = bins[1] + count = array(is_between(native, bins[0], upper, closed="both"), BOOL).true_count if include_breakpoint: - return {"breakpoint": [bins[-1]], "count": [count]} + return {"breakpoint": [upper], "count": [count]} return {"count": [count]} # lowest bin is inclusive From 129b8e628e9eb584a17eac350bbef48b58aa730e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:27:19 +0000 Subject: [PATCH 153/215] factor-in `_hist_calculate_bins` --- narwhals/_plan/arrow/functions.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 9afbb3ef89..1f90ad2719 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1366,19 +1366,6 @@ def _hist_series_empty( return {"breakpoint": bp, "count": zeros(n)} -# TODO @dangotbanned: ughhhhh -# figure out whatever this is supposed to be called -def _hist_calculate_bins( - native: ChunkedArrayAny, bin_count: int -) -> ChunkedArray[NumericScalar]: - d = pc.min_max(native) - lower, upper = d["min"].as_py(), d["max"].as_py() - if lower == upper: - lower -= 0.5 - upper += 0.5 - return linear_space(lower, upper, bin_count + 1) - - SearchSortedSide: TypeAlias = Literal["left", "right"] @@ -1465,11 +1452,15 @@ def hist_with_bin_count( return _hist_data_empty(include_breakpoint=include_breakpoint) if _hist_is_empty_series(native): return _hist_series_empty(bin_count, include_breakpoint=include_breakpoint) - return _hist_calculate_hist( - native, - _hist_calculate_bins(native, bin_count), - include_breakpoint=include_breakpoint, - ) + + # TODO @dangotbanned: Can this be done in a more ergomomic way? + d = pc.min_max(native) + lower, upper = d["min"].as_py(), d["max"].as_py() + if lower == upper: + lower -= 0.5 # TODO @dangotbanned: What is adjustment this called? + upper += 0.5 + bins = linear_space(lower, upper, bin_count + 1) + return _hist_calculate_hist(native, bins, include_breakpoint=include_breakpoint) def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: From 1e1d2cb716b6a1c2c8d0a9ee3d42be49b9b180e8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:29:56 +0000 Subject: [PATCH 154/215] move `search_sorted` --- narwhals/_plan/arrow/functions.py | 43 ++++++++++++++----------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 1f90ad2719..f66f9fa8bc 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1347,25 +1347,6 @@ def zeros(n: int, /) -> pa.Int64Array: return pa.repeat(0, n) -def _hist_is_empty_series(native: ChunkedArrayAny) -> bool: - return array(native.is_null(nan_is_null=True), BOOL).false_count == 0 - - -# TODO @dangotbanned: Really need to reduce repeating this breakpoint/results stuff -def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, list[Any]]: - return {"breakpoint": [], "count": []} if include_breakpoint else {"count": []} - - -def _hist_series_empty( - arg: int | list[float], *, include_breakpoint: bool -) -> dict[str, Iterable[Any]]: - n = arg if isinstance(arg, int) else len(arg) - 1 - if not include_breakpoint: - return {"count": zeros(n)} - bp = linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] - return {"breakpoint": bp, "count": zeros(n)} - - SearchSortedSide: TypeAlias = Literal["left", "right"] @@ -1377,21 +1358,18 @@ def search_sorted( *, side: SearchSortedSide = ..., ) -> ChunkedOrArrayT: ... - - # NOTE: scalar case may work with only `partition_nth_indices`? @t.overload def search_sorted( native: ChunkedOrArrayT, element: float, *, side: SearchSortedSide = ... ) -> ScalarAny: ... - - def search_sorted( native: ChunkedOrArrayT, element: ChunkedOrArray[NumericScalar] | Sequence[float] | float, *, side: SearchSortedSide = "left", ) -> ChunkedOrArrayT | ScalarAny: + """Find indices where elements should be inserted to maintain order.""" import numpy as np # ignore-banned-import indices = np.searchsorted(element, native, side=side) @@ -1402,6 +1380,25 @@ def search_sorted( return array(indices) +def _hist_is_empty_series(native: ChunkedArrayAny) -> bool: + return array(native.is_null(nan_is_null=True), BOOL).false_count == 0 + + +# TODO @dangotbanned: Really need to reduce repeating this breakpoint/results stuff +def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, list[Any]]: + return {"breakpoint": [], "count": []} if include_breakpoint else {"count": []} + + +def _hist_series_empty( + arg: int | list[float], *, include_breakpoint: bool +) -> dict[str, Iterable[Any]]: + n = arg if isinstance(arg, int) else len(arg) - 1 + if not include_breakpoint: + return {"count": zeros(n)} + bp = linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] + return {"breakpoint": bp, "count": zeros(n)} + + def _hist_calculate_hist( native: ChunkedArrayAny, bins: list[float] | ChunkedArray[NumericScalar], From 31f4a7d37d68bd2f9a69edda6b1333479f8a8c6f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 23:25:38 +0000 Subject: [PATCH 155/215] refactor: `_hist_is_empty_series` -> `is_only_nulls` Slightly better (but still quite awful) name --- narwhals/_plan/arrow/functions.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index f66f9fa8bc..76c57e9a94 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -917,6 +917,12 @@ def preserve_nulls( drop_nulls = t.cast("VectorFunction[...]", pc.drop_null) + +def is_only_nulls(native: ChunkedOrArrayAny, *, nan_is_null: bool = False) -> bool: + """Return True if `native` has no non-null values (and optionally include NaN).""" + return array(native.is_null(nan_is_null=nan_is_null), BOOL).false_count == 0 + + _FILL_NULL_STRATEGY: Mapping[FillNullStrategy, UnaryFunction] = { "forward": pc.fill_null_forward, "backward": pc.fill_null_backward, @@ -1380,10 +1386,6 @@ def search_sorted( return array(indices) -def _hist_is_empty_series(native: ChunkedArrayAny) -> bool: - return array(native.is_null(nan_is_null=True), BOOL).false_count == 0 - - # TODO @dangotbanned: Really need to reduce repeating this breakpoint/results stuff def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, list[Any]]: return {"breakpoint": [], "count": []} if include_breakpoint else {"count": []} @@ -1437,7 +1439,7 @@ def hist_with_bins( ) -> Mapping[str, Iterable[Any]]: if len(bins) <= 1: return _hist_data_empty(include_breakpoint=include_breakpoint) - if _hist_is_empty_series(native): + if is_only_nulls(native, nan_is_null=True): return _hist_series_empty(bins, include_breakpoint=include_breakpoint) return _hist_calculate_hist(native, bins, include_breakpoint=include_breakpoint) @@ -1447,7 +1449,7 @@ def hist_with_bin_count( ) -> Mapping[str, Iterable[Any]]: if bin_count == 0: return _hist_data_empty(include_breakpoint=include_breakpoint) - if _hist_is_empty_series(native): + if is_only_nulls(native, nan_is_null=True): return _hist_series_empty(bin_count, include_breakpoint=include_breakpoint) # TODO @dangotbanned: Can this be done in a more ergomomic way? From 679a216d607a2af94fb925b6881dc1a96e12fa1e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 5 Dec 2025 23:37:56 +0000 Subject: [PATCH 156/215] refactor: Move `hist_with_bins` -> `ArrowExpr.hist_bins` `ArrowSeries` will **always** defer to `ArrowExpr` --- narwhals/_plan/arrow/expr.py | 19 +++++++++++-------- narwhals/_plan/arrow/functions.py | 14 ++------------ 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index d70f748d8f..44919fdade 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -702,14 +702,17 @@ def rolling_expr( return self.from_series(result) def hist_bins(self, node: FExpr[F.HistBins], frame: Frame, name: str) -> Self: - s = self._dispatch_expr(node.input[0], frame, name) - func = node.function - struct_data = fn.hist_with_bins( - s.native, list(func.bins), include_breakpoint=func.include_breakpoint - ) - return self.from_series( - namespace(self)._dataframe.from_dict(struct_data).to_struct(name) - ) + native = self._dispatch_expr(node.input[0], frame, name).native + bins = list(node.function.bins) + include = node.function.include_breakpoint + if len(bins) <= 1: + data = fn._hist_data_empty(include_breakpoint=include) + elif fn.is_only_nulls(native, nan_is_null=True): + data = fn._hist_series_empty(bins, include_breakpoint=include) + else: + data = fn._hist_calculate_hist(native, bins, include_breakpoint=include) + ns = namespace(self) + return self.from_series(ns._dataframe.from_dict(data).to_struct(name)) def hist_bin_count( self, node: FExpr[F.HistBinCount], frame: Frame, name: str diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 76c57e9a94..34570b7659 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1387,13 +1387,13 @@ def search_sorted( # TODO @dangotbanned: Really need to reduce repeating this breakpoint/results stuff -def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, list[Any]]: +def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, Iterable[Any]]: return {"breakpoint": [], "count": []} if include_breakpoint else {"count": []} def _hist_series_empty( arg: int | list[float], *, include_breakpoint: bool -) -> dict[str, Iterable[Any]]: +) -> Mapping[str, Iterable[Any]]: n = arg if isinstance(arg, int) else len(arg) - 1 if not include_breakpoint: return {"count": zeros(n)} @@ -1434,16 +1434,6 @@ def _hist_calculate_hist( return {"count": counts} -def hist_with_bins( - native: ChunkedArrayAny, bins: list[float], *, include_breakpoint: bool -) -> Mapping[str, Iterable[Any]]: - if len(bins) <= 1: - return _hist_data_empty(include_breakpoint=include_breakpoint) - if is_only_nulls(native, nan_is_null=True): - return _hist_series_empty(bins, include_breakpoint=include_breakpoint) - return _hist_calculate_hist(native, bins, include_breakpoint=include_breakpoint) - - def hist_with_bin_count( native: ChunkedArrayAny, bin_count: int, *, include_breakpoint: bool ) -> Mapping[str, Iterable[Any]]: From f756df0ca8c948ae56737df360796e0219e93334 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 13:02:41 +0000 Subject: [PATCH 157/215] refactor: Move more of `hist_with_bin_count` -> `pyarrow` --- narwhals/_plan/arrow/functions.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 34570b7659..a15d60ab9b 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1284,7 +1284,7 @@ def date_range( def linear_space( - start: int, end: int, num_samples: int, *, closed: ClosedInterval = "both" + start: float, end: float, num_samples: int, *, closed: ClosedInterval = "both" ) -> ChunkedArray[pc.NumericScalar]: """Based on [`np.linspace`]. @@ -1434,21 +1434,22 @@ def _hist_calculate_hist( return {"count": counts} +# NOTE: `Decimal` is not supported, but excluding it from the typing is surprisingly complicated +# https://docs.rs/polars-core/0.52.0/polars_core/datatypes/enum.DataType.html#method.is_primitive_numeric def hist_with_bin_count( - native: ChunkedArrayAny, bin_count: int, *, include_breakpoint: bool + native: ChunkedArray[NumericScalar], bin_count: int, *, include_breakpoint: bool ) -> Mapping[str, Iterable[Any]]: if bin_count == 0: return _hist_data_empty(include_breakpoint=include_breakpoint) if is_only_nulls(native, nan_is_null=True): return _hist_series_empty(bin_count, include_breakpoint=include_breakpoint) - - # TODO @dangotbanned: Can this be done in a more ergomomic way? - d = pc.min_max(native) - lower, upper = d["min"].as_py(), d["max"].as_py() - if lower == upper: - lower -= 0.5 # TODO @dangotbanned: What is adjustment this called? - upper += 0.5 - bins = linear_space(lower, upper, bin_count + 1) + lower: ScalarAny = min_(native) + upper: ScalarAny = max_(native) + if lower.equals(upper): + # All data points are identical - use unit interval + rhs = lit(0.5) + lower, upper = sub(lower, rhs), add(upper, rhs) + bins = linear_space(lower.as_py(), upper.as_py(), bin_count + 1) return _hist_calculate_hist(native, bins, include_breakpoint=include_breakpoint) From 6132607c79e57b9230918d4393bb932f79148269 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 13:37:47 +0000 Subject: [PATCH 158/215] refactor: More `hist` organizing --- narwhals/_plan/arrow/expr.py | 38 ++++++++++++------ narwhals/_plan/arrow/functions.py | 52 +++++++++---------------- narwhals/_plan/expressions/functions.py | 9 ++++- 3 files changed, 53 insertions(+), 46 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 44919fdade..aa5dbd2d59 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -703,28 +703,42 @@ def rolling_expr( def hist_bins(self, node: FExpr[F.HistBins], frame: Frame, name: str) -> Self: native = self._dispatch_expr(node.input[0], frame, name).native - bins = list(node.function.bins) - include = node.function.include_breakpoint + func = node.function + bins = func.bins + include = func.include_breakpoint if len(bins) <= 1: - data = fn._hist_data_empty(include_breakpoint=include) + data = func.empty_data elif fn.is_only_nulls(native, nan_is_null=True): - data = fn._hist_series_empty(bins, include_breakpoint=include) + data = fn.hist_zeroed_data(bins, include_breakpoint=include) else: - data = fn._hist_calculate_hist(native, bins, include_breakpoint=include) + data = fn.hist_bins(native, bins, include_breakpoint=include) ns = namespace(self) return self.from_series(ns._dataframe.from_dict(data).to_struct(name)) def hist_bin_count( self, node: FExpr[F.HistBinCount], frame: Frame, name: str ) -> Self: - s = self._dispatch_expr(node.input[0], frame, name) + native = self._dispatch_expr(node.input[0], frame, name).native func = node.function - struct_data = fn.hist_with_bin_count( - s.native, func.bin_count, include_breakpoint=func.include_breakpoint - ) - return self.from_series( - namespace(self)._dataframe.from_dict(struct_data).to_struct(name) - ) + bin_count = func.bin_count + include = func.include_breakpoint + if bin_count == 0: + data = func.empty_data + elif fn.is_only_nulls(native, nan_is_null=True): + data = fn.hist_zeroed_data(bin_count, include_breakpoint=include) + else: + # NOTE: `Decimal` is not supported, but excluding it from the typing is surprisingly complicated + # https://docs.rs/polars-core/0.52.0/polars_core/datatypes/enum.DataType.html#method.is_primitive_numeric + lower: NativeScalar = fn.min_(native) + upper: NativeScalar = fn.max_(native) + if lower.equals(upper): + # All data points are identical - use unit interval + rhs = fn.lit(0.5) + lower, upper = fn.sub(lower, rhs), fn.add(upper, rhs) + bins = fn.linear_space(lower.as_py(), upper.as_py(), bin_count + 1) + data = fn.hist_bins(native, bins, include_breakpoint=include) + ns = namespace(self) + return self.from_series(ns._dataframe.from_dict(data).to_struct(name)) # ewm_mean = not_implemented() # noqa: ERA001 @property diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index a15d60ab9b..6a2dfb838b 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1386,27 +1386,21 @@ def search_sorted( return array(indices) -# TODO @dangotbanned: Really need to reduce repeating this breakpoint/results stuff -def _hist_data_empty(*, include_breakpoint: bool) -> Mapping[str, Iterable[Any]]: - return {"breakpoint": [], "count": []} if include_breakpoint else {"count": []} - - -def _hist_series_empty( - arg: int | list[float], *, include_breakpoint: bool -) -> Mapping[str, Iterable[Any]]: - n = arg if isinstance(arg, int) else len(arg) - 1 - if not include_breakpoint: - return {"count": zeros(n)} - bp = linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] - return {"breakpoint": bp, "count": zeros(n)} - - -def _hist_calculate_hist( +def hist_bins( native: ChunkedArrayAny, - bins: list[float] | ChunkedArray[NumericScalar], + bins: Sequence[float] | ChunkedArray[NumericScalar], *, include_breakpoint: bool, ) -> Mapping[str, Iterable[Any]]: + """Bin values into buckets and count their occurrences. + + Notes: + Assumes that the following edge cases have been handled: + - `len(bins) >= 2` + - `bins` increase monotonically + - `bin[0] != bin[-1]` + - `native` contains values that are non-null (including NaN) + """ if len(bins) == 2: upper = bins[1] count = array(is_between(native, bins[0], upper, closed="both"), BOOL).true_count @@ -1434,23 +1428,15 @@ def _hist_calculate_hist( return {"count": counts} -# NOTE: `Decimal` is not supported, but excluding it from the typing is surprisingly complicated -# https://docs.rs/polars-core/0.52.0/polars_core/datatypes/enum.DataType.html#method.is_primitive_numeric -def hist_with_bin_count( - native: ChunkedArray[NumericScalar], bin_count: int, *, include_breakpoint: bool +def hist_zeroed_data( + arg: int | Sequence[float], *, include_breakpoint: bool ) -> Mapping[str, Iterable[Any]]: - if bin_count == 0: - return _hist_data_empty(include_breakpoint=include_breakpoint) - if is_only_nulls(native, nan_is_null=True): - return _hist_series_empty(bin_count, include_breakpoint=include_breakpoint) - lower: ScalarAny = min_(native) - upper: ScalarAny = max_(native) - if lower.equals(upper): - # All data points are identical - use unit interval - rhs = lit(0.5) - lower, upper = sub(lower, rhs), add(upper, rhs) - bins = linear_space(lower.as_py(), upper.as_py(), bin_count + 1) - return _hist_calculate_hist(native, bins, include_breakpoint=include_breakpoint) + # NOTE: If adding `linear_space` and `zeros` to `CompliantNamespace`, consider moving this. + n = arg if isinstance(arg, int) else len(arg) - 1 + if not include_breakpoint: + return {"count": zeros(n)} + bp = linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] + return {"breakpoint": bp, "count": zeros(n)} def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index a87eb56ccd..d919448baf 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -9,7 +9,7 @@ from narwhals._plan.options import FunctionFlags, FunctionOptions if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Mapping from typing import Any from _typeshed import ConvertibleToInt @@ -108,6 +108,13 @@ def from_bin_count( ) -> HistBinCount: return HistBinCount(bin_count=int(count), include_breakpoint=include_breakpoint) + @property + def empty_data(self) -> Mapping[str, Iterable[Any]]: + # NOTE: May need to adapt for `include_category`? + return ( + {"breakpoint": [], "count": []} if self.include_breakpoint else {"count": []} + ) + class HistBins(Hist): __slots__ = ("bins",) From 3e835931f33b7aa1aa743601e7fc5af93da083e2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:29:39 +0000 Subject: [PATCH 159/215] stop fiddling w/ `hist` impl (for now) --- narwhals/_plan/arrow/functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 6a2dfb838b..d69787110d 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1356,7 +1356,7 @@ def zeros(n: int, /) -> pa.Int64Array: SearchSortedSide: TypeAlias = Literal["left", "right"] -# TODO @dangotbanned: replacing `np.searchsorted`? +# NOTE @dangotbanned: (wish) replacing `np.searchsorted`? @t.overload def search_sorted( native: ChunkedOrArrayT, @@ -1417,7 +1417,6 @@ def hist_bins( ) values, counts = struct_fields(value_counts, "values", "counts") bin_count = len(bins) - # TODO @dangotbanned: I'd still like to do this in less steps, but it is *more* native int_range_ = int_range(1, bin_count, chunked=False) mask = is_in(int_range_, values) replacements = counts.filter(is_in(values, int_range_)) From 96cb095a7b1d84aad03da3e0449875ab3e0acd46 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 16:03:10 +0000 Subject: [PATCH 160/215] test: Split out more tools for test gen Mainly just `test_hist_bins` left now --- tests/plan/hist_test.py | 109 ++++++++++++++++++++++------------------ 1 file changed, 60 insertions(+), 49 deletions(-) diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index 3b86f63488..4b1997ac8f 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest @@ -39,6 +39,12 @@ def schema_data() -> nw.Schema: ) +@pytest.fixture(scope="module", params=["int", "float", "int_shuffled", "float_shuffled"]) +def column_data(request: pytest.FixtureRequest) -> str: + result: str = request.param + return result + + @pytest.fixture(scope="module") def data_missing(data: Data) -> Data: return {"has_nan": [float("nan"), *data["int"]], "has_null": [None, *data["int"]]} @@ -49,27 +55,46 @@ def schema_data_missing() -> nw.Schema: return nw.Schema({"has_nan": nw.Float64(), "has_null": nw.Int64()}) +@pytest.fixture(scope="module", params=["has_nan", "has_null"]) +def column_data_missing(request: pytest.FixtureRequest) -> str: + result: str = request.param + return result + + @pytest.fixture(scope="module", params=["pyarrow"]) def backend(request: pytest.FixtureRequest) -> EagerAllowed: result: EagerAllowed = request.param return result -@pytest.fixture( - scope="module", params=[True, False], ids=["breakpoint-True", "breakpoint-False"] -) +@pytest.fixture(scope="module", params=[True, False]) def include_breakpoint(request: pytest.FixtureRequest) -> bool: result: bool = request.param return result +def _series( + name: str, source: Data, schema: nw.Schema, backend: EagerAllowed, / +) -> nwp.Series[Any]: + values, dtype = (source[name], schema[name]) + return nwp.Series.from_iterable(values, name=name, dtype=dtype, backend=backend) + + +def _expected( + bins: Sequence[float], count: Sequence[int], *, include_breakpoint: bool +) -> dict[str, Any]: + if not include_breakpoint: + return {"count": count} + return {"breakpoint": bins[1:] if len(bins) > len(count) else bins, "count": count} + + SHIFT_BINS_BY = 10 """shift bins property""" # TODO @dangotbanned: Try to avoid all this looping (3x `iter_columns` in a single test?) @pytest.mark.parametrize( - ("bins", "expected"), + ("bins", "expected_count"), [ pytest.param( [-float("inf"), 2.5, 5.5, float("inf")], [3, 3, 1], id="4_bins-neg-inf-inf" @@ -86,48 +111,44 @@ def test_hist_bins( data: Data, data_missing: Data, backend: EagerAllowed, - bins: list[float], - expected: Sequence[float], + bins: Sequence[float], + expected_count: Sequence[int], *, include_breakpoint: bool, ) -> None: df = nwp.DataFrame.from_dict(data, backend=backend).with_columns( float=nwp.col("int").cast(nw.Float64) ) - expected_full = {"count": expected} - if include_breakpoint: - expected_full = {"breakpoint": bins[1:], **expected_full} + expected = _expected(bins, expected_count, include_breakpoint=include_breakpoint) + # smoke tests for series in df.iter_columns(): result = series.hist(bins=bins, include_breakpoint=include_breakpoint) - assert_equal_data(result, expected_full) + assert_equal_data(result, expected) # result size property assert len(result) == max(len(bins) - 1, 0) # shift bins property shifted_bins = [b + SHIFT_BINS_BY for b in bins] - expected_full = {"count": expected} - if include_breakpoint: - expected_full = {"breakpoint": shifted_bins[1:], **expected_full} - + expected = _expected( + shifted_bins, expected_count, include_breakpoint=include_breakpoint + ) for series in df.iter_columns(): result = (series + SHIFT_BINS_BY).hist( shifted_bins, include_breakpoint=include_breakpoint ) - assert_equal_data(result, expected_full) + assert_equal_data(result, expected) # missing/nan results df = nwp.DataFrame.from_dict(data_missing, backend=backend) - expected_full = {"count": expected} - if include_breakpoint: - expected_full = {"breakpoint": bins[1:], **expected_full} + expected = _expected(bins, expected_count, include_breakpoint=include_breakpoint) for series in df.iter_columns(): result = series.hist(bins, include_breakpoint=include_breakpoint) - assert_equal_data(result, expected_full) + assert_equal_data(result, expected) -params_params = pytest.mark.parametrize( +bin_count_cases = pytest.mark.parametrize( ("bin_count", "expected_bins", "expected_count"), [ (4, [1.5, 3.0, 4.5, 6.0], [2, 2, 1, 2]), @@ -142,56 +163,46 @@ def test_hist_bins( ) -@pytest.mark.parametrize("column", ["int", "float", "int_shuffled", "float_shuffled"]) -@params_params +@bin_count_cases def test_hist_bin_count( data: Data, schema_data: nw.Schema, backend: EagerAllowed, - column: str, + column_data: str, bin_count: int, - expected_bins: list[float], - expected_count: list[int], + expected_bins: Sequence[float], + expected_count: Sequence[int], *, include_breakpoint: bool, ) -> None: - values, dtype = data[column], schema_data[column] - ser = nwp.Series.from_iterable(values, name=column, dtype=dtype, backend=backend) - if include_breakpoint: - expected = {"breakpoint": expected_bins, "count": expected_count} - else: - expected = {"count": expected_count} - + ser = _series(column_data, data, schema_data, backend) + expected = _expected( + expected_bins, expected_count, include_breakpoint=include_breakpoint + ) result = ser.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) - assert_equal_data(result, expected) assert len(result) == bin_count if bin_count > 0: assert result.get_column("count").sum() == ser.drop_nans().count() -@pytest.mark.parametrize("column", ["has_nan", "has_null"]) -@params_params +@bin_count_cases def test_hist_bin_count_missing( data_missing: Data, schema_data_missing: nw.Schema, backend: EagerAllowed, - column: str, + column_data_missing: str, bin_count: int, - expected_bins: list[float], - expected_count: list[int], + expected_bins: Sequence[float], + expected_count: Sequence[int], *, include_breakpoint: bool, ) -> None: - values, dtype = data_missing[column], schema_data_missing[column] - ser = nwp.Series.from_iterable(values, name=column, dtype=dtype, backend=backend) - if include_breakpoint: - expected = {"breakpoint": expected_bins, "count": expected_count} - else: - expected = {"count": expected_count} - + ser = _series(column_data_missing, data_missing, schema_data_missing, backend) + expected = _expected( + expected_bins, expected_count, include_breakpoint=include_breakpoint + ) result = ser.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) - assert_equal_data(result, expected) assert len(result) == bin_count if bin_count > 0: @@ -210,8 +221,8 @@ def test_hist_bin_count_no_spread( backend: EagerAllowed, column: str, bin_count: int, - expected_breakpoint: list[float], - expected_count: list[int], + expected_breakpoint: Sequence[float], + expected_count: Sequence[int], ) -> None: data = {"all_zero": [0, 0, 0], "all_non_zero": [5, 5, 5]} ser = nwp.DataFrame.from_dict(data, backend=backend).get_column(column) From 6f74bf9a2f694b8f38e9387e1724daa5b9acdf40 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:18:45 +0000 Subject: [PATCH 161/215] test: Split up `test_hist_bins` --- tests/plan/hist_test.py | 66 +++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index 4b1997ac8f..48c8400002 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -91,9 +91,7 @@ def _expected( SHIFT_BINS_BY = 10 """shift bins property""" - -# TODO @dangotbanned: Try to avoid all this looping (3x `iter_columns` in a single test?) -@pytest.mark.parametrize( +bins_cases = pytest.mark.parametrize( ("bins", "expected_count"), [ pytest.param( @@ -107,45 +105,61 @@ def _expected( pytest.param([0, 10], [7], id="2_bins-2"), ], ) + + +@bins_cases def test_hist_bins( data: Data, - data_missing: Data, + schema_data: nw.Schema, backend: EagerAllowed, + column_data: str, bins: Sequence[float], expected_count: Sequence[int], *, include_breakpoint: bool, ) -> None: - df = nwp.DataFrame.from_dict(data, backend=backend).with_columns( - float=nwp.col("int").cast(nw.Float64) - ) + ser = _series(column_data, data, schema_data, backend) expected = _expected(bins, expected_count, include_breakpoint=include_breakpoint) + result = ser.hist(bins, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected) + assert len(result) == max(len(bins) - 1, 0) - # smoke tests - for series in df.iter_columns(): - result = series.hist(bins=bins, include_breakpoint=include_breakpoint) - assert_equal_data(result, expected) - - # result size property - assert len(result) == max(len(bins) - 1, 0) - # shift bins property +@bins_cases +def test_hist_bins_shifted( + data: Data, + schema_data: nw.Schema, + backend: EagerAllowed, + column_data: str, + bins: Sequence[float], + expected_count: Sequence[int], + *, + include_breakpoint: bool, +) -> None: shifted_bins = [b + SHIFT_BINS_BY for b in bins] expected = _expected( shifted_bins, expected_count, include_breakpoint=include_breakpoint ) - for series in df.iter_columns(): - result = (series + SHIFT_BINS_BY).hist( - shifted_bins, include_breakpoint=include_breakpoint - ) - assert_equal_data(result, expected) - - # missing/nan results - df = nwp.DataFrame.from_dict(data_missing, backend=backend) + ser = _series(column_data, data, schema_data, backend) + SHIFT_BINS_BY + result = ser.hist(shifted_bins, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected) + + +@bins_cases +def test_hist_bins_missing( + data_missing: Data, + schema_data_missing: nw.Schema, + backend: EagerAllowed, + column_data_missing: str, + bins: Sequence[float], + expected_count: Sequence[int], + *, + include_breakpoint: bool, +) -> None: + ser = _series(column_data_missing, data_missing, schema_data_missing, backend) expected = _expected(bins, expected_count, include_breakpoint=include_breakpoint) - for series in df.iter_columns(): - result = series.hist(bins, include_breakpoint=include_breakpoint) - assert_equal_data(result, expected) + result = ser.hist(bins, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected) bin_count_cases = pytest.mark.parametrize( From 4af7b1077712ff6cb8aa19a8d5b9686710459e51 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 18:32:05 +0000 Subject: [PATCH 162/215] chore: re-cover things Most of this was added to support `hist` tests, but they now have fewer dependencies --- narwhals/_plan/arrow/series.py | 4 +- narwhals/_plan/dataframe.py | 35 ++++++++++++++--- narwhals/_plan/series.py | 35 +++++++++-------- narwhals/_plan/typing.py | 1 + tests/plan/compliant_test.py | 69 +++++++++++++++++++++++++++++++++- tests/plan/utils.py | 19 ++++++++-- 6 files changed, 137 insertions(+), 26 deletions(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index b92341b19b..72e01a4fb1 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -297,7 +297,9 @@ def drop_nulls(self) -> Self: def drop_nans(self) -> Self: predicate: Incomplete = fn.is_not_nan(self.native) - return self._with_native(self.native.filter(predicate, "emit_null")) + return self._with_native( + self.native.filter(predicate, null_selection_behavior="emit_null") + ) @property def struct(self) -> SeriesStructNamespace: diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 4ceb06a2ea..7455ad8cc5 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -4,6 +4,7 @@ from narwhals._plan import _parse from narwhals._plan._expansion import expand_selector_irs_names, prepare_projection +from narwhals._plan._guards import is_series from narwhals._plan.common import ensure_seq_str, temp from narwhals._plan.group_by import GroupBy, Grouped from narwhals._plan.options import SortMultipleOptions @@ -16,6 +17,7 @@ NativeDataFrameT_co, NativeFrameT_co, NativeSeriesT, + NativeSeriesT2, NonCrossJoinStrategy, OneOrIterable, PartialSeries, @@ -34,6 +36,7 @@ import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs + from narwhals._native import NativeSeries from narwhals._plan.arrow.typing import NativeArrowDataFrame from narwhals._plan.compliant.dataframe import ( CompliantDataFrame, @@ -84,7 +87,7 @@ def __init__(self, compliant: CompliantFrame[Any, NativeFrameT_co], /) -> None: def _with_compliant(self, compliant: CompliantFrame[Any, Incomplete], /) -> Self: return type(self)(compliant) - def to_native(self) -> NativeFrameT_co: # pragma: no cover + def to_native(self) -> NativeFrameT_co: return self._compliant.native def filter( @@ -232,10 +235,17 @@ def from_dict( def from_dict( cls: type[DataFrame[Any, Any]], data: Mapping[str, Any], - schema: IntoSchema | None = None, + schema: IntoSchema | None = ..., *, - backend: IntoBackend[EagerAllowed] | None = ..., + backend: IntoBackend[EagerAllowed], ) -> DataFrame[Any, Any]: ... + @overload + @classmethod + def from_dict( + cls: type[DataFrame[Any, Any]], + data: Mapping[str, Series[NativeSeriesT2]], + schema: IntoSchema | None = ..., + ) -> DataFrame[Any, NativeSeriesT2]: ... @classmethod def from_dict( cls: type[DataFrame[Any, Any]], @@ -247,8 +257,23 @@ def from_dict( from narwhals._plan import functions as F if backend is None: - msg = f"`from_dict({backend=})`" - raise NotImplementedError(msg) + unwrapped: dict[str, NativeSeries | Any] = {} + impl: _EagerAllowedImpl | None = backend + for k, v in data.items(): + if is_series(v): + current = v.implementation + if impl is None: + impl = current + elif current is not impl: + msg = f"All `Series` must share the same backend, but got:\n -{impl!r}\n -{current!r}" + raise NotImplementedError(msg) + unwrapped[k] = v.to_native() + else: + unwrapped[k] = v + if impl is None: + msg = "Calling `from_dict` without `backend` is only supported if all input values are already Narwhals Series" + raise TypeError(msg) + return _dataframe_from_dict(unwrapped, schema, F._eager_namespace(impl)) ns = F._eager_namespace(backend) return _dataframe_from_dict(data, schema, ns) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index fa689e3174..2ef297b84f 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -119,7 +119,8 @@ def to_list(self) -> list[Any]: def to_polars(self) -> pl.Series: return self._compliant.to_polars() - def __iter__(self) -> Iterator[Any]: # pragma: no cover + # TODO @dangotbanned: Figure out if this should be yielding `pa.Scalar` + def __iter__(self) -> Iterator[Any]: yield from self.to_native() def alias(self, name: str) -> Self: @@ -131,7 +132,7 @@ def cast(self, dtype: IntoDType) -> Self: # pragma: no cover def __len__(self) -> int: return len(self._compliant) - def gather(self, indices: SizedMultiIndexSelector[Self]) -> Self: # pragma: no cover + def gather(self, indices: SizedMultiIndexSelector[Self]) -> Self: if len(indices) == 0: return self.slice(0, 0) rows = indices._compliant if isinstance(indices, Series) else indices @@ -140,19 +141,17 @@ def gather(self, indices: SizedMultiIndexSelector[Self]) -> Self: # pragma: no def gather_every(self, n: int, offset: int = 0) -> Self: return type(self)(self._compliant.gather_every(n, offset)) - def has_nulls(self) -> bool: # pragma: no cover + def has_nulls(self) -> bool: return self._compliant.has_nulls() def slice(self, offset: int, length: int | None = None) -> Self: return type(self)(self._compliant.slice(offset=offset, length=length)) - def sort( - self, *, descending: bool = False, nulls_last: bool = False - ) -> Self: # pragma: no cover + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: result = self._compliant.sort(descending=descending, nulls_last=nulls_last) return type(self)(result) - def is_empty(self) -> bool: # pragma: no cover + def is_empty(self) -> bool: return self._compliant.is_empty() def _unwrap_compliant( @@ -191,16 +190,16 @@ def scatter( def is_in(self, other: Iterable[Any]) -> Self: return type(self)(self._compliant.is_in(self._parse_into_compliant(other))) - def is_nan(self) -> Self: # pragma: no cover + def is_nan(self) -> Self: return type(self)(self._compliant.is_nan()) - def is_null(self) -> Self: # pragma: no cover + def is_null(self) -> Self: return type(self)(self._compliant.is_null()) - def is_not_nan(self) -> Self: # pragma: no cover + def is_not_nan(self) -> Self: return type(self)(self._compliant.is_not_nan()) - def is_not_null(self) -> Self: # pragma: no cover + def is_not_null(self) -> Self: return type(self)(self._compliant.is_not_null()) def null_count(self) -> int: @@ -237,11 +236,15 @@ def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # typ other_ = self._unwrap_compliant(other) if is_series(other) else other return type(self)(self._compliant.__eq__(other_)) - def __or__(self, other: bool | Self, /) -> Self: # pragma: no cover + def __and__(self, other: bool | Self, /) -> Self: + other_ = self._unwrap_compliant(other) if is_series(other) else other + return type(self)(self._compliant.__and__(other_)) + + def __or__(self, other: bool | Self, /) -> Self: other_ = self._unwrap_compliant(other) if is_series(other) else other return type(self)(self._compliant.__or__(other_)) - def __invert__(self) -> Self: # pragma: no cover + def __invert__(self) -> Self: return type(self)(self._compliant.__invert__()) def __add__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: @@ -251,7 +254,7 @@ def __add__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: def all(self) -> bool: return self._compliant.all() - def any(self) -> bool: # pragma: no cover + def any(self) -> bool: return self._compliant.any() def sum(self) -> float: @@ -260,10 +263,10 @@ def sum(self) -> float: def count(self) -> int: return self._compliant.count() - def unique(self, *, maintain_order: bool = False) -> Self: # pragma: no cover + def unique(self, *, maintain_order: bool = False) -> Self: return type(self)(self._compliant.unique(maintain_order=maintain_order)) - def drop_nulls(self) -> Self: # pragma: no cover + def drop_nulls(self) -> Self: return type(self)(self._compliant.drop_nulls()) def drop_nans(self) -> Self: diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index ad9c2881c2..9c4a6ca927 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -92,6 +92,7 @@ "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" ) NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries") +NativeSeriesT2 = TypeVar("NativeSeriesT2", bound="NativeSeries", default="NativeSeries") NativeSeriesAnyT = TypeVar("NativeSeriesAnyT", bound="NativeSeries", default="t.Any") NativeSeriesT_co = TypeVar( "NativeSeriesT_co", bound="NativeSeries", covariant=True, default="NativeSeries" diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index d4e55cae93..bd8821260c 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -17,7 +17,14 @@ import narwhals as nw from narwhals import _plan as nwp -from tests.plan.utils import assert_equal_data, dataframe, first, last, series +from tests.plan.utils import ( + assert_equal_data, + assert_equal_series, + dataframe, + first, + last, + series, +) if TYPE_CHECKING: from collections.abc import Sequence @@ -670,6 +677,66 @@ def test_series_to_polars(values: Sequence[PythonLiteral]) -> None: pl_assert_series_equal(result, expected) +def test_dataframe_iter_columns(data_small: Data) -> None: + df = dataframe(data_small) + result = df.from_dict({s.name: s for s in df.iter_columns()}).to_dict(as_series=False) + assert_equal_data(df, result) + + +def test_dataframe_from_dict_misc(data_small: Data) -> None: + pytest.importorskip("pyarrow") + items = iter(data_small.items()) + name, values = next(items) + mapping: dict[str, Any] = { + name: nwp.Series.from_iterable(values, name=name, backend="pyarrow") + } + mapping.update(items) + result = nwp.DataFrame.from_dict(mapping) + assert_equal_data(result, data_small) + + with pytest.raises(TypeError, match=r"from_dict.+without.+backend"): + nwp.DataFrame.from_dict(data_small) # type: ignore[arg-type] + + +# TODO @dangotbanned: Split this up +def test_series_misc() -> None: + pytest.importorskip("pyarrow") + + values = [1.0, None, 7.1, float("nan"), 4.9, 12.0, 1.1, float("nan"), 0.2, None] + name = "ser" + ser = nwp.Series.from_iterable(values, name=name, dtype=nw.Float64, backend="pyarrow") + assert ser.is_empty() is False + assert ser.has_nulls() + assert ser.null_count() == 2 + + is_null = ser.is_null() + is_nan = ser.is_nan() + is_not_null = ser.is_not_null() + is_not_nan = ser.is_not_nan() + is_useful = ~(is_null | is_nan) + + assert is_useful.any() + + assert_equal_series(is_null, ~is_not_null) + assert_equal_series(~is_null, is_not_null) + assert_equal_series(is_nan, ~is_not_nan) + assert_equal_series(~is_nan, is_not_nan) + + expected = [False, None, False, False, False, False, False, False, False, None] + assert_equal_series(is_null & is_nan, expected, name) + expected = [False, True, False, False, False, False, False, False, False, True] + assert_equal_series(is_null, expected, name) + expected = [True, False, True, False, True, True, True, False, True, False] + assert_equal_series(is_not_nan & is_not_null, expected, name) + + assert ser.unique().drop_nans().drop_nulls().count() == 6 + + assert_equal_series(ser.gather([0, 2, 4]).sort(), [1.0, 4.9, 7.1], name) + assert ser.gather([]).to_list() == [] + + assert len(list(ser)) == len(values) + + if TYPE_CHECKING: from typing_extensions import assert_type diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 09acbf750f..d7295838de 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload import pytest @@ -13,11 +13,13 @@ pytest.importorskip("pyarrow") +from collections.abc import Sequence + import pyarrow as pa if TYPE_CHECKING: import sys - from collections.abc import Iterable, Mapping, Sequence + from collections.abc import Iterable, Mapping from typing_extensions import LiteralString, TypeAlias @@ -223,9 +225,20 @@ def assert_equal_data( _assert_equal_data(result.to_dict(as_series=False), expected) +@overload +def assert_equal_series(result: nwp.Series[Any], expected: nwp.Series[Any]) -> None: ... +@overload +def assert_equal_series( + result: nwp.Series[Any], expected: Iterable[Any], name: str +) -> None: ... def assert_equal_series( - result: nwp.Series[Any], expected: Sequence[Any], name: str + result: nwp.Series[Any], expected: Iterable[Any], name: str = "" ) -> None: + if isinstance(expected, nwp.Series): + name = expected.name + expected = expected.to_list() + elif not isinstance(expected, Sequence): + expected = tuple(expected) assert_equal_data(result.to_frame(), {name: expected}) From 9f3d9acbcc655a2be698a7b85bba1543c51274e8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 22:12:36 +0000 Subject: [PATCH 163/215] fix: Ensure `Expr.hist` only returns a struct when we have multiple fields --- narwhals/_plan/arrow/dataframe.py | 9 ++++-- narwhals/_plan/arrow/expr.py | 17 ++++++++--- narwhals/_plan/expr.py | 1 + narwhals/_plan/series.py | 49 +++++++++---------------------- tests/plan/hist_test.py | 27 +++++++++++++++++ tests/plan/utils.py | 4 +-- 6 files changed, 64 insertions(+), 43 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 2c1bbedda6..1c2f3e1d6c 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -74,11 +74,16 @@ def __len__(self) -> int: @classmethod def from_dict( - cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None + cls, + data: Mapping[str, Any], + /, + *, + schema: IntoSchema | None = None, + version: Version = Version.MAIN, ) -> Self: pa_schema = Schema(schema).to_arrow() if schema is not None else schema native = pa.Table.from_pydict(data, schema=pa_schema) - return cls.from_native(native, version=Version.MAIN) + return cls.from_native(native, version=version) def iter_columns(self) -> Iterator[Series]: for name, series in zip(self.columns, self.native.itercolumns()): diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index aa5dbd2d59..14e3ac893e 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -701,6 +701,17 @@ def rolling_expr( result = method(s, size, min_samples=samples, center=center, ddof=ddof) return self.from_series(result) + # NOTE: Should not be returning a struct when all `include_*` are false + # https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-ops/src/chunked_array/hist.rs#L223-L223 + def _hist_finish(self, data: Mapping[str, Any], name: str) -> Self: + ns = namespace(self) + if len(data) == 1: + count = next(iter(data.values())) + series = ns._series.from_iterable(count, version=self.version, name=name) + else: + series = ns._dataframe.from_dict(data, version=self.version).to_struct(name) + return self.from_series(series) + def hist_bins(self, node: FExpr[F.HistBins], frame: Frame, name: str) -> Self: native = self._dispatch_expr(node.input[0], frame, name).native func = node.function @@ -712,8 +723,7 @@ def hist_bins(self, node: FExpr[F.HistBins], frame: Frame, name: str) -> Self: data = fn.hist_zeroed_data(bins, include_breakpoint=include) else: data = fn.hist_bins(native, bins, include_breakpoint=include) - ns = namespace(self) - return self.from_series(ns._dataframe.from_dict(data).to_struct(name)) + return self._hist_finish(data, name) def hist_bin_count( self, node: FExpr[F.HistBinCount], frame: Frame, name: str @@ -737,8 +747,7 @@ def hist_bin_count( lower, upper = fn.sub(lower, rhs), fn.add(upper, rhs) bins = fn.linear_space(lower.as_py(), upper.as_py(), bin_count + 1) data = fn.hist_bins(native, bins, include_breakpoint=include) - ns = namespace(self) - return self.from_series(ns._dataframe.from_dict(data).to_struct(name)) + return self._hist_finish(data, name) # ewm_mean = not_implemented() # noqa: ERA001 @property diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 540ba60d0a..954601b982 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -190,6 +190,7 @@ def _with_unary(self, function: Function, /) -> Self: def abs(self) -> Self: return self._with_unary(F.Abs()) + # TODO @dangotbanned: Change the default to `False`, and update tests def hist( self, bins: Sequence[float] | None = None, diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 2ef297b84f..3b6407a0da 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Generic +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal from narwhals._plan._guards import is_series from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co, OneOrIterable, SeriesT @@ -278,46 +278,25 @@ def hist( bins: Sequence[float] | None = None, *, bin_count: int | None = None, - # NOTE: `pl.Series.hist` defaults are the opposite of `pl.Expr.hist` include_breakpoint: bool = True, - include_category: bool = False, # NOTE: `pl.Series.hist` default is `True`, but that would be breaking (ish) for narwhals - _use_current_polars_behavior: bool = False, + include_category: bool = False, + _compatibility_behavior: Literal["narwhals", "polars"] = "narwhals", ) -> DataFrame[Incomplete, NativeSeriesT_co]: - """Well ... - - `_use_current_polars_behavior` would preserve the series name, in line with current `polars`: - - import polars as pl - ser = pl.Series("original_name", [0, 1, 2, 3, 4, 5, 6]) - hist = ser.hist(bin_count=4, include_breakpoint=False, include_category=False) - hist_to_dict(as_series=False) - {'original_name': [2, 2, 1, 2]} - - But all of our tests expect `"count"` as the name 🤔 - """ from narwhals._plan import functions as F - result = ( - self.to_frame() - .select( - F.col(self.name).hist( - bins, - bin_count=bin_count, - include_breakpoint=include_breakpoint, - include_category=include_category, - ) + result = self.to_frame().select( + F.col(self.name).hist( + bins, + bin_count=bin_count, + include_breakpoint=include_breakpoint, + include_category=include_category, ) - .to_series() - .struct.unnest() ) - - if ( - not include_breakpoint - and not include_category - and _use_current_polars_behavior - ): # pragma: no cover - return result.rename({"count": self.name}) - return result + if not include_breakpoint and not include_category: + if _compatibility_behavior == "narwhals": + result = result.rename({self.name: "count"}) + return result + return result.to_series().struct.unnest() @property def struct(self) -> SeriesStructNamespace[Self]: diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index 48c8400002..c3456f1864 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -275,3 +275,30 @@ def test_hist_bins_none(backend: EagerAllowed) -> None: s = nwp.Series.from_iterable([1, 2, 3], backend=backend) result = s.hist(bins=None, bin_count=None) assert len(result) == 10 + + +def test_hist_series_compat_flag(backend: EagerAllowed) -> None: + # NOTE: Mainly for verifying `Expr.hist` has handled naming/collecting as struct + # The flag itself is not desirable + values = [1, 3, 8, 8, 2, 1, 3] + s = nwp.Series.from_iterable(values, name="original", backend=backend) + + result = s.hist( + bin_count=4, + include_breakpoint=False, + include_category=False, + _compatibility_behavior="narwhals", + ) + assert_equal_data(result, {"count": [3, 2, 0, 2]}) + + result = s.hist( + bin_count=4, + include_breakpoint=False, + include_category=False, + _compatibility_behavior="polars", + ) + assert_equal_data(result, {"original": [3, 2, 0, 2]}) + + result = s.hist(bin_count=4, include_breakpoint=True, include_category=False) + expected = {"breakpoint": [2.75, 4.5, 6.25, 8.0], "count": [3, 2, 0, 2]} + assert_equal_data(result, expected) diff --git a/tests/plan/utils.py b/tests/plan/utils.py index d7295838de..70bf35541d 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -237,8 +237,8 @@ def assert_equal_series( if isinstance(expected, nwp.Series): name = expected.name expected = expected.to_list() - elif not isinstance(expected, Sequence): - expected = tuple(expected) + else: + expected = expected if isinstance(expected, Sequence) else tuple(expected) assert_equal_data(result.to_frame(), {name: expected}) From 7ad60dfe58f8471c6522cd71c1650a25dd579b61 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 22:21:34 +0000 Subject: [PATCH 164/215] chore: Align `Expr.hist` defaults with `pl.Expr.hist` --- narwhals/_plan/expr.py | 3 +-- narwhals/_plan/expressions/functions.py | 6 ++---- tests/plan/expr_parsing_test.py | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 954601b982..05d19b0180 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -190,13 +190,12 @@ def _with_unary(self, function: Function, /) -> Self: def abs(self) -> Self: return self._with_unary(F.Abs()) - # TODO @dangotbanned: Change the default to `False`, and update tests def hist( self, bins: Sequence[float] | None = None, *, bin_count: int | None = None, - include_breakpoint: bool = True, # NOTE: `pl.Expr.hist` default is `False` + include_breakpoint: bool = False, include_category: bool = False, ) -> Self: if include_category: diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index d919448baf..89ae663241 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -81,8 +81,6 @@ class MeanHorizontal(HorizontalFunction): ... class Coalesce(HorizontalFunction): ... # fmt: on class Hist(Function): - """Only supported for `Series` so far.""" - __slots__ = ("include_breakpoint",) include_breakpoint: bool @@ -94,7 +92,7 @@ def __repr__(self) -> str: # They're also more widely defined to what will work at runtime @staticmethod def from_bins( - bins: Iterable[float], /, *, include_breakpoint: bool = True + bins: Iterable[float], /, *, include_breakpoint: bool = False ) -> HistBins: bins = tuple(bins) for i in range(1, len(bins)): @@ -104,7 +102,7 @@ def from_bins( @staticmethod def from_bin_count( - count: ConvertibleToInt = 10, /, *, include_breakpoint: bool = True + count: ConvertibleToInt = 10, /, *, include_breakpoint: bool = False ) -> HistBinCount: return HistBinCount(bin_count=int(count), include_breakpoint=include_breakpoint) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index e5b37e2996..54cbf5de68 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -584,14 +584,14 @@ def test_hist_bins() -> None: def test_hist_bin_count() -> None: bin_count_default = 10 - include_breakpoint_default = True + include_breakpoint_default = False a = nwp.col("a") hist_1 = a.hist( bin_count=bin_count_default, include_breakpoint=include_breakpoint_default ) hist_2 = a.hist() hist_3 = a.hist(bin_count=5) - hist_4 = a.hist(include_breakpoint=False) + hist_4 = a.hist(include_breakpoint=True) ir_1 = hist_1._ir ir_2 = hist_2._ir From 2d198ddfc07a649b9bd7ba8cf2717be45b4002bf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:00:20 +0000 Subject: [PATCH 165/215] test: Cover `Expr.hist` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Nice to see this all worked out 😅 --- tests/plan/hist_test.py | 112 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index c3456f1864..28dfbfdd06 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - from narwhals.typing import EagerAllowed + from narwhals.typing import EagerAllowed, IntoDType from tests.conftest import Data pytest.importorskip("pyarrow") @@ -200,6 +200,116 @@ def test_hist_bin_count( assert result.get_column("count").sum() == ser.drop_nans().count() +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.all().hist(bin_count=5), + { + "int": [2, 1, 1, 1, 2], + "float": [2, 1, 1, 1, 2], + "int_shuffled": [2, 1, 1, 1, 2], + "float_shuffled": [2, 1, 1, 1, 2], + }, + ), + ( + (99 + nwp.all()).hist(bin_count=2).name.keep(), + { + "int": [4, 3], + "float": [4, 3], + "int_shuffled": [4, 3], + "float_shuffled": [4, 3], + }, + ), + ( + nwp.all().hist([-3, -2, 3, 6]).name.to_uppercase(), + { + "INT": [0, 4, 3], + "FLOAT": [0, 4, 3], + "INT_SHUFFLED": [0, 4, 3], + "FLOAT_SHUFFLED": [0, 4, 3], + }, + ), + ( + nwp.all().clip(upper_bound=4).hist([2, 3, 4, 5, 6]), + { + "int": [2, 3, 0, 0], + "float": [2, 3, 0, 0], + "int_shuffled": [2, 3, 0, 0], + "float_shuffled": [2, 3, 0, 0], + }, + ), + ( + (nwp.all() * 2.7).hist([1.3, 5.1, 8.98, 11.3]), + { + "int": [1, 2, 1], + "float": [1, 2, 1], + "int_shuffled": [1, 2, 1], + "float_shuffled": [1, 2, 1], + }, + ), + ], +) +def test_hist_expr_counts_only( + data: Data, + schema_data: nw.Schema, + backend: EagerAllowed, + expr: nwp.Expr, + expected: dict[str, Any], +) -> None: + df = nwp.DataFrame.from_dict(data, schema_data, backend=backend) + result = df.select(expr) + assert_equal_data(result, expected) + + +def test_hist_expr_breakpoint( + data: Data, schema_data: nw.Schema, backend: EagerAllowed +) -> None: + df = nwp.DataFrame.from_dict(data, schema_data, backend=backend) + expr = nwp.all().hist(bin_count=3, include_breakpoint=True) + result = df.select(expr) + result_schema = result.collect_schema() + + dtype_breakpoint: IntoDType = nw.Float64 + # NOTE: To match polars it would be this, but maybe i64 is okay? + dtype_count: IntoDType = nw.UInt32 + dtype_count = nw.Int64 + + dtype_struct = nw.Struct({"breakpoint": dtype_breakpoint, "count": dtype_count}) + expected_schema = nw.Schema( + [ + ("int", dtype_struct), + ("float", dtype_struct), + ("int_shuffled", dtype_struct), + ("float_shuffled", dtype_struct), + ] + ) + expected_data = { + "int": [ + {"breakpoint": 2.0, "count": 3}, + {"breakpoint": 4.0, "count": 2}, + {"breakpoint": 6.0, "count": 2}, + ], + "float": [ + {"breakpoint": 2.0, "count": 3}, + {"breakpoint": 4.0, "count": 2}, + {"breakpoint": 6.0, "count": 2}, + ], + "int_shuffled": [ + {"breakpoint": 2.0, "count": 3}, + {"breakpoint": 4.0, "count": 2}, + {"breakpoint": 6.0, "count": 2}, + ], + "float_shuffled": [ + {"breakpoint": 2.0, "count": 3}, + {"breakpoint": 4.0, "count": 2}, + {"breakpoint": 6.0, "count": 2}, + ], + } + assert result_schema == expected_schema + assert_equal_data(result, expected_data) + + @bin_count_cases def test_hist_bin_count_missing( data_missing: Data, From 8283a47635ff07ac4d04fe819a919bdfbf2aafb0 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:29:59 +0000 Subject: [PATCH 166/215] feat(expr-ir): Add `{DataFrame,Series}.explode(empty_as_nulls, keep_nulls)` (#3347) --- narwhals/_plan/arrow/dataframe.py | 60 ++++-- narwhals/_plan/arrow/expr.py | 8 +- narwhals/_plan/arrow/functions.py | 159 +++++++++++++++- narwhals/_plan/arrow/series.py | 4 + narwhals/_plan/arrow/typing.py | 15 ++ narwhals/_plan/compliant/dataframe.py | 3 +- narwhals/_plan/compliant/series.py | 1 + narwhals/_plan/dataframe.py | 24 ++- narwhals/_plan/exceptions.py | 5 + narwhals/_plan/options.py | 12 ++ narwhals/_plan/series.py | 5 + tests/plan/explode_test.py | 264 ++++++++++++++++++++++++++ 12 files changed, 534 insertions(+), 26 deletions(-) create mode 100644 tests/plan/explode_test.py diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 1c2f3e1d6c..ecef91b772 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -16,6 +16,7 @@ from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.compliant.dataframe import EagerDataFrame from narwhals._plan.compliant.typing import namespace +from narwhals._plan.exceptions import shape_error from narwhals._plan.expressions import NamedIR from narwhals._utils import Version, generate_repr from narwhals.schema import Schema @@ -26,10 +27,10 @@ import polars as pl from typing_extensions import Self, TypeAlias - from narwhals._plan.arrow.typing import ChunkedArrayAny + from narwhals._plan.arrow.typing import ChunkedArrayAny, ChunkedOrArrayAny from narwhals._plan.compliant.group_by import GroupByResolver from narwhals._plan.expressions import ExprIR, NamedIR - from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.options import ExplodeOptions, SortMultipleOptions from narwhals._plan.typing import NonCrossJoinStrategy from narwhals.dtypes import DType from narwhals.typing import IntoSchema @@ -162,6 +163,12 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: native = self.native.filter(~to_drop) return self._with_native(native) + def explode(self, subset: Sequence[str], options: ExplodeOptions) -> Self: + builder = fn.ExplodeBuilder.from_options(options) + if len(subset) == 1: + return self._with_native(builder.explode_column(self.native, subset[0])) + return self._with_native(builder.explode_columns(self.native, subset)) + def rename(self, mapping: Mapping[str, str]) -> Self: names: dict[str, str] | list[str] if fn.BACKEND_VERSION >= (17,): @@ -170,20 +177,26 @@ def rename(self, mapping: Mapping[str, str]) -> Self: names = [mapping.get(c, c) for c in self.columns] return self._with_native(self.native.rename_columns(names)) - # NOTE: Use instead of `with_columns` for trivial cases + def with_series(self, series: Series) -> Self: + """Add a new column or replace an existing one. + + Uses similar semantics as `with_columns`, but: + - for a single named `Series` + - no broadcasting (use `Scalar.broadcast` instead) + - no length checking (use `with_series_checked` instead) + """ + return self._with_native(with_array(self.native, series.name, series.native)) + + def with_series_checked(self, series: Series) -> Self: + expected, actual = len(self), len(series) + if len(series) != len(self): + raise shape_error(expected, actual) + return self.with_series(series) + def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: - native = self.native - columns = self.columns height = len(self) - for into_series in exprs: - name = into_series.name - chunked = into_series.broadcast(height).native - if name in columns: - i = columns.index(name) - native = native.set_column(i, name, chunked) - else: - native = native.append_column(name, chunked) - return self._with_native(native) + names_and_columns = ((e.name, e.broadcast(height).native) for e in exprs) + return self._with_native(with_arrays(self.native, names_and_columns)) def select_names(self, *column_names: str) -> Self: return self._with_native(self.native.select(list(column_names))) @@ -226,3 +239,22 @@ def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[S from_native = self._with_native partitions = partition_by(self.native, by, include_key=include_key) return [from_native(df) for df in partitions] + + +def with_array(table: pa.Table, name: str, column: ChunkedOrArrayAny) -> pa.Table: + column_names = table.column_names + if name in column_names: + return table.set_column(column_names.index(name), name, column) + return table.append_column(name, column) + + +def with_arrays( + table: pa.Table, names_and_columns: Iterable[tuple[str, ChunkedOrArrayAny]], / +) -> pa.Table: + column_names = table.column_names + for name, column in names_and_columns: + if name in column_names: + table = table.set_column(column_names.index(name), name, column) + else: + table = table.append_column(name, column) + return table diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 14e3ac893e..15ee7e5f91 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -28,6 +28,7 @@ from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar from narwhals._plan.compliant.typing import namespace +from narwhals._plan.exceptions import shape_error from narwhals._plan.expressions import FunctionExpr as FExpr, functions as F from narwhals._plan.expressions.boolean import ( IsDuplicated, @@ -48,7 +49,7 @@ not_implemented, qualified_type_name, ) -from narwhals.exceptions import InvalidOperationError, ShapeError +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from collections.abc import Callable, Mapping, Sequence @@ -372,7 +373,7 @@ def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Sel def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Self: if isinstance(result, pa.Scalar): return ArrowScalar.from_native(result, name, version=self.version) - return self.from_native(result, name or self.name, self.version) + return self.from_native(result, name, self.version) # NOTE: I'm not sure what I meant by # > "isn't natively supported on `ChunkedArray`" @@ -405,8 +406,7 @@ def to_series(self) -> Series: def broadcast(self, length: int, /) -> Series: if (actual_len := len(self)) != length: - msg = f"Expected object of length {length}, got {actual_len}." - raise ShapeError(msg) + raise shape_error(length, actual_len) return self._evaluated def __len__(self) -> int: diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d69787110d..1969cd33d3 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -4,7 +4,8 @@ import math import typing as t -from collections.abc import Callable, Sequence +from collections.abc import Callable, Collection, Iterator, Sequence +from itertools import chain from typing import TYPE_CHECKING, Any, Final, Literal, overload import pyarrow as pa # ignore-banned-import @@ -20,13 +21,15 @@ from narwhals._plan._guards import is_non_nested_literal from narwhals._plan.arrow import options as pa_options from narwhals._plan.expressions import functions as F, operators as ops -from narwhals._utils import Implementation, Version +from narwhals._plan.options import ExplodeOptions +from narwhals._utils import Implementation, Version, no_default +from narwhals.exceptions import ShapeError if TYPE_CHECKING: import datetime as dt from collections.abc import Iterable, Mapping - from typing_extensions import TypeAlias, TypeIs, TypeVarTuple, Unpack + from typing_extensions import Self, TypeAlias, TypeIs, TypeVarTuple, Unpack from narwhals._arrow.typing import Incomplete, PromoteOptions from narwhals._plan.arrow.acero import Field @@ -35,6 +38,7 @@ ArrayAny, Arrow, ArrowAny, + ArrowListT, ArrowT, BinaryComp, BinaryFunction, @@ -63,7 +67,9 @@ LargeStringType, ListArray, ListScalar, + ListTypeT, NativeScalar, + NonListTypeT, NumericScalar, Predicate, SameArrowT, @@ -80,6 +86,7 @@ from narwhals._plan.compliant.typing import SeriesT from narwhals._plan.options import RankOptions, SortMultipleOptions, SortOptions from narwhals._plan.typing import Seq + from narwhals._typing import NoDefault from narwhals.typing import ( ClosedInterval, FillNullStrategy, @@ -386,6 +393,148 @@ def get_categories(native: ArrowAny) -> ChunkedArrayAny: return chunked_array(da.dictionary) +class ExplodeBuilder: + options: ExplodeOptions + + def __init__(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> None: + self.options = ExplodeOptions(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + + @classmethod + def from_options(cls, options: ExplodeOptions, /) -> Self: + obj = cls.__new__(cls) + obj.options = options + return obj + + @t.overload + def explode( + self, native: ChunkedList[DataTypeT] | ListScalar[DataTypeT] + ) -> ChunkedArray[Scalar[DataTypeT]]: ... + @t.overload + def explode(self, native: ListArray[DataTypeT]) -> Array[Scalar[DataTypeT]]: ... + @t.overload + def explode( + self, native: Arrow[ListScalar[DataTypeT]] + ) -> ChunkedOrArray[Scalar[DataTypeT]]: ... + def explode( + self, native: Arrow[ListScalar[DataTypeT]] + ) -> ChunkedOrArray[Scalar[DataTypeT]]: + """Explode list elements, expanding one-level into a new array. + + Equivalent to `polars.{Expr,Series}.explode`. + """ + safe = self._fill_with_null(native) if self.options.any() else native + if not isinstance(safe, pa.Scalar): + return _list_explode(safe) + return chunked_array(_list_explode(safe)) + + def explode_column(self, native: pa.Table, column_name: str, /) -> pa.Table: + """Explode a list-typed column in the context of `native`.""" + ca = native.column(column_name) + safe = self._fill_with_null(ca) if self.options.any() else ca + exploded = _list_explode(safe) + col_idx = native.schema.get_field_index(column_name) + if len(exploded) == len(native): + return native.set_column(col_idx, column_name, exploded) + return ( + native.remove_column(col_idx) + .take(_list_parent_indices(safe)) + .add_column(col_idx, column_name, exploded) + ) + + def explode_columns(self, native: pa.Table, subset: Collection[str], /) -> pa.Table: + """Explode multiple list-typed columns in the context of `native`.""" + subset = list(subset) + arrays = native.select(subset).columns + first = arrays[0] + first_len = list_len(first) + if self.options.any(): + mask = self._predicate(first_len) + first_safe = self._fill_with_null(first, mask) + it = ( + _list_explode(self._fill_with_null(arr, mask)) + for arr in self._iter_ensure_shape(first_len, arrays[1:]) + ) + else: + first_safe = first + it = ( + _list_explode(arr) + for arr in self._iter_ensure_shape(first_len, arrays[1:]) + ) + first_result = _list_explode(first_safe) + if len(first_result) != len(native): + gathered = native.drop_columns(subset).take(_list_parent_indices(first_safe)) + for name, arr in zip(subset, chain([first_result], it)): + gathered = gathered.append_column(name, arr) + return gathered.select(native.column_names) + # NOTE: Not too happy about this import + from narwhals._plan.arrow.dataframe import with_arrays + + return with_arrays(native, zip(subset, chain([first_result], it))) + + def _iter_ensure_shape( + self, + first_len: ChunkedArray[pa.UInt32Scalar], + arrays: Iterable[ChunkedArrayAny], + /, + ) -> Iterator[ChunkedArrayAny]: + for arr in arrays: + if not first_len.equals(list_len(arr)): + msg = "exploded columns must have matching element counts" + raise ShapeError(msg) + yield arr + + def _predicate(self, lengths: ArrowAny, /) -> Arrow[BooleanScalar]: + """Return True for each sublist length that indicates the original sublist should be replaced with `[None]`.""" + empty_as_null, keep_nulls = self.options.empty_as_null, self.options.keep_nulls + if empty_as_null and keep_nulls: + return or_(is_null(lengths), eq(lengths, lit(0))) + if empty_as_null: + return eq(lengths, lit(0)) + return is_null(lengths) + + def _fill_with_null( + self, native: ArrowListT, mask: Arrow[BooleanScalar] | NoDefault = no_default + ) -> ArrowListT: + """Replace each sublist in `native` with `[None]`, according to `self.options`. + + Arguments: + native: List-typed arrow data. + mask: An optional, pre-computed replacement mask. By default, this is generated from `native`. + """ + predicate = self._predicate(list_len(native)) if mask is no_default else mask + result: ArrowListT = when_then(predicate, lit([None], native.type), native) + return result + + +@t.overload +def _list_explode(native: ChunkedList[DataTypeT]) -> ChunkedArray[Scalar[DataTypeT]]: ... +@t.overload +def _list_explode( + native: ListArray[NonListTypeT] | ListScalar[NonListTypeT], +) -> Array[Scalar[NonListTypeT]]: ... +@t.overload +def _list_explode(native: ListArray[DataTypeT]) -> Array[Scalar[DataTypeT]]: ... +@t.overload +def _list_explode(native: ListScalar[ListTypeT]) -> ListArray[ListTypeT]: ... +def _list_explode(native: Arrow[ListScalar]) -> ChunkedOrArrayAny: + result: ChunkedOrArrayAny = pc.call_function("list_flatten", [native]) + return result + + +@t.overload +def _list_parent_indices(native: ChunkedList) -> ChunkedArray[pa.Int64Scalar]: ... +@t.overload +def _list_parent_indices(native: ListArray) -> pa.Int64Array: ... +def _list_parent_indices( + native: ChunkedOrArray[ListScalar], +) -> ChunkedOrArray[pa.Int64Scalar]: + """Don't use this withut handling nulls!""" + result: ChunkedOrArray[pa.Int64Scalar] = pc.call_function( + "list_parent_indices", [native] + ) + return result + + @t.overload def list_len(native: ChunkedList) -> ChunkedArray[pa.UInt32Scalar]: ... @t.overload @@ -393,9 +542,9 @@ def list_len(native: ListArray) -> pa.UInt32Array: ... @t.overload def list_len(native: ListScalar) -> pa.UInt32Scalar: ... @t.overload -def list_len(native: SameArrowT) -> SameArrowT: ... -@t.overload def list_len(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[pa.UInt32Scalar]: ... +@t.overload +def list_len(native: Arrow[ListScalar[Any]]) -> Arrow[pa.UInt32Scalar]: ... def list_len(native: ArrowAny) -> ArrowAny: length: Incomplete = pc.list_value_length result: ArrowAny = length(native).cast(pa.uint32()) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 72e01a4fb1..8a59ce42cc 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -301,6 +301,10 @@ def drop_nans(self) -> Self: self.native.filter(predicate, null_selection_behavior="emit_null") ) + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Self: + exploder = fn.ExplodeBuilder(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + return self._with_native(exploder.explode(self.native)) + @property def struct(self) -> SeriesStructNamespace: return SeriesStructNamespace(self) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 63ee251997..766053f76a 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: import pyarrow as pa import pyarrow.compute as pc + from pyarrow import lib, types from pyarrow.lib import ( BoolType as BoolType, Date32Type, @@ -37,6 +38,19 @@ BooleanScalar: TypeAlias = "Scalar[BoolType]" NumericScalar: TypeAlias = "pc.NumericScalar" + PrimitiveNumericType: TypeAlias = "types._Integer | types._Floating" + NumericType: TypeAlias = "PrimitiveNumericType | types._Decimal" + NumericOrTemporalType: TypeAlias = "NumericType | types._Temporal" + StringOrBinaryType: TypeAlias = "StringType | LargeStringType | lib.StringViewType | lib.BinaryType | lib.LargeBinaryType | lib.BinaryViewType" + BasicType: TypeAlias = ( + "NumericOrTemporalType | StringOrBinaryType | BoolType | lib.NullType" + ) + NonListNestedType: TypeAlias = "pa.StructType | pa.DictionaryType[Any, Any] | pa.MapType[Any, Any] | pa.UnionType" + NonListType: TypeAlias = "BasicType | NonListNestedType" + NestedType: TypeAlias = "NonListNestedType | pa.ListType[Any]" + NonListTypeT = TypeVar("NonListTypeT", bound="NonListType") + ListTypeT = TypeVar("ListTypeT", bound="pa.ListType[Any]") + class NativeArrowSeries(NativeSeries, Protocol): @property def chunks(self) -> list[Any]: ... @@ -200,6 +214,7 @@ class BinaryLogical(BinaryFunction["BooleanScalar", "BooleanScalar"], Protocol): ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" SameArrowT = TypeVar("SameArrowT", ChunkedArrayAny, ArrayAny, ScalarAny) ArrowT = TypeVar("ArrowT", bound=ArrowAny) +ArrowListT = TypeVar("ArrowListT", bound="Arrow[ListScalar[Any]]") Predicate: TypeAlias = "Arrow[BooleanScalar]" """Any `pyarrow` container that wraps boolean.""" diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index b502848534..2787f449e2 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -30,7 +30,7 @@ from narwhals._plan.compliant.namespace import EagerNamespace from narwhals._plan.dataframe import BaseFrame, DataFrame from narwhals._plan.expressions import NamedIR - from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.options import ExplodeOptions, SortMultipleOptions from narwhals._plan.typing import Seq from narwhals._typing import _EagerAllowedImpl from narwhals._utils import Implementation, Version @@ -59,6 +59,7 @@ def to_narwhals(self) -> BaseFrame[NativeFrameT_co]: ... def columns(self) -> list[str]: ... def drop(self, columns: Sequence[str]) -> Self: ... def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... + def explode(self, subset: Sequence[str], options: ExplodeOptions) -> Self: ... # Shouldn't *need* to be `NamedIR`, but current impl depends on a name being passed around def filter(self, predicate: NamedIR, /) -> Self: ... def rename(self, mapping: Mapping[str, str]) -> Self: ... diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 7a51d7802e..6416adbc06 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -133,6 +133,7 @@ def cum_sum(self, *, reverse: bool = False) -> Self: ... def diff(self, n: int = 1) -> Self: ... def drop_nulls(self) -> Self: ... def drop_nans(self) -> Self: ... + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Self: ... def fill_nan(self, value: float | Self | None) -> Self: ... def fill_null(self, value: NonNestedLiteral | Self) -> Self: ... def fill_null_with_strategy( diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 7455ad8cc5..2941c30a30 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -7,7 +7,7 @@ from narwhals._plan._guards import is_series from narwhals._plan.common import ensure_seq_str, temp from narwhals._plan.group_by import GroupBy, Grouped -from narwhals._plan.options import SortMultipleOptions +from narwhals._plan.options import ExplodeOptions, SortMultipleOptions from narwhals._plan.series import Series from narwhals._plan.typing import ( ColumnNameOrSelector, @@ -25,7 +25,7 @@ ) from narwhals._utils import Implementation, Version, generate_repr from narwhals.dependencies import is_pyarrow_table -from narwhals.exceptions import ShapeError +from narwhals.exceptions import InvalidOperationError, ShapeError from narwhals.schema import Schema from narwhals.typing import EagerAllowed, IntoBackend, IntoDType, IntoSchema, JoinStrategy @@ -158,6 +158,26 @@ def with_row_index( by_names = expand_selector_irs_names(by_selectors, schema=self, require_any=True) return self._with_compliant(self._compliant.with_row_index_by(name, by_names)) + def explode( + self, + columns: OneOrIterable[ColumnNameOrSelector], + *more_columns: ColumnNameOrSelector, + empty_as_null: bool = True, + keep_nulls: bool = True, + ) -> Self: + s_ir = _parse.parse_into_combined_selector_ir(columns, *more_columns) + schema = self.collect_schema() + subset = expand_selector_irs_names((s_ir,), schema=schema, require_any=True) + dtypes = self.version.dtypes + tp_list = dtypes.List + for col_to_explode in subset: + dtype = schema[col_to_explode] + if dtype != tp_list: + msg = f"`explode` operation is not supported for dtype `{dtype}`, expected List type" + raise InvalidOperationError(msg) + options = ExplodeOptions(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + return self._with_compliant(self._compliant.explode(subset, options)) + def _dataframe_from_dict( data: Mapping[str, Any], diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index cac4616b9b..53aaa49f67 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -66,6 +66,11 @@ def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001 return ComputeError(msg) +def shape_error(expected_length: int, actual_length: int) -> ShapeError: + msg = f"Expected object of length {expected_length}, got {actual_length}." + return ShapeError(msg) + + def _binary_underline( left: ir.ExprIR, operator: Operator, diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 4914b8b59b..ab43aa23ed 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -328,3 +328,15 @@ def default(cls) -> Self: FEOptions = FunctionExprOptions + + +class ExplodeOptions(Immutable): + __slots__ = ("empty_as_null", "keep_nulls") + empty_as_null: bool + """Explode an empty list into a `null`.""" + keep_nulls: bool + """Explode a `null` into a `null`.""" + + def any(self) -> bool: + """Return True if we need to handle empty lists and/or nulls.""" + return self.empty_as_null or self.keep_nulls diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 3b6407a0da..db799ff37f 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -298,6 +298,11 @@ def hist( return result return result.to_series().struct.unnest() + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Self: + return type(self)( + self._compliant.explode(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + ) + @property def struct(self) -> SeriesStructNamespace[Self]: return SeriesStructNamespace(self) diff --git a/tests/plan/explode_test.py b/tests/plan/explode_test.py new file mode 100644 index 0000000000..c7cc478d67 --- /dev/null +++ b/tests/plan/explode_test.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Final + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from narwhals.exceptions import InvalidOperationError, ShapeError +from tests.plan.utils import ( + assert_equal_data, + assert_equal_series, + dataframe, + re_compile, + series, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from narwhals._plan.typing import ColumnNameOrSelector + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + # For context, polars allows to explode multiple columns only if the columns + # have matching element counts, therefore, l1 and l2 but not l1 and l3 together. + return { + "a": ["x", "y", "z", "w"], + "l1": [[1, 2], None, [None], []], + "l2": [[3, None], None, [42], []], + "l3": [[1, 2], [3], [None], [1]], + "l4": [[1, 2], [3], [123], [456]], + "l5": [[None, None], [None], [99], [83]], + } + + +@pytest.mark.parametrize( + ("column", "expected_values"), + [("l2", [None, 3, None, None, 42]), ("l3", [1, 1, 2, 3, None])], +) +def test_explode_frame_single_col( + column: str, expected_values: list[int | None], data: Data +) -> None: + result = ( + dataframe(data) + .with_columns(nwp.col(column).cast(nw.List(nw.Int32()))) + .explode(column) + .select("a", column) + .sort("a", column, nulls_last=True) + ) + expected = {"a": ["w", "x", "x", "y", "z"], column: expected_values} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("column", "more_columns", "expected"), + [ + ( + "l1", + ["l2"], + { + "a": ["w", "x", "x", "y", "z"], + "l1": [None, 1, 2, None, None], + "l2": [None, 3, None, None, 42], + }, + ), + ( + "l3", + ["l4"], + { + "a": ["w", "x", "x", "y", "z"], + "l3": [1, 1, 2, 3, None], + "l4": [456, 1, 2, 3, 123], + }, + ), + ], +) +def test_explode_frame_multiple_cols( + column: str, + more_columns: Sequence[str], + expected: dict[str, list[str | int | None]], + data: Data, +) -> None: + result = ( + dataframe(data) + .with_columns(nwp.col(column, *more_columns).cast(nw.List(nw.Int32()))) + .explode(column, *more_columns) + .select("a", column, *more_columns) + .sort("a", column, nulls_last=True) + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + ncs.by_index(-1, -2, -3), + { + "a": ["w", "x", "x", "y", "z"], + "l5": [83, None, None, None, 99], + "l4": [456, 1, 2, 3, 123], + "l3": [1, 1, 2, 3, None], + }, + ), + ( + ncs.matches(r"l[3|5]"), + { + "a": ["w", "x", "x", "y", "z"], + "l3": [1, 1, 2, 3, None], + "l5": [83, None, None, None, 99], + }, + ), + ], +) +def test_explode_frame_selectors(expr: nwp.Selector, expected: Data, data: Data) -> None: + result = ( + dataframe(data) + .with_columns(expr.cast(nw.List(nw.Int32()))) + .explode(expr) + .select("a", expr) + .sort("a", expr, nulls_last=True) + ) + assert_equal_data(result, expected) + + +def test_explode_frame_shape_error(data: Data) -> None: + with pytest.raises( + ShapeError, match=r".*exploded columns (must )?have matching element counts" + ): + dataframe(data).with_columns( + nwp.col("l1", "l2", "l3").cast(nw.List(nw.Int32())) + ).explode(ncs.list()) + + +def test_explode_frame_invalid_operation_error(data: Data) -> None: + with pytest.raises( + InvalidOperationError, + match=re_compile(r"explode.+not supported for.+string.+expected.+list"), + ): + dataframe(data).explode("a") + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ([[1, 2, 3]], [1, 2, 3]), + ([[1, 2, 3], None], [1, 2, 3, None]), + ([[1, 2, 3], []], [1, 2, 3, None]), + ], +) +def test_explode_series_default(values: list[Any], expected: list[Any]) -> None: + # Based on https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/operations/test_explode.py#L465-L470 + result = series(values).explode() + assert_equal_series(result, expected, "") + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ([[1, 2, 3], [1, 2], [1, 2]], [1, 2, 3, None, 1, 2]), + ([[1, 2, 3], [], [1, 2]], [1, 2, 3, None, 1, 2]), + ], +) +def test_explode_series_default_masked(values: list[Any], expected: list[Any]) -> None: + # Based on https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/operations/test_explode.py#L471-484 + result = ( + series(values) + .to_frame() + .select(nwp.when(series([True, False, True])).then(nwp.col(""))) + .to_series() + .explode() + ) + assert_equal_series(result, expected, "") + + +DROP_EMPTY: Final = {"empty_as_null": False} +DROP_NULLS: Final = {"keep_nulls": False} +DROP_BOTH: Final = {"empty_as_null": False, "keep_nulls": False} +DEFAULT: Final[Data] = {} + + +@pytest.mark.parametrize( + ("values", "kwds", "expected"), + [ + ([[1, 2, 3]], DROP_BOTH, [1, 2, 3]), + ([[1, 2, 3], None], DROP_NULLS, [1, 2, 3]), + ([[1, 2, 3], [None]], DROP_NULLS, [1, 2, 3, None]), + ([[1, 2, 3], []], DROP_EMPTY, [1, 2, 3]), + ([[1, 2, 3], [None]], DROP_EMPTY, [1, 2, 3, None]), + ], +) +def test_explode_series_options( + values: list[Any], kwds: dict[str, Any], expected: list[Any] +) -> None: + # Based on https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/operations/test_explode.py#L486-L505 + result = series(values).explode(**kwds) + assert_equal_series(result, expected, "") + + +A = ("a",) +BA = "b", "a" + +DEFAULT_A: Final = [1, 2, 3, None, 4, 5, 6, None] +DEFAULT_I: Final = [1, 1, 1, 2, 3, 3, 3, 4] +DEFAULT_B: Final = [None, "dog", "cat", None, "narwhal", None, "orca", None] +EMPTY_A: Final = [1, 2, 3, None, 4, 5, 6] +EMPTY_I: Final = [1, 1, 1, 2, 3, 3, 3] +EMPTY_B: Final = [None, "dog", "cat", None, "narwhal", None, "orca"] +NULLS_A: Final = [1, 2, 3, 4, 5, 6, None] +NULLS_I: Final = [1, 1, 1, 3, 3, 3, 4] +NULLS_B: Final = [None, "dog", "cat", "narwhal", None, "orca", None] +BOTH_A: Final = [1, 2, 3, 4, 5, 6] +BOTH_I: Final = [1, 1, 1, 3, 3, 3] +BOTH_B: Final = [None, "dog", "cat", "narwhal", None, "orca"] + + +@pytest.mark.parametrize( + ("columns", "kwds", "expected"), + [ + (A, DEFAULT, {"a": DEFAULT_A, "i": DEFAULT_I}), + (A, DROP_EMPTY, {"a": EMPTY_A, "i": EMPTY_I}), + (A, DROP_NULLS, {"a": NULLS_A, "i": NULLS_I}), + (A, DROP_BOTH, {"a": BOTH_A, "i": BOTH_I}), + (BA, DEFAULT, {"b": DEFAULT_B, "a": DEFAULT_A, "i": DEFAULT_I}), + (BA, DROP_EMPTY, {"b": EMPTY_B, "a": EMPTY_A, "i": EMPTY_I}), + (BA, DROP_NULLS, {"b": NULLS_B, "a": NULLS_A, "i": NULLS_I}), + (BA, DROP_BOTH, {"b": BOTH_B, "a": BOTH_A, "i": BOTH_I}), + ], +) +def test_explode_frame_options( + columns: Sequence[ColumnNameOrSelector], kwds: dict[str, Any], expected: Data +) -> None: + # Based on https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/operations/test_explode.py#L596-L616 + data = { + "a": [[1, 2, 3], None, [4, 5, 6], []], + "b": [[None, "dog", "cat"], None, ["narwhal", None, "orca"], []], + "i": [1, 2, 3, 4], + } + result = ( + dataframe(data) + .with_columns( + nwp.col("a").cast(nw.List(nw.Int32())), nwp.col("b").cast(nw.List(nw.String)) + ) + .select(*columns, "i") + .explode(columns, **kwds) + ) + assert_equal_data(result, expected) + + +def test_explode_frame_single_elements() -> None: + data = {"a": [[1], [2], [3]], "b": [[4], [5], [6]], "i": [0, 10, 20]} + df = dataframe(data).with_columns(nwp.col("a", "b").cast(nw.List(nw.Int32()))) + + result = df.explode("a") + expected = {"a": [1, 2, 3], "b": [[4], [5], [6]], "i": [0, 10, 20]} + assert_equal_data(result, expected) + + result = df.explode("b", "a") + expected = {"a": [1, 2, 3], "b": [4, 5, 6], "i": [0, 10, 20]} + assert_equal_data(result, expected) From 5ecd62f15eea295d4df0e53e822bb1595b935fab Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:53:51 +0000 Subject: [PATCH 167/215] test: Cover `DataFrame.to_struct` The compliant level was already covered in `hist` --- narwhals/_plan/dataframe.py | 2 +- tests/plan/compliant_test.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 2941c30a30..b18697a4b4 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -321,7 +321,7 @@ def to_dict( def to_series(self, index: int = 0) -> Series[NativeSeriesT]: return self._series(self._compliant.to_series(index)) - def to_struct(self, name: str = "") -> Series[NativeSeriesT]: # pragma: no cover + def to_struct(self, name: str = "") -> Series[NativeSeriesT]: return self._series(self._compliant.to_struct(name)) def to_polars(self) -> pl.DataFrame: diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index bd8821260c..e503cd63ec 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -698,6 +698,35 @@ def test_dataframe_from_dict_misc(data_small: Data) -> None: nwp.DataFrame.from_dict(data_small) # type: ignore[arg-type] +def test_dataframe_to_struct(data_small_af: Data) -> None: + pytest.importorskip("pyarrow") + + schema = { + "a": nw.String(), + "b": nw.Int64(), + "c": nw.Int64(), + "d": nw.Int64(), + "e": nw.Int64(), + "f": nw.Boolean(), + } + + df = dataframe(data_small_af).with_columns( + nwp.col(name).cast(dtype) for name, dtype in schema.items() + ) + result = df.to_struct("struct_series") + result_dtype = result.dtype + assert isinstance(result_dtype, nw.Struct) + result_schema = dict(result_dtype.to_schema()) + assert result_schema == schema + + expected = [ + {"a": "A", "b": 1, "c": 9, "d": 8, "e": None, "f": True}, + {"a": "B", "b": 2, "c": 2, "d": 7, "e": 9, "f": False}, + {"a": "A", "b": 3, "c": 4, "d": 8, "e": 7, "f": None}, + ] + assert_equal_series(result, expected, "struct_series") + + # TODO @dangotbanned: Split this up def test_series_misc() -> None: pytest.importorskip("pyarrow") From 33c94980b298445ba30ed18234aaea19f96ff8a6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 15:09:39 +0000 Subject: [PATCH 168/215] remove outdated comments --- tests/plan/all_any_horizontal_test.py | 1 - tests/plan/rolling_expr_test.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/tests/plan/all_any_horizontal_test.py b/tests/plan/all_any_horizontal_test.py index df9a4859e0..3db55ec694 100644 --- a/tests/plan/all_any_horizontal_test.py +++ b/tests/plan/all_any_horizontal_test.py @@ -12,7 +12,6 @@ from tests.conftest import Data -# test_allh_iterator has a different length @pytest.fixture(scope="module") def data() -> Data: return { diff --git a/tests/plan/rolling_expr_test.py b/tests/plan/rolling_expr_test.py index 3c07d0cd44..2bff999053 100644 --- a/tests/plan/rolling_expr_test.py +++ b/tests/plan/rolling_expr_test.py @@ -30,7 +30,6 @@ def data() -> Data: } -# TODO @dangotbanned: Just reuse `rolling_options` for the tests? @pytest.mark.parametrize( ("window_size", "min_samples", "center", "ddof", "expected"), [ @@ -58,7 +57,6 @@ def test_rolling_var( assert_equal_data(result, {"var_std": expected}) -# TODO @dangotbanned: Just reuse `rolling_options` for the tests? @pytest.mark.parametrize( ("window_size", "min_samples", "center", "ddof", "expected"), [ From 03fbcfb2583a06c31ad023994c2bbfde5d3859de Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 16:08:46 +0000 Subject: [PATCH 169/215] refactor: Move `Series.hist` impl -> `CompliantSeries.hist` Fixes https://github.com/narwhals-dev/narwhals/pull/3325#discussion_r2596493923 --- narwhals/_plan/compliant/series.py | 39 +++++++++++++++++++++++++++--- narwhals/_plan/series.py | 22 ++++++----------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 6416adbc06..7f95b379ee 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -1,18 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Protocol +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol from narwhals._plan.compliant.typing import HasVersion from narwhals._plan.typing import NativeSeriesT -from narwhals._utils import Version, _StoresNative +from narwhals._utils import Version, _StoresNative, unstable if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Sequence import polars as pl from typing_extensions import Self, TypeAlias from narwhals._plan.compliant.accessors import SeriesStructNamespace + from narwhals._plan.compliant.dataframe import CompliantDataFrame + from narwhals._plan.dataframe import DataFrame from narwhals._plan.series import Series from narwhals._typing import _EagerAllowedImpl from narwhals.dtypes import DType @@ -194,6 +196,37 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: def to_polars(self) -> pl.Series: ... def unique(self, *, maintain_order: bool = False) -> Self: ... def zip_with(self, mask: Self, other: Self) -> Self: ... + @unstable + def hist( + self, + bins: Sequence[float] | None = None, + *, + bin_count: int | None = None, + include_breakpoint: bool = True, + include_category: bool = False, + _compatibility_behavior: Literal["narwhals", "polars"] = "narwhals", + ) -> CompliantDataFrame[Self, Incomplete, NativeSeriesT]: + from narwhals._plan.expressions import col as ir_col + + expr = ( + ir_col(self.name) + .to_narwhals(self.version) + .hist( + bins, + bin_count=bin_count, + include_breakpoint=include_breakpoint, + include_category=include_category, + ) + ) + df: DataFrame[Incomplete, NativeSeriesT] = ( + self.to_narwhals().to_frame().select(expr) + ) + if not include_breakpoint and not include_category: + if _compatibility_behavior == "narwhals": + df = df.rename({self.name: "count"}) + else: + df = df.to_series().struct.unnest() + return df._compliant @property def struct(self) -> SeriesStructNamespace[Self, Incomplete]: ... diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index db799ff37f..0f6a714044 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -282,21 +282,13 @@ def hist( include_category: bool = False, _compatibility_behavior: Literal["narwhals", "polars"] = "narwhals", ) -> DataFrame[Incomplete, NativeSeriesT_co]: - from narwhals._plan import functions as F - - result = self.to_frame().select( - F.col(self.name).hist( - bins, - bin_count=bin_count, - include_breakpoint=include_breakpoint, - include_category=include_category, - ) - ) - if not include_breakpoint and not include_category: - if _compatibility_behavior == "narwhals": - result = result.rename({self.name: "count"}) - return result - return result.to_series().struct.unnest() + return self._compliant.hist( + bins, + bin_count=bin_count, + include_breakpoint=include_breakpoint, + include_category=include_category, + _compatibility_behavior=_compatibility_behavior, + ).to_narwhals() def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Self: return type(self)( From ca22bac423d12ea7ca2a956b8a926d4eca1f0a97 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 16:13:31 +0000 Subject: [PATCH 170/215] refactor(typing): Align type param positions between `Expr`, `Series` accessors Fixes https://github.com/narwhals-dev/narwhals/pull/3325#discussion_r2590855768 --- narwhals/_plan/arrow/series.py | 2 +- narwhals/_plan/compliant/accessors.py | 2 +- narwhals/_plan/compliant/series.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 8a59ce42cc..ed30d683ce 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -310,7 +310,7 @@ def struct(self) -> SeriesStructNamespace: return SeriesStructNamespace(self) -class SeriesStructNamespace(StructNamespace[ArrowSeries, "DataFrame"]): +class SeriesStructNamespace(StructNamespace["DataFrame", ArrowSeries]): def __init__(self, compliant: ArrowSeries, /) -> None: self._compliant: ArrowSeries = compliant diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 3ba16b4e0d..0c1eb0de28 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -90,6 +90,6 @@ def field( ) -> ExprT_co: ... -class SeriesStructNamespace(Protocol[SeriesT_co, DataFrameT_co]): +class SeriesStructNamespace(Protocol[DataFrameT_co, SeriesT_co]): def field(self, name: str) -> SeriesT_co: ... def unnest(self) -> DataFrameT_co: ... diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 7f95b379ee..23a06a53c7 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -229,4 +229,4 @@ def hist( return df._compliant @property - def struct(self) -> SeriesStructNamespace[Self, Incomplete]: ... + def struct(self) -> SeriesStructNamespace[Incomplete, Self]: ... From 75302141e3c7ef12a8769c2807421f0c3bdee5aa Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 16:33:51 +0000 Subject: [PATCH 171/215] docs(typing): Highlight and explain cycling dependency issue --- narwhals/_plan/compliant/dataframe.py | 3 ++- narwhals/_plan/compliant/expr.py | 7 +++---- narwhals/_plan/compliant/series.py | 16 +++++++--------- narwhals/_plan/dataframe.py | 3 ++- narwhals/_plan/series.py | 18 ++++++++++-------- narwhals/_plan/typing.py | 10 ++++++++++ 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index 2787f449e2..1e953d9c38 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -6,6 +6,7 @@ from narwhals._plan.compliant.group_by import Grouped from narwhals._plan.compliant.typing import ColumnT_co, HasVersion, SeriesT from narwhals._plan.typing import ( + IncompleteCyclic, IntoExpr, NativeDataFrameT, NativeFrameT_co, @@ -43,7 +44,7 @@ class CompliantFrame(HasVersion, Protocol[ColumnT_co, NativeFrameT_co]): implementation: ClassVar[Implementation] - def __narwhals_namespace__(self) -> Any: ... + def __narwhals_namespace__(self) -> IncompleteCyclic: ... def _evaluate_irs( self, nodes: Iterable[NamedIR[ir.ExprIR]], / ) -> Iterator[ColumnT_co]: ... diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 8f0fa2730b..e53e25977c 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -13,7 +13,7 @@ from narwhals._utils import Version if TYPE_CHECKING: - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan import expressions as ir from narwhals._plan.compliant.accessors import ( @@ -41,8 +41,7 @@ IsNull, Not, ) - -Incomplete: TypeAlias = Any + from narwhals._plan.typing import IncompleteCyclic class CompliantExpr(HasVersion, Protocol[FrameT_contra, SeriesT_co]): @@ -292,7 +291,7 @@ def gather_every( ) -> Self: ... def is_in_series( self, - node: FunctionExpr[boolean.IsInSeries[Incomplete]], + node: FunctionExpr[boolean.IsInSeries[IncompleteCyclic]], frame: FrameT_contra, name: str, ) -> Self: ... diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 23a06a53c7..4f3ba969bc 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -3,14 +3,14 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol from narwhals._plan.compliant.typing import HasVersion -from narwhals._plan.typing import NativeSeriesT +from narwhals._plan.typing import IncompleteCyclic, NativeSeriesT from narwhals._utils import Version, _StoresNative, unstable if TYPE_CHECKING: from collections.abc import Iterable, Sequence import polars as pl - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.compliant.accessors import SeriesStructNamespace from narwhals._plan.compliant.dataframe import CompliantDataFrame @@ -29,8 +29,6 @@ _1DArray, ) -Incomplete: TypeAlias = Any - class CompliantSeries(HasVersion, Protocol[NativeSeriesT]): implementation: ClassVar[_EagerAllowedImpl] @@ -79,7 +77,7 @@ def not_(self) -> Self: def pow(self, exponent: float | Self) -> Self: return self.__pow__(exponent) - def __narwhals_namespace__(self) -> Incomplete: ... + def __narwhals_namespace__(self) -> IncompleteCyclic: ... def __narwhals_series__(self) -> Self: return self @@ -185,7 +183,7 @@ def sample_n( def scatter(self, indices: Self, values: Self) -> Self: ... def slice(self, offset: int, length: int | None = None) -> Self: ... def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: ... - def to_frame(self) -> Incomplete: ... + def to_frame(self) -> IncompleteCyclic: ... def to_list(self) -> list[Any]: ... def to_narwhals(self) -> Series[NativeSeriesT]: from narwhals._plan.series import Series @@ -205,7 +203,7 @@ def hist( include_breakpoint: bool = True, include_category: bool = False, _compatibility_behavior: Literal["narwhals", "polars"] = "narwhals", - ) -> CompliantDataFrame[Self, Incomplete, NativeSeriesT]: + ) -> CompliantDataFrame[Self, IncompleteCyclic, NativeSeriesT]: from narwhals._plan.expressions import col as ir_col expr = ( @@ -218,7 +216,7 @@ def hist( include_category=include_category, ) ) - df: DataFrame[Incomplete, NativeSeriesT] = ( + df: DataFrame[IncompleteCyclic, NativeSeriesT] = ( self.to_narwhals().to_frame().select(expr) ) if not include_breakpoint and not include_category: @@ -229,4 +227,4 @@ def hist( return df._compliant @property - def struct(self) -> SeriesStructNamespace[Incomplete, Self]: ... + def struct(self) -> SeriesStructNamespace[IncompleteCyclic, Self]: ... diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index b18697a4b4..08f943978e 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -11,6 +11,7 @@ from narwhals._plan.series import Series from narwhals._plan.typing import ( ColumnNameOrSelector, + IncompleteCyclic, IntoExpr, IntoExprColumn, NativeDataFrameT, @@ -191,7 +192,7 @@ def _dataframe_from_dict( class DataFrame( BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT] ): - _compliant: CompliantDataFrame[Any, NativeDataFrameT_co, NativeSeriesT] + _compliant: CompliantDataFrame[IncompleteCyclic, NativeDataFrameT_co, NativeSeriesT] @property def implementation(self) -> _EagerAllowedImpl: diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 0f6a714044..390e4cc631 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -4,7 +4,13 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal from narwhals._plan._guards import is_series -from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co, OneOrIterable, SeriesT +from narwhals._plan.typing import ( + IncompleteCyclic, + NativeSeriesT, + NativeSeriesT_co, + OneOrIterable, + SeriesT, +) from narwhals._utils import ( Implementation, Version, @@ -20,7 +26,7 @@ from collections.abc import Iterator import polars as pl - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.dataframe import DataFrame @@ -34,8 +40,6 @@ TemporalLiteral, ) -Incomplete: TypeAlias = Any - class Series(Generic[NativeSeriesT_co]): _compliant: CompliantSeries[NativeSeriesT_co] @@ -102,9 +106,7 @@ def from_native( raise NotImplementedError(type(native)) - # NOTE: `Incomplete` until `CompliantSeries` can avoid a cyclic dependency back to `CompliantDataFrame` - # Currently an issue on `main` and leads to a lot of intermittent warnings - def to_frame(self) -> DataFrame[Incomplete, NativeSeriesT_co]: + def to_frame(self) -> DataFrame[IncompleteCyclic, NativeSeriesT_co]: import narwhals._plan.dataframe as _df # NOTE: Missing placeholder for `DataFrameV1` @@ -281,7 +283,7 @@ def hist( include_breakpoint: bool = True, include_category: bool = False, _compatibility_behavior: Literal["narwhals", "polars"] = "narwhals", - ) -> DataFrame[Incomplete, NativeSeriesT_co]: + ) -> DataFrame[IncompleteCyclic, NativeSeriesT_co]: return self._compliant.hist( bins, bin_count=bin_count, diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 9c4a6ca927..4b35e7722b 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -141,3 +141,13 @@ [^1]: `ByName`, `ByIndex` will never be ignored. """ + + +IncompleteCyclic: TypeAlias = "t.Any" +"""Placeholder for typing that introduces a cyclic dependency. + +Mainly for spelling `(Compliant)DataFrame` from within `(Compliant)Series`. + +On `main`, this works fine when running a type checker from the CLI - but causes +intermittent warnings when running in a language server. +""" From 66f116339197ea99c96519a14454e29f687f3fbf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 16:48:47 +0000 Subject: [PATCH 172/215] tests: cov `Series.cast` --- narwhals/_plan/series.py | 2 +- tests/plan/compliant_test.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 390e4cc631..6fa0743316 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -128,7 +128,7 @@ def __iter__(self) -> Iterator[Any]: def alias(self, name: str) -> Self: return type(self)(self._compliant.alias(name)) - def cast(self, dtype: IntoDType) -> Self: # pragma: no cover + def cast(self, dtype: IntoDType) -> Self: return type(self)(self._compliant.cast(dtype)) def __len__(self) -> int: diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index e503cd63ec..d8ff90be6b 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -766,6 +766,18 @@ def test_series_misc() -> None: assert len(list(ser)) == len(values) +def test_series_cast() -> None: + pytest.importorskip("pyarrow") + ser = nwp.int_range(10, step=2, eager="pyarrow", dtype=nw.Int64) + assert ser.dtype == nw.Int64 + ser_float = ser.cast(nw.Float64) + assert ser_float.dtype == nw.Float64 + assert ser.dtype == nw.Int64 + result = ser_float + 0.5 + expected = [0.5, 2.5, 4.5, 6.5, 8.5] + assert_equal_series(result, expected, "literal") + + if TYPE_CHECKING: from typing_extensions import assert_type From 68c08c0cdcabb03470cba2a2935a990de6c2055c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 19:25:01 +0000 Subject: [PATCH 173/215] test: A whole bunch of `str.*` method coverage Resolves https://github.com/narwhals-dev/narwhals/pull/3325#discussion_r2603068814 --- narwhals/_plan/arrow/expr.py | 1 - narwhals/_plan/expressions/strings.py | 6 +-- tests/plan/str_contains_test.py | 35 ++++++++++++++ tests/plan/str_split_test.py | 19 ++++++++ tests/plan/str_starts_ends_with_test.py | 37 +++++++++++++++ tests/plan/str_strip_chars_test.py | 16 +++++++ tests/plan/str_transform_case_test.py | 61 +++++++++++++++++++++++++ 7 files changed, 171 insertions(+), 4 deletions(-) create mode 100644 tests/plan/str_contains_test.py create mode 100644 tests/plan/str_split_test.py create mode 100644 tests/plan/str_starts_ends_with_test.py create mode 100644 tests/plan/str_strip_chars_test.py create mode 100644 tests/plan/str_transform_case_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 15ee7e5f91..3e350e4859 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -961,7 +961,6 @@ def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: contains = not_implemented() -# TODO @dangotbanned: Add tests for these, especially those using a different native function class ArrowStringNamespace( ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 5dc730e8e5..42ea31bb70 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -170,7 +170,7 @@ def strip_chars(self, characters: str | None = None) -> Expr: def starts_with(self, prefix: str) -> Expr: return self._with_unary(self._ir.starts_with(prefix=prefix)) - def ends_with(self, suffix: str) -> Expr: # pragma: no cover + def ends_with(self, suffix: str) -> Expr: return self._with_unary(self._ir.ends_with(suffix=suffix)) def contains(self, pattern: str, *, literal: bool = False) -> Expr: @@ -185,7 +185,7 @@ def head(self, n: int = 5) -> Expr: def tail(self, n: int = 5) -> Expr: return self._with_unary(self._ir.tail(n)) - def split(self, by: str) -> Expr: # pragma: no cover + def split(self, by: str) -> Expr: return self._with_unary(self._ir.split(by=by)) def to_date(self, format: str | None = None) -> Expr: # pragma: no cover @@ -200,7 +200,7 @@ def to_lowercase(self) -> Expr: def to_uppercase(self) -> Expr: return self._with_unary(self._ir.to_uppercase()) - def to_titlecase(self) -> Expr: # pragma: no cover + def to_titlecase(self) -> Expr: return self._with_unary(self._ir.to_titlecase()) def zfill(self, length: int) -> Expr: diff --git a/tests/plan/str_contains_test.py b/tests/plan/str_contains_test.py new file mode 100644 index 0000000000..666992e557 --- /dev/null +++ b/tests/plan/str_contains_test.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"pets": ["cat", "dog", "rabbit and parrot", "dove", "Parrot|dove", None]} + + +@pytest.mark.parametrize( + ("pattern", "literal", "expected"), + [ + ("(?i)parrot|Dove", False, [False, False, True, True, True, None]), + ("parrot|Dove", False, [False, False, True, False, False, None]), + ("Parrot|dove", False, [False, False, False, True, True, None]), + ("Parrot|dove", True, [False, False, False, False, True, None]), + ], + ids=["case_insensitive", "case_sensitive-1", "case_sensitive-2", "literal"], +) +def test_str_contains( + data: Data, pattern: str, *, literal: bool, expected: list[bool | None] +) -> None: + result = dataframe(data).select( + nwp.col("pets").str.contains(pattern, literal=literal) + ) + assert_equal_data(result, {"pets": expected}) diff --git a/tests/plan/str_split_test.py b/tests/plan/str_split_test.py new file mode 100644 index 0000000000..548c909d33 --- /dev/null +++ b/tests/plan/str_split_test.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +@pytest.mark.parametrize( + ("by", "expected"), + [ + ("_", [["foo bar"], ["foo", "bar"], ["foo", "bar", "baz"], ["foo,bar"]]), + (",", [["foo bar"], ["foo_bar"], ["foo_bar_baz"], ["foo", "bar"]]), + ], +) +def test_str_split(by: str, expected: list[list[str]]) -> None: + data = {"a": ["foo bar", "foo_bar", "foo_bar_baz", "foo,bar"]} + result = dataframe(data).select(nwp.col("a").str.split(by)) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/str_starts_ends_with_test.py b/tests/plan/str_starts_ends_with_test.py new file mode 100644 index 0000000000..ebf3f4e0f9 --- /dev/null +++ b/tests/plan/str_starts_ends_with_test.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": ["Starts_with", "starts_with", "Ends_with", "ends_With", None]} + + +@pytest.mark.parametrize( + ("prefix", "expected"), + [ + ("start", [False, True, False, False, None]), + ("End", [False, False, True, False, None]), + ], +) +def test_str_starts_with(data: Data, prefix: str, expected: list[bool | None]) -> None: + result = dataframe(data).select(nwp.col("a").str.starts_with(prefix)) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + ("suffix", "expected"), + [("With", [False, False, False, True, None]), ("th", [True, True, True, True, None])], +) +def test_str_ends_with(data: Data, suffix: str, expected: list[bool | None]) -> None: + result = dataframe(data).select(nwp.col("a").str.ends_with(suffix)) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/str_strip_chars_test.py b/tests/plan/str_strip_chars_test.py new file mode 100644 index 0000000000..9910d69ec1 --- /dev/null +++ b/tests/plan/str_strip_chars_test.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +@pytest.mark.parametrize( + ("characters", "expected"), + [(None, ["foobar", "bar", "baz"]), ("foo", ["bar", "bar\n", " baz"])], +) +def test_str_strip_chars(characters: str | None, expected: list[str]) -> None: + data = {"a": ["foobar", "bar\n", " baz"]} + result = dataframe(data).select(nwp.col("a").str.strip_chars(characters)) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/str_transform_case_test.py b/tests/plan/str_transform_case_test.py new file mode 100644 index 0000000000..e7bf2be7c7 --- /dev/null +++ b/tests/plan/str_transform_case_test.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + StrData: TypeAlias = dict[str, list[str]] + + +@pytest.fixture(scope="module") +def data() -> StrData: + return { + "a": [ + "e.t. phone home", + "they're bill's friends from the UK", + "to infinity,and BEYOND!", + "with123numbers", + "__dunder__score_a1_.2b ?three", + ] + } + + +@pytest.fixture(scope="module") +def data_lower(data: StrData) -> StrData: + return {"a": [*data["a"], "SPECIAL CASE ß", "ΣPECIAL CAΣE"]} + + +@pytest.fixture(scope="module") +def expected_title(data: StrData) -> StrData: + return {"a": [s.title() for s in data["a"]]} + + +@pytest.fixture(scope="module") +def expected_upper(data: StrData) -> StrData: + return {"a": [s.upper() for s in data["a"]]} + + +@pytest.fixture(scope="module") +def expected_lower(data_lower: StrData) -> StrData: + return {"a": [s.lower() for s in data_lower["a"]]} + + +def test_str_to_titlecase(data: StrData, expected_title: StrData) -> None: + result = dataframe(data).select(nwp.col("a").str.to_titlecase()) + assert_equal_data(result, expected_title) + + +def test_str_to_uppercase(data: StrData, expected_upper: StrData) -> None: + result = dataframe(data).select(nwp.col("a").str.to_uppercase()) + assert_equal_data(result, expected_upper) + + +def test_str_to_lowercase(data_lower: StrData, expected_lower: StrData) -> None: + result = dataframe(data_lower).select(nwp.col("a").str.to_lowercase()) + assert_equal_data(result, expected_lower) From f50278d7a79001a49f510a0cbac695ed09ed2e6f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Dec 2025 21:06:00 +0000 Subject: [PATCH 174/215] fix: Actually support `Sequence` in `Series.gather` Discovered while addressing https://github.com/narwhals-dev/narwhals/pull/3325#discussion_r2603034954 --- narwhals/_plan/series.py | 3 +-- tests/plan/compliant_test.py | 12 +++++++++--- tests/plan/gather_test.py | 28 +++++++++++++++++++++++++++- 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 6fa0743316..3efb825e29 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -137,8 +137,7 @@ def __len__(self) -> int: def gather(self, indices: SizedMultiIndexSelector[Self]) -> Self: if len(indices) == 0: return self.slice(0, 0) - rows = indices._compliant if isinstance(indices, Series) else indices - return type(self)(self._compliant.gather(rows)) + return type(self)(self._compliant.gather(self._parse_into_compliant(indices))) def gather_every(self, n: int, offset: int = 0) -> Self: return type(self)(self._compliant.gather_every(n, offset)) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index d8ff90be6b..eb498f786d 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -759,11 +759,17 @@ def test_series_misc() -> None: assert_equal_series(is_not_nan & is_not_null, expected, name) assert ser.unique().drop_nans().drop_nulls().count() == 6 + assert len(list(ser)) == len(values) - assert_equal_series(ser.gather([0, 2, 4]).sort(), [1.0, 4.9, 7.1], name) - assert ser.gather([]).to_list() == [] - assert len(list(ser)) == len(values) +def test_series_sort() -> None: + ser = series([1.0, 7.1, None, 4.9]) + assert_equal_series(ser.sort(), [None, 1.0, 4.9, 7.1], "") + assert_equal_series(ser.sort(nulls_last=True), [1.0, 4.9, 7.1, None], "") + assert_equal_series(ser.sort(descending=True), [None, 7.1, 4.9, 1.0], "") + assert_equal_series( + ser.sort(descending=True, nulls_last=True), [7.1, 4.9, 1.0, None], "" + ) def test_series_cast() -> None: diff --git a/tests/plan/gather_test.py b/tests/plan/gather_test.py index f2e3453509..d8708297bf 100644 --- a/tests/plan/gather_test.py +++ b/tests/plan/gather_test.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest @@ -32,6 +32,32 @@ def test_gather_every_series(data: Data, n: int, offset: int, column: str) -> No assert_equal_series(result, expected, column) +@pytest.mark.parametrize( + ("column", "indices", "expected"), + [ + ("idx", [], []), + ("name", [], []), + ("idx", [0, 4, 2], [0, 4, 2]), + ("name", [1, 5, 5], ["b", "f", "f"]), + pytest.param( + "idx", + [-1], + [9], + marks=pytest.mark.xfail( + reason="TODO: Handle negative indices", raises=IndexError + ), + ), + ("name", range(5, 7), ["f", "g"]), + ], +) +def test_gather_series( + data: Data, column: str, indices: Any, expected: list[Any] +) -> None: + ser = series(data[column]).alias(column) + result = ser.gather(indices) + assert_equal_series(result, expected, column) + + @pytest.mark.parametrize("n", [1, 2, 3]) @pytest.mark.parametrize("offset", [0, 1, 2, 3]) def test_gather_every_dataframe(data: Data, n: int, offset: int) -> None: From 937bc57e0b1e5d4f1630038eb10541d099e1f934 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:44:40 +0000 Subject: [PATCH 175/215] feat(expr-ir): Add `linear_space` (#3349) --- narwhals/_plan/__init__.py | 2 + narwhals/_plan/arrow/functions.py | 53 +++-- narwhals/_plan/arrow/namespace.py | 34 ++- narwhals/_plan/compliant/namespace.py | 14 +- narwhals/_plan/expr.py | 2 +- narwhals/_plan/expressions/ranges.py | 8 + narwhals/_plan/functions.py | 117 +++++++++- narwhals/_plan/series.py | 2 + tests/plan/compliant_test.py | 35 +-- tests/plan/expr_parsing_test.py | 51 +---- tests/plan/range_test.py | 297 ++++++++++++++++++++++++++ tests/plan/utils.py | 4 +- 12 files changed, 495 insertions(+), 124 deletions(-) create mode 100644 tests/plan/range_test.py diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index 1528ea6ac6..03ef4b249c 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -15,6 +15,7 @@ format, int_range, len, + linear_space, lit, max, max_horizontal, @@ -47,6 +48,7 @@ "format", "int_range", "len", + "linear_space", "lit", "max", "max_horizontal", diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 1969cd33d3..4b0a65101e 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1435,43 +1435,38 @@ def date_range( def linear_space( start: float, end: float, num_samples: int, *, closed: ClosedInterval = "both" ) -> ChunkedArray[pc.NumericScalar]: - """Based on [`np.linspace`]. + """Based on [`new_linear_space_f64`]. - Use when implementing `hist`. - - [`np.linspace`]: https://github.com/numpy/numpy/blob/v2.3.0/numpy/_core/function_base.py#L26-L187 + [`new_linear_space_f64`]: https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-ops/src/series/ops/linear_space.rs#L62-L94 """ if num_samples < 0: msg = f"Number of samples, {num_samples}, must be non-negative." raise ValueError(msg) - if closed == "both": - range_end = num_samples - div = num_samples - 1 + if num_samples == 0: + return chunked_array([[]], F64) + if num_samples == 1: + if closed == "none": + value = (end + start) * 0.5 + elif closed in {"left", "both"}: + value = float(start) + else: + value = float(end) + return chunked_array([[value]], F64) + n = num_samples + span = float(end - start) + if closed == "none": + d = span / (n + 1) + start = start + d elif closed == "left": - range_end = num_samples - div = num_samples + d = span / n elif closed == "right": - range_end = num_samples + 1 - div = num_samples - elif closed == "none": - range_end = num_samples + 1 - div = num_samples + 1 - ca: ChunkedArray[pc.NumericScalar] = int_range(0, range_end).cast(F64) - delta = float(end - start) - if div > 0: - step = delta / div - if step == 0: - ca = truediv(ca, lit(div)) - ca = multiply(ca, lit(delta)) - else: - ca = multiply(ca, lit(step)) + start = start + span / n + d = span / n else: - ca = multiply(ca, lit(delta)) - if start != 0: - ca = add(ca, lit(start, F64)) - if closed in {"right", "none"}: - return ca.slice(1) - return ca + d = span / (n - 1) + ca: ChunkedArray[pc.NumericScalar] = multiply(int_range(0, n).cast(F64), lit(d)) + ca = add(ca, lit(start, F64)) + return ca # noqa: RET504 def repeat(value: ScalarAny | NonNestedLiteral, n: int) -> ArrayAny: diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 84efdc43bf..b2d6f60e21 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -11,6 +11,7 @@ from narwhals._plan._guards import is_tuple_of from narwhals._plan.arrow import functions as fn from narwhals._plan.compliant.namespace import EagerNamespace +from narwhals._plan.expressions.expr import RangeExpr from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._utils import Implementation, Version from narwhals.exceptions import InvalidOperationError @@ -26,7 +27,7 @@ from narwhals._plan.expressions import expr, functions as F from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal from narwhals._plan.expressions.expr import FunctionExpr as FExpr, RangeExpr - from narwhals._plan.expressions.ranges import DateRange, IntRange + from narwhals._plan.expressions.ranges import DateRange, IntRange, LinearSpace from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.series import Series as NwSeries from narwhals._plan.typing import NonNestedLiteralT @@ -176,8 +177,13 @@ def concat_str( return self._scalar.from_native(result, name, self.version) return self._expr.from_native(result, name, self.version) + # TODO @dangotbanned: Refactor alongside `nwp.functions._ensure_range_scalar` + # Consider returning the supertype of inputs def _range_function_inputs( - self, node: RangeExpr, frame: Frame, valid_type: type[NonNestedLiteralT] + self, + node: RangeExpr, + frame: Frame, + valid_type: type[NonNestedLiteralT] | tuple[type[NonNestedLiteralT], ...], ) -> tuple[NonNestedLiteralT, NonNestedLiteralT]: start_: PythonLiteral end_: PythonLiteral @@ -198,8 +204,10 @@ def _range_function_inputs( ) raise InvalidOperationError(msg) if isinstance(start_, valid_type) and isinstance(end_, valid_type): - return start_, end_ - msg = f"All inputs for `{node.function}()` must resolve to {valid_type.__name__}, but got \n{start_!r}\n{end_!r}" + return start_, end_ # type: ignore[return-value] + valid_types = (valid_type,) if not isinstance(valid_type, tuple) else valid_type + tp_names = " | ".join(tp.__name__ for tp in valid_types) + msg = f"All inputs for `{node.function}()` must resolve to {tp_names}, but got \n{start_!r}\n{end_!r}" raise InvalidOperationError(msg) def _int_range( @@ -247,6 +255,24 @@ def date_range_eager( native = fn.date_range(start, end, interval, closed=closed) return self._series.from_native(native, name, version=self.version) + def linear_space(self, node: RangeExpr[LinearSpace], frame: Frame, name: str) -> Expr: + start, end = self._range_function_inputs(node, frame, (int, float)) + func = node.function + native = fn.linear_space(start, end, func.num_samples, closed=func.closed) + return self._expr.from_native(native, name, self.version) + + def linear_space_eager( + self, + start: float, + end: float, + num_samples: int, + *, + closed: ClosedInterval = "both", + name: str = "literal", + ) -> Series: + native = fn.linear_space(start, end, num_samples, closed=closed) + return self._series.from_native(native, name, version=self.version) + @overload def concat(self, items: Iterable[Frame], *, how: ConcatMethod) -> Frame: ... @overload diff --git a/narwhals/_plan/compliant/namespace.py b/narwhals/_plan/compliant/namespace.py index c3ebb47b55..a5a132e0c5 100644 --- a/narwhals/_plan/compliant/namespace.py +++ b/narwhals/_plan/compliant/namespace.py @@ -26,7 +26,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import FunctionExpr, boolean, functions as F - from narwhals._plan.expressions.ranges import DateRange, IntRange + from narwhals._plan.expressions.ranges import DateRange, IntRange, LinearSpace from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.series import Series from narwhals.dtypes import IntegerType @@ -63,6 +63,9 @@ def date_range( def int_range( self, node: ir.RangeExpr[IntRange], frame: FrameT, name: str ) -> ExprT_co: ... + def linear_space( + self, node: ir.RangeExpr[LinearSpace], frame: FrameT, name: str + ) -> ExprT_co: ... def len(self, node: ir.Len, frame: FrameT, name: str) -> ScalarT_co: ... def lit( self, node: ir.Literal[Any], frame: FrameT, name: str @@ -159,6 +162,15 @@ def int_range_eager( dtype: IntegerType = Int64, name: str = "literal", ) -> SeriesT: ... + def linear_space_eager( + self, + start: float, + end: float, + num_samples: int, + *, + closed: ClosedInterval = "both", + name: str = "literal", + ) -> SeriesT: ... class LazyNamespace( diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 05d19b0180..6c8f5a6154 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -607,7 +607,7 @@ def name(self) -> ExprNameNamespace: >>> >>> renamed = nw.col("a", "b").name.suffix("_changed") >>> str(renamed._ir) - "RenameAlias(expr=RootSelector(selector=ByName(names=[a, b], require_all=True)), function=Suffix(suffix='_changed'))" + "RenameAlias(expr=RootSelector(selector=ByName(names=['a', 'b'], require_all=True)), function=Suffix(suffix='_changed'))" """ from narwhals._plan.expressions.name import ExprNameNamespace diff --git a/narwhals/_plan/expressions/ranges.py b/narwhals/_plan/expressions/ranges.py index 644a89dc1e..fba04afbbf 100644 --- a/narwhals/_plan/expressions/ranges.py +++ b/narwhals/_plan/expressions/ranges.py @@ -38,3 +38,11 @@ class DateRange(RangeFunction, options=FunctionOptions.row_separable): __slots__ = ("interval", "closed") # noqa: RUF023 interval: int closed: ClosedInterval + + +class LinearSpace(RangeFunction, options=FunctionOptions.row_separable): + """N-ary (start, end).""" + + __slots__ = ("num_samples", "closed") # noqa: RUF023 + num_samples: int + closed: ClosedInterval diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 334eea1e91..11900306aa 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -11,12 +11,18 @@ from narwhals._plan.exceptions import list_literal_error from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral -from narwhals._plan.expressions.ranges import DateRange, IntRange, RangeFunction +from narwhals._plan.expressions.ranges import ( + DateRange, + IntRange, + LinearSpace, + RangeFunction, +) from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.when_then import When from narwhals._utils import ( Implementation, Version, + ensure_type, flatten, is_eager_allowed, qualified_type_name, @@ -314,6 +320,95 @@ def date_range( ) +@t.overload +def linear_space( + start: float | IntoExprColumn, + end: float | IntoExprColumn, + num_samples: int, + *, + closed: ClosedInterval = ..., + eager: t.Literal[False] = ..., +) -> Expr: ... +@t.overload +def linear_space( + start: float, + end: float, + num_samples: int, + *, + closed: ClosedInterval = ..., + eager: Arrow, +) -> Series[pa.ChunkedArray[t.Any]]: ... +@t.overload +def linear_space( + start: float, + end: float, + num_samples: int, + *, + closed: ClosedInterval = ..., + eager: IntoBackend[EagerAllowed], +) -> Series: ... +def linear_space( + start: float | IntoExprColumn, + end: float | IntoExprColumn, + num_samples: int, + *, + closed: ClosedInterval = "both", + eager: IntoBackend[EagerAllowed] | t.Literal[False] = False, +) -> Expr | Series: + """Create sequence of evenly-spaced points. + + Arguments: + start: Lower bound of the range. + end: Upper bound of the range. + num_samples: Number of samples in the output sequence. + closed: Define which sides of the interval are closed (inclusive). + eager: If set to `False` (default), then an expression is returned. + If set to an (eager) implementation ("pandas", "polars" or "pyarrow"), then + a `Series` is returned. + + Notes: + Unlike `pl.linear_space`, *currently* only numeric dtypes (and not temporal) are supported. + + Examples: + >>> import narwhals._plan as nwp + >>> nwp.linear_space(start=0, end=1, num_samples=3, eager="pyarrow").to_list() + [0.0, 0.5, 1.0] + + >>> nwp.linear_space(0, 1, 3, closed="left", eager="pyarrow").to_list() + [0.0, 0.3333333333333333, 0.6666666666666666] + + >>> nwp.linear_space(0, 1, 3, closed="right", eager="pyarrow").to_list() + [0.3333333333333333, 0.6666666666666666, 1.0] + + >>> nwp.linear_space(0, 1, 3, closed="none", eager="pyarrow").to_list() + [0.25, 0.5, 0.75] + + >>> df = nwp.DataFrame.from_dict({"a": [1, 2, 3, 4, 5]}, backend="pyarrow") + >>> df.with_columns(nwp.linear_space(0, 10, 5).alias("ls")) + ┌──────────────────────┐ + | nw.DataFrame | + |----------------------| + |pyarrow.Table | + |a: int64 | + |ls: double | + |---- | + |a: [[1,2,3,4,5]] | + |ls: [[0,2.5,5,7.5,10]]| + └──────────────────────┘ + """ + ensure_type(num_samples, int, param_name="num_samples") + closed = _ensure_closed_interval(closed) + if eager: + ns = _eager_namespace(eager) + start, end = _ensure_range_scalar(start, end, (float, int), LinearSpace, eager) + return _linear_space_eager(start, end, num_samples, closed, ns) + return ( + LinearSpace(num_samples=num_samples, closed=closed) + .to_function_expr(*_parse.parse_into_seq_of_expr_ir(start, end)) + .to_narwhals() + ) + + @t.overload def _eager_namespace(backend: Arrow, /) -> _arrow.Namespace: ... @t.overload @@ -332,11 +427,12 @@ def _eager_namespace( raise ValueError(msg) -# NOTE: If anything beyond `{date,int}_range` are added, move to `RangeFunction` +# TODO @dangotbanned: Handle this in `RangeFunction` or `RangeExpr` +# NOTE: `ArrowNamespace._range_function_inputs` has some duplicated logic too def _ensure_range_scalar( start: t.Any, end: t.Any, - valid_type: type[NonNestedLiteralT], + valid_type: type[NonNestedLiteralT] | tuple[type[NonNestedLiteralT], ...], function: type[RangeFunction], eager: IntoBackend[EagerAllowed], ) -> tuple[NonNestedLiteralT, NonNestedLiteralT]: @@ -344,8 +440,10 @@ def _ensure_range_scalar( return start, end tp_start = qualified_type_name(start) tp_end = qualified_type_name(end) + valid_types = (valid_type,) if not isinstance(valid_type, tuple) else valid_type + tp_names = " | ".join(tp.__name__ for tp in valid_types) msg = ( - f"Expected `start` and `end` to be {valid_type.__name__} values since `eager={eager}`, but got: ({tp_start}, {tp_end})\n\n" + f"Expected `start` and `end` to be {tp_names} values since `eager={eager}`, but got: ({tp_start}, {tp_end})\n\n" f"Hint: Calling `nw.{get_dispatch_name(function)}` with expressions requires:\n" " - `eager=False`\n" " - a context such as `select` or `with_columns`" @@ -374,6 +472,17 @@ def _int_range_eager( return ns.int_range_eager(start, end, step, dtype=dtype).to_narwhals() +def _linear_space_eager( + start: float, + end: float, + num_samples: int, + closed: ClosedInterval, + ns: EagerNs[NativeSeriesT], + /, +) -> Series[NativeSeriesT]: + return ns.linear_space_eager(start, end, num_samples, closed=closed).to_narwhals() + + def _ensure_closed_interval(closed: ClosedInterval, /) -> ClosedInterval: closed_intervals = "left", "right", "none", "both" if closed not in closed_intervals: diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 3efb825e29..65a8265d9c 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -128,6 +128,8 @@ def __iter__(self) -> Iterator[Any]: def alias(self, name: str) -> Self: return type(self)(self._compliant.alias(name)) + rename = alias + def cast(self, dtype: IntoDType) -> Self: return type(self)(self._compliant.cast(dtype)) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index eb498f786d..37d6cc605e 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -11,7 +11,6 @@ pytest.importorskip("pyarrow") pytest.importorskip("numpy") -import datetime as dt import pyarrow as pa @@ -27,6 +26,7 @@ ) if TYPE_CHECKING: + import datetime as dt from collections.abc import Sequence from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable @@ -129,39 +129,6 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: .name.to_uppercase(), {"C": [2.0, 9.0, 4.0], "D": [7.0, 8.0, 8.0]}, ), - ([nwp.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), - ([nwp.int_range(nwp.len())], {"literal": [0, 1, 2]}), - (nwp.int_range(nwp.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), - (nwp.int_range(nwp.col("b").min() + 4, nwp.col("d").last()), {"b": [5, 6, 7]}), - ( - [ - nwp.date_range( - dt.date(2020, 1, 1), - dt.date(2020, 4, 30), - interval="25d", - closed="none", - ) - ], - { - "literal": [ - dt.date(2020, 1, 26), - dt.date(2020, 2, 20), - dt.date(2020, 3, 16), - dt.date(2020, 4, 10), - ] - }, - ), - ( - ( - nwp.date_range( - dt.date(2021, 1, 30), - nwp.lit(18747, nw.Int32).cast(nw.Date), - interval="90d", - closed="left", - ).alias("date_range_cast_expr"), - {"date_range_cast_expr": [dt.date(2021, 1, 30)]}, - ) - ), (nwp.col("b") ** 2, {"b": [1, 4, 9]}), ( [2 ** nwp.col("b"), (nwp.lit(2.0) ** nwp.nth(1)).alias("lit")], diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 54cbf5de68..54665ac8a8 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -209,13 +209,7 @@ def test_date_range_invalid() -> None: nwp.date_range(start, end, interval="3y") -def test_int_range_eager() -> None: - series = nwp.int_range(50, eager="pyarrow") - assert isinstance(series, nwp.Series) - assert series.to_list() == list(range(50)) - series = nwp.int_range(50, eager=Implementation.PYARROW) - assert series.to_list() == list(range(50)) - +def test_int_range_eager_invalid() -> None: with pytest.raises(InvalidOperationError): nwp.int_range(nwp.len(), eager="pyarrow") # type: ignore[call-overload] with pytest.raises(InvalidOperationError): @@ -226,49 +220,6 @@ def test_int_range_eager() -> None: nwp.int_range(10, eager="duckdb") # type: ignore[call-overload] -def test_date_range_eager() -> None: - leap_year = 2024 - series_leap = nwp.date_range( - dt.date(leap_year, 2, 25), dt.date(leap_year, 3, 25), eager="pyarrow" - ) - series_regular = nwp.date_range( - dt.date(leap_year + 1, 2, 25), - dt.date(leap_year + 1, 3, 25), - interval=dt.timedelta(days=1), - eager="pyarrow", - ) - assert len(series_regular) == 29 - assert len(series_leap) == 30 - - expected = [ - dt.date(2000, 1, 1), - dt.date(2002, 9, 14), - dt.date(2005, 5, 28), - dt.date(2008, 2, 9), - dt.date(2010, 10, 23), - dt.date(2013, 7, 6), - dt.date(2016, 3, 19), - dt.date(2018, 12, 1), - dt.date(2021, 8, 14), - ] - - series = nwp.date_range( - dt.date(2000, 1, 1), dt.date(2023, 8, 31), interval="987d", eager="pyarrow" - ) - result = series.to_list() - assert result == expected - - expected = [dt.date(2006, 10, 14), dt.date(2013, 7, 27), dt.date(2020, 5, 9)] - result = nwp.date_range( - dt.date(2000, 1, 1), - dt.date(2023, 8, 31), - interval="354w", - closed="right", - eager="pyarrow", - ).to_list() - assert result == expected - - def test_date_range_eager_invalid() -> None: start, end = dt.date(2000, 1, 1), dt.date(2001, 1, 1) diff --git a/tests/plan/range_test.py b/tests/plan/range_test.py new file mode 100644 index 0000000000..7765f23d59 --- /dev/null +++ b/tests/plan/range_test.py @@ -0,0 +1,297 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Final, Literal + +import pytest + +from narwhals.exceptions import ShapeError +from tests.utils import PYARROW_VERSION + +if PYARROW_VERSION < (21,): + pytest.importorskip("numpy") +import datetime as dt + +import narwhals as nw +from narwhals import _plan as nwp +from tests.conftest import TEST_EAGER_BACKENDS +from tests.plan.utils import assert_equal_data, assert_equal_series, dataframe + +if TYPE_CHECKING: + from collections.abc import Sequence + + from narwhals.typing import ClosedInterval, EagerAllowed, IntoDType + + +@pytest.fixture(scope="module") +def data() -> dict[str, Any]: + """Variant of `compliant_test.data_small`, with only numeric data.""" + return { + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + "j": [12.1, None, 4.0], + "k": [42, 10, None], + "l": [4, 5, 6], + "m": [0, 1, 2], + } + + +_HAS_IMPLEMENTATION = frozenset((nw.Implementation.PYARROW, "pyarrow")) +"""Using to filter *the source* of `eager_backend` - which includes `polars` and `pandas` when available. + +For now, this lets some tests be written in a backend agnostic way. +""" + +_HAS_IMPLEMENTATION_IMPL = frozenset( + el for el in _HAS_IMPLEMENTATION if isinstance(el, nw.Implementation) +) +"""Filtered for heavily parametric tests.""" + + +@pytest.fixture( + scope="module", params=_HAS_IMPLEMENTATION.intersection(TEST_EAGER_BACKENDS) +) +def eager(request: pytest.FixtureRequest) -> EagerAllowed: + result: EagerAllowed = request.param + return result + + +@pytest.fixture( + scope="module", + params=_HAS_IMPLEMENTATION_IMPL.intersection(TEST_EAGER_BACKENDS).union([False]), +) +def backend(request: pytest.FixtureRequest) -> EagerAllowed | Literal[False]: + result: EagerAllowed | Literal[False] = request.param + return result + + +@pytest.fixture(scope="module", params=[2024, 2400]) +def leap_year(request: pytest.FixtureRequest) -> int: + result: int = request.param + return result + + +EXPECTED_DATE_1: Final = [ + dt.date(2020, 1, 26), + dt.date(2020, 2, 20), + dt.date(2020, 3, 16), + dt.date(2020, 4, 10), +] +EXPECTED_DATE_2: Final = [dt.date(2021, 1, 30)] +EXPECTED_DATE_3: Final = [ + dt.date(2000, 1, 1), + dt.date(2002, 9, 14), + dt.date(2005, 5, 28), + dt.date(2008, 2, 9), + dt.date(2010, 10, 23), + dt.date(2013, 7, 6), + dt.date(2016, 3, 19), + dt.date(2018, 12, 1), + dt.date(2021, 8, 14), +] +EXPECTED_DATE_4: Final = [ + dt.date(2006, 10, 14), + dt.date(2013, 7, 27), + dt.date(2020, 5, 9), +] + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + [ + nwp.date_range( + dt.date(2020, 1, 1), + dt.date(2020, 4, 30), + interval="25d", + closed="none", + ) + ], + {"literal": EXPECTED_DATE_1}, + ), + ( + ( + nwp.date_range( + dt.date(2021, 1, 30), + nwp.lit(18747, nw.Int32).cast(nw.Date), + interval="90d", + closed="left", + ).alias("date_range_cast_expr"), + {"date_range_cast_expr": EXPECTED_DATE_2}, + ) + ), + ], +) +def test_date_range( + expr: nwp.Expr | Sequence[nwp.Expr], + expected: dict[str, Any], + data: dict[str, list[dt.date]], +) -> None: + pytest.importorskip("pyarrow") + result = dataframe(data).select(expr) + assert_equal_data(result, expected) + + +def test_date_range_eager_leap(eager: EagerAllowed, leap_year: int) -> None: + series_leap = nwp.date_range( + dt.date(leap_year, 2, 25), dt.date(leap_year, 3, 25), eager=eager + ) + series_regular = nwp.date_range( + dt.date(leap_year + 1, 2, 25), + dt.date(leap_year + 1, 3, 25), + interval=dt.timedelta(days=1), + eager=eager, + ) + assert len(series_regular) == 29 + assert len(series_leap) == 30 + + +@pytest.mark.parametrize( + ("start", "end", "interval", "closed", "expected"), + [ + (dt.date(2000, 1, 1), dt.date(2023, 8, 31), "987d", "both", EXPECTED_DATE_3), + (dt.date(2000, 1, 1), dt.date(2023, 8, 31), "354w", "right", EXPECTED_DATE_4), + ], +) +def test_date_range_eager( + start: dt.date, + end: dt.date, + interval: str | dt.timedelta, + closed: ClosedInterval, + expected: list[dt.date], + eager: EagerAllowed, +) -> None: + ser = nwp.date_range(start, end, interval=interval, closed=closed, eager=eager) + result = ser.to_list() + assert result == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ([nwp.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), + ([nwp.int_range(nwp.len())], {"literal": [0, 1, 2]}), + (nwp.int_range(nwp.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), + (nwp.int_range(nwp.col("b").min() + 4, nwp.col("d").last()), {"b": [5, 6, 7]}), + ], +) +def test_int_range( + expr: nwp.Expr | Sequence[nwp.Expr], expected: dict[str, Any], data: dict[str, Any] +) -> None: + pytest.importorskip("pyarrow") + result = dataframe(data).select(expr) + assert_equal_data(result, expected) + + +def test_int_range_eager(eager: EagerAllowed) -> None: + ser = nwp.int_range(50, eager=eager) + assert isinstance(ser, nwp.Series) + assert ser.to_list() == list(range(50)) + + +@pytest.mark.parametrize(("start", "end"), [(0, 0), (0, 1), (-1, 0), (-2.1, 3.4)]) +@pytest.mark.parametrize("num_samples", [0, 1, 2, 5, 1_000]) +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +def test_linear_space_values( + start: float, + end: float, + num_samples: int, + interval: ClosedInterval, + *, + backend: EagerAllowed | Literal[False], +) -> None: + # NOTE: Adapted from https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/functions/range/test_linear_space.py#L19-L56 + if backend: + result = nwp.linear_space( + start, end, num_samples, closed=interval, eager=backend + ).rename("ls") + else: + result = ( + dataframe({}) + .select(ls=nwp.linear_space(start, end, num_samples, closed=interval)) + .to_series() + ) + + pytest.importorskip("numpy") + import numpy as np + + if interval == "both": + expected = np.linspace(start, end, num_samples) + elif interval == "left": + expected = np.linspace(start, end, num_samples, endpoint=False) + elif interval == "right": + expected = np.linspace(start, end, num_samples + 1)[1:] + elif interval == "none": + expected = np.linspace(start, end, num_samples + 2)[1:-1] + + assert_equal_series(result, expected, "ls") + + +def test_linear_space_expr() -> None: + # NOTE: Adapted from https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/functions/range/test_linear_space.py#L59-L68 + pytest.importorskip("pyarrow") + df = dataframe({"a": [1, 2, 3, 4, 5]}) + + result = df.select(nwp.linear_space(0, nwp.col("a").len(), 3)) + expected = df.select( + literal=nwp.Series.from_iterable( + [0.0, 2.5, 5.0], dtype=nw.Float64, backend="pyarrow" + ) + ) + assert_equal_data(result, expected) + + result = df.select(nwp.linear_space(nwp.col("a").len(), 0, 3)) + expected = df.select( + a=nwp.Series.from_iterable([5.0, 2.5, 0.0], dtype=nw.Float64, backend="pyarrow") + ) + assert_equal_data(result, expected) + + +# NOTE: More general "supertyping" behavior would need `pyarrow.unify_schemas` +# (https://arrow.apache.org/docs/14.0/python/generated/pyarrow.unify_schemas.html) +@pytest.mark.parametrize( + ("dtype_start", "dtype_end", "dtype_expected"), + [ + pytest.param( + nw.Float32, + nw.Float32, + nw.Float32, + marks=pytest.mark.xfail( + reason="Didn't preserve `Float32` dtype, promoted to `Float64`", + raises=AssertionError, + ), + ), + (nw.Float32, nw.Float64, nw.Float64), + (nw.Float64, nw.Float32, nw.Float64), + (nw.Float64, nw.Float64, nw.Float64), + (nw.UInt8, nw.UInt32, nw.Float64), + (nw.Int16, nw.Int128, nw.Float64), + (nw.Int8, nw.Float64, nw.Float64), + ], +) +def test_linear_space_expr_numeric_dtype( + dtype_start: IntoDType, dtype_end: IntoDType, dtype_expected: IntoDType +) -> None: + # NOTE: Adapted from https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/functions/range/test_linear_space.py#L71-L95 + pytest.importorskip("pyarrow") + df = dataframe({}) + result = df.select( + ls=nwp.linear_space(nwp.lit(0, dtype=dtype_start), nwp.lit(1, dtype=dtype_end), 6) + ) + expected = df.select( + ls=nwp.Series.from_iterable( + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=dtype_expected, backend="pyarrow" + ) + ) + assert result.get_column("ls").dtype == dtype_expected + assert_equal_data(result, expected) + + +def test_linear_space_expr_wrong_length() -> None: + # NOTE: Adapted from https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/functions/range/test_linear_space.py#L194-L199 + pytest.importorskip("pyarrow") + df = dataframe({"a": [1, 2, 3, 4, 5]}) + with pytest.raises(ShapeError, match="Expected object of length 6, got 5"): + df.with_columns(nwp.linear_space(0, 1, 6)) diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 70bf35541d..92446a48cd 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -220,8 +220,10 @@ def series(values: Iterable[Any], /) -> nwp.Series[pa.ChunkedArray[Any]]: def assert_equal_data( - result: nwp.DataFrame[Any, Any], expected: Mapping[str, Any] + result: nwp.DataFrame[Any, Any], expected: Mapping[str, Any] | nwp.DataFrame[Any, Any] ) -> None: + if isinstance(expected, nwp.DataFrame): + expected = expected.to_dict(as_series=False) _assert_equal_data(result.to_dict(as_series=False), expected) From d238ac64818965e5371ecd6f7e370140def5f1e1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Dec 2025 21:59:52 +0000 Subject: [PATCH 176/215] feat(expr-ir): Support `list.unique` Was a bit tricky, but quite possible --- narwhals/_plan/arrow/expr.py | 5 ++++- narwhals/_plan/arrow/functions.py | 37 +++++++++++++++++++++++++++++++ tests/plan/list_unique_test.py | 31 ++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 tests/plan/list_unique_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 3e350e4859..3b09efe207 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -957,7 +957,10 @@ def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: return self.unary(fn.list_get, node.function.index)(node, frame, name) - unique = not_implemented() + # TODO @dangotbanned: Add tests for scalar + def unique(self, node: FExpr[lists.Unique], frame: Frame, name: str) -> Expr | Scalar: + return self.unary(fn.list_unique)(node, frame, name) # type: ignore[arg-type] + contains = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 4b0a65101e..f7aace081e 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -605,6 +605,35 @@ def list_join( return pc.binary_join(native, separator) +# TODO @dangotbanned: Multiple cleanup jobs +# - Use Acero directly to avoid renaming columns +# - Maybe handle the filter in acero too? +def list_unique(native: ChunkedArrayAny) -> ChunkedArrayAny: + lengths = list_len(native) + is_not_valid = or_(is_null(lengths), eq(lengths, lit(0))) + is_valid = not_(is_not_valid) + + i, v = "index", "values" + + indexed = concat_horizontal_arrays( + [int_range(len(native), chunked=False), native], [i, v] + ) + return ( + concat_vertical_table( + [ + ExplodeBuilder(empty_as_null=False, keep_nulls=False) + .explode_column(indexed.filter(is_valid), v) # pyright: ignore[reportArgumentType] + .group_by(i) + .aggregate([(v, "hash_distinct", pa_options.count("all"))]) + .rename_columns([i, v]), + indexed.filter(is_not_valid), # pyright: ignore[reportArgumentType] + ] + ) + .sort_by(i) + .column(v) + ) + + def str_join( native: Arrow[StringScalar], separator: str, *, ignore_nulls: bool = True ) -> StringScalar: @@ -1619,6 +1648,14 @@ def chunked_array( return _chunked_array(array(data) if isinstance(data, pa.Scalar) else data, dtype) +def concat_horizontal_arrays( + arrays: Collection[ChunkedOrArrayAny], names: Collection[str] +) -> pa.Table: + table: Incomplete = pa.Table.from_arrays + result: pa.Table = table(arrays, names) + return result + + def concat_vertical_chunked( arrays: Iterable[ChunkedOrArrayAny], dtype: DataType | None = None, / ) -> ChunkedArrayAny: diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py new file mode 100644 index 0000000000..91064a8ad1 --- /dev/null +++ b/tests/plan/list_unique_test.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_series, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [[2, 2, 3, None, None], None, [], [None]]} + + +def test_list_unique(data: Data) -> None: + df = dataframe(data).with_columns(nwp.col("a")) + ser = df.select(nwp.col("a").cast(nw.List(nw.Int32)).list.unique()).to_series() + result = ser.to_list() + assert len(result) == 4 + assert len(result[0]) == 3 + assert set(result[0]) == {2, 3, None} + assert result[1] is None + assert len(result[2]) == 0 + assert len(result[3]) == 1 + + assert_equal_series(ser.explode(), [2, 3, None, None, None, None], "a") From c2c3d78a5e707f6b961a85a760a9ff4a5f42b13f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:23:38 +0000 Subject: [PATCH 177/215] fix(typing): Well that my goof, apparently --- narwhals/_plan/arrow/functions.py | 36 ++++++++++++++++--------------- narwhals/_plan/arrow/typing.py | 5 +++-- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index f7aace081e..c6e629229c 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -141,22 +141,22 @@ class MinMax(ir.AggExpr): IntoColumnAgg: TypeAlias = Callable[[str], ir.AggExpr] """Helper constructor for single-column aggregations.""" -is_null = t.cast("UnaryFunction[ScalarAny, BooleanScalar]", pc.is_null) -is_not_null = t.cast("UnaryFunction[ScalarAny,BooleanScalar]", pc.is_valid) -is_nan = t.cast("UnaryFunction[ScalarAny, BooleanScalar]", pc.is_nan) -is_finite = t.cast("UnaryFunction[ScalarAny, BooleanScalar]", pc.is_finite) -not_ = t.cast("UnaryFunction[ScalarAny, BooleanScalar]", pc.invert) +is_null = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_null) +is_not_null = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_valid) +is_nan = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_nan) +is_finite = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_finite) +not_ = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.invert) @overload -def is_not_nan(native: ChunkedArrayAny) -> ChunkedArray[BooleanScalar]: ... +def is_not_nan(native: ChunkedArrayAny) -> ChunkedArray[pa.BooleanScalar]: ... @overload -def is_not_nan(native: ScalarAny) -> BooleanScalar: ... +def is_not_nan(native: ScalarAny) -> pa.BooleanScalar: ... @overload -def is_not_nan(native: ChunkedOrScalarAny) -> ChunkedOrScalar[BooleanScalar]: ... +def is_not_nan(native: ChunkedOrScalarAny) -> ChunkedOrScalar[pa.BooleanScalar]: ... @overload -def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[BooleanScalar]: ... -def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[BooleanScalar]: +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: ... +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: return not_(is_nan(native)) @@ -483,7 +483,7 @@ def _iter_ensure_shape( raise ShapeError(msg) yield arr - def _predicate(self, lengths: ArrowAny, /) -> Arrow[BooleanScalar]: + def _predicate(self, lengths: ArrowAny, /) -> Arrow[pa.BooleanScalar]: """Return True for each sublist length that indicates the original sublist should be replaced with `[None]`.""" empty_as_null, keep_nulls = self.options.empty_as_null, self.options.keep_nulls if empty_as_null and keep_nulls: @@ -622,11 +622,11 @@ def list_unique(native: ChunkedArrayAny) -> ChunkedArrayAny: concat_vertical_table( [ ExplodeBuilder(empty_as_null=False, keep_nulls=False) - .explode_column(indexed.filter(is_valid), v) # pyright: ignore[reportArgumentType] + .explode_column(indexed.filter(is_valid), v) .group_by(i) .aggregate([(v, "hash_distinct", pa_options.count("all"))]) .rename_columns([i, v]), - indexed.filter(is_not_valid), # pyright: ignore[reportArgumentType] + indexed.filter(is_not_valid), ] ) .sort_by(i) @@ -1224,23 +1224,25 @@ def is_between( lower: ChunkedOrScalar[ScalarT] | NumericLiteral, upper: ChunkedOrScalar[ScalarT] | NumericLiteral, closed: ClosedInterval, -) -> ChunkedArray[BooleanScalar]: ... +) -> ChunkedArray[pa.BooleanScalar]: ... @t.overload def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT] | NumericLiteral, upper: ChunkedOrScalar[ScalarT] | NumericLiteral, closed: ClosedInterval, -) -> ChunkedOrScalar[BooleanScalar]: ... +) -> ChunkedOrScalar[pa.BooleanScalar]: ... def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT] | NumericLiteral, upper: ChunkedOrScalar[ScalarT] | NumericLiteral, closed: ClosedInterval, -) -> ChunkedOrScalar[BooleanScalar]: +) -> ChunkedOrScalar[pa.BooleanScalar]: fn_lhs, fn_rhs = _IS_BETWEEN[closed] low, high = (el if _is_arrow(el) else lit(el) for el in (lower, upper)) - out: ChunkedOrScalar[BooleanScalar] = and_(fn_lhs(native, low), fn_rhs(native, high)) + out: ChunkedOrScalar[pa.BooleanScalar] = and_( + fn_lhs(native, low), fn_rhs(native, high) + ) return out diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 766053f76a..76d7430623 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -36,6 +36,7 @@ DateScalar: TypeAlias = "Scalar[Date32Type]" ListScalar: TypeAlias = "Scalar[pa.ListType[DataTypeT_co]]" BooleanScalar: TypeAlias = "Scalar[BoolType]" + """Only use this for a parameter type, not as a return type!""" NumericScalar: TypeAlias = "pc.NumericScalar" PrimitiveNumericType: TypeAlias = "types._Integer | types._Floating" @@ -176,11 +177,11 @@ def __call__( class BinaryComp( - BinaryFunction[ScalarPT_contra, "BooleanScalar"], Protocol[ScalarPT_contra] + BinaryFunction[ScalarPT_contra, "pa.BooleanScalar"], Protocol[ScalarPT_contra] ): ... -class BinaryLogical(BinaryFunction["BooleanScalar", "BooleanScalar"], Protocol): ... +class BinaryLogical(BinaryFunction["BooleanScalar", "pa.BooleanScalar"], Protocol): ... BinaryNumericTemporal: TypeAlias = BinaryFunction[ From 5f001b47d5aa9331e7de46b5cbf2b58d38c47d85 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:58:51 +0000 Subject: [PATCH 178/215] test: Cover unique on `ListScalar` The error is getting raised on the `first()` call --- tests/plan/list_unique_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py index 91064a8ad1..7212da6aff 100644 --- a/tests/plan/list_unique_test.py +++ b/tests/plan/list_unique_test.py @@ -29,3 +29,27 @@ def test_list_unique(data: Data) -> None: assert len(result[3]) == 1 assert_equal_series(ser.explode(), [2, 3, None, None, None, None], "a") + + +@pytest.mark.parametrize( + ("row", "expected"), + [ + ([None, "A", "B", "A", "A", "B"], [None, "A", "B"]), + pytest.param( + None, + None, + marks=pytest.mark.xfail( + reason="TODO: `object of type 'NoneType' has no len()`", raises=TypeError + ), + ), + ([], []), + ([None], [None]), + ], +) +def test_list_unique_scalar( + row: list[str | None] | None, expected: list[str | None] | None +) -> None: + data = {"a": [row]} + df = dataframe(data).select(nwp.col("a").cast(nw.List(nw.String)).first()) + result = df.select(nwp.col("a").list.unique()).to_series() + assert_equal_series(result, [expected], "a") From cd3b626033585e18b253da4e7c433cfe36548708 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Dec 2025 23:19:22 +0000 Subject: [PATCH 179/215] fix: `pyarrow<14` compat Fixes https://github.com/narwhals-dev/narwhals/actions/runs/20115195424/job/57722811655#step:9:384 --- narwhals/_plan/arrow/functions.py | 34 +++++++++---------------------- narwhals/_plan/arrow/namespace.py | 8 ++++---- narwhals/_plan/arrow/series.py | 2 +- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index c6e629229c..8cc0661f58 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -14,6 +14,7 @@ from narwhals._arrow.utils import ( cast_for_truediv, chunked_array as _chunked_array, + concat_tables as concat_tables, # noqa: PLC0414 floordiv_compat as _floordiv, narwhals_to_native_dtype as _dtype_native, ) @@ -31,7 +32,7 @@ from typing_extensions import Self, TypeAlias, TypeIs, TypeVarTuple, Unpack - from narwhals._arrow.typing import Incomplete, PromoteOptions + from narwhals._arrow.typing import Incomplete from narwhals._plan.arrow.acero import Field from narwhals._plan.arrow.typing import ( Array, @@ -615,11 +616,9 @@ def list_unique(native: ChunkedArrayAny) -> ChunkedArrayAny: i, v = "index", "values" - indexed = concat_horizontal_arrays( - [int_range(len(native), chunked=False), native], [i, v] - ) + indexed = concat_horizontal([int_range(len(native), chunked=False), native], [i, v]) return ( - concat_vertical_table( + concat_tables( [ ExplodeBuilder(empty_as_null=False, keep_nulls=False) .explode_column(indexed.filter(is_valid), v) @@ -1650,35 +1649,22 @@ def chunked_array( return _chunked_array(array(data) if isinstance(data, pa.Scalar) else data, dtype) -def concat_horizontal_arrays( +def concat_horizontal( arrays: Collection[ChunkedOrArrayAny], names: Collection[str] ) -> pa.Table: + """Concatenate `arrays` as columns in a new table.""" table: Incomplete = pa.Table.from_arrays result: pa.Table = table(arrays, names) return result -def concat_vertical_chunked( +def concat_vertical( arrays: Iterable[ChunkedOrArrayAny], dtype: DataType | None = None, / ) -> ChunkedArrayAny: + """Concatenate `arrays` into a new array.""" v_concat: Incomplete = pa.chunked_array - return v_concat(arrays, dtype) # type: ignore[no-any-return] - - -def concat_vertical_table( - tables: Iterable[pa.Table], /, promote_options: PromoteOptions = "none" -) -> pa.Table: - return pa.concat_tables(tables, promote_options=promote_options) - - -if BACKEND_VERSION >= (14,): - - def concat_diagonal(tables: Iterable[pa.Table]) -> pa.Table: - return pa.concat_tables(tables, promote_options="default") -else: - - def concat_diagonal(tables: Iterable[pa.Table]) -> pa.Table: - return pa.concat_tables(tables, promote=True) + result: ChunkedArrayAny = v_concat(arrays, dtype) + return result def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index b2d6f60e21..534aab37d5 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -293,7 +293,7 @@ def concat( def _concat_diagonal(self, items: Iterable[Frame]) -> Frame: return self._dataframe.from_native( - fn.concat_vertical_table(df.native for df in items), self.version + fn.concat_tables((df.native for df in items), "default"), self.version ) def _concat_horizontal(self, items: Iterable[Frame | Series]) -> Frame: @@ -305,14 +305,14 @@ def gen(objs: Iterable[Frame | Series]) -> Iterator[tuple[ChunkedArrayAny, str]] yield from zip(item.native.itercolumns(), item.columns) arrays, names = zip(*gen(items)) - native = pa.Table.from_arrays(arrays, list(names)) + native = fn.concat_horizontal(arrays, names) return self._dataframe.from_native(native, self.version) def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series: collected = items if isinstance(items, tuple) else tuple(items) if is_tuple_of(collected, self._series): sers = collected - chunked = fn.concat_vertical_chunked(ser.native for ser in sers) + chunked = fn.concat_vertical(ser.native for ser in sers) return sers[0]._with_native(chunked) if is_tuple_of(collected, self._dataframe): dfs = collected @@ -326,5 +326,5 @@ def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series: f" - dataframe {i}: {cols_current}\n" ) raise TypeError(msg) - return df._with_native(fn.concat_vertical_table(df.native for df in dfs)) + return df._with_native(fn.concat_tables(df.native for df in dfs)) raise TypeError(items) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index ed30d683ce..3407cf3cca 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -213,7 +213,7 @@ def _rolling_center(self, window_size: int) -> tuple[Self, int]: fn.nulls_like(offset_right, native), ) offset = offset_left + offset_right - return self._with_native(fn.concat_vertical_chunked(arrays)), offset + return self._with_native(fn.concat_vertical(arrays)), offset def _rolling_sum(self, window_size: int, /) -> Self: cum_sum = self.cum_sum().fill_null_with_strategy("forward") From ced4ae1777cd9a3794df0a810ec26a2e69897c4b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 10 Dec 2025 23:43:25 +0000 Subject: [PATCH 180/215] use acero more directly in `list.unique` --- narwhals/_plan/arrow/functions.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 8cc0661f58..ac2d82e8dc 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -472,6 +472,13 @@ def explode_columns(self, native: pa.Table, subset: Collection[str], /) -> pa.Ta return with_arrays(native, zip(subset, chain([first_result], it))) + @classmethod + def explode_column_fast(cls, native: pa.Table, column_name: str, /) -> pa.Table: + """Explode a list-typed column in the context of `native`, ignoring empty and nulls.""" + return cls(empty_as_null=False, keep_nulls=False).explode_column( + native, column_name + ) + def _iter_ensure_shape( self, first_len: ChunkedArray[pa.UInt32Scalar], @@ -610,24 +617,19 @@ def list_join( # - Use Acero directly to avoid renaming columns # - Maybe handle the filter in acero too? def list_unique(native: ChunkedArrayAny) -> ChunkedArrayAny: + from narwhals._plan.arrow.acero import group_by_table + from narwhals._plan.arrow.group_by import AggSpec + lengths = list_len(native) is_not_valid = or_(is_null(lengths), eq(lengths, lit(0))) is_valid = not_(is_not_valid) i, v = "index", "values" - indexed = concat_horizontal([int_range(len(native), chunked=False), native], [i, v]) + table = ExplodeBuilder.explode_column_fast(indexed.filter(is_valid), v) + agg = AggSpec.from_function_expr(ir_unique(v), v) return ( - concat_tables( - [ - ExplodeBuilder(empty_as_null=False, keep_nulls=False) - .explode_column(indexed.filter(is_valid), v) - .group_by(i) - .aggregate([(v, "hash_distinct", pa_options.count("all"))]) - .rename_columns([i, v]), - indexed.filter(is_not_valid), - ] - ) + concat_tables([group_by_table(table, [i], [agg]), indexed.filter(is_not_valid)]) .sort_by(i) .column(v) ) @@ -1274,6 +1276,10 @@ def ir_min_max(name: str, /) -> MinMax: return MinMax(expr=ir.col(name)) +def ir_unique(name: str, /) -> ir.FunctionExpr[F.Unique]: + return F.Unique().to_function_expr(ir.col(name)) + + def _boolean_is_unique( indices: ChunkedArrayAny, aggregated: ChunkedStruct, / ) -> ChunkedArrayAny: From 2ce64d603730e111e47bf2aa622220cd7f23ca0d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 11:52:38 +0000 Subject: [PATCH 181/215] refactor: somewhat cleaner `list_unique` Mostly just aiming for readability improvements --- narwhals/_plan/arrow/functions.py | 36 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index ac2d82e8dc..60bac0c966 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -395,6 +395,15 @@ def get_categories(native: ArrowAny) -> ChunkedArrayAny: class ExplodeBuilder: + """Tools for exploding lists. + + The complexity of these operations increases with: + - Needing to preserve null/empty elements + - All variants are cheaper if this can be skipped + - Exploding in the context of a table + - Where a single column is much simpler than multiple + """ + options: ExplodeOptions def __init__(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> None: @@ -613,26 +622,23 @@ def list_join( return pc.binary_join(native, separator) -# TODO @dangotbanned: Multiple cleanup jobs -# - Use Acero directly to avoid renaming columns -# - Maybe handle the filter in acero too? +# TODO @dangotbanned: Docs, explain why some of the intermediate steps are needed +# hint: why not `replace_with_mask`? def list_unique(native: ChunkedArrayAny) -> ChunkedArrayAny: from narwhals._plan.arrow.acero import group_by_table from narwhals._plan.arrow.group_by import AggSpec - lengths = list_len(native) - is_not_valid = or_(is_null(lengths), eq(lengths, lit(0))) - is_valid = not_(is_not_valid) - - i, v = "index", "values" - indexed = concat_horizontal([int_range(len(native), chunked=False), native], [i, v]) - table = ExplodeBuilder.explode_column_fast(indexed.filter(is_valid), v) - agg = AggSpec.from_function_expr(ir_unique(v), v) - return ( - concat_tables([group_by_table(table, [i], [agg]), indexed.filter(is_not_valid)]) - .sort_by(i) - .column(v) + idx, v = "index", "values" + indexed = concat_horizontal([int_range(len(native), chunked=False), native], [idx, v]) + len_eq_0 = eq(list_len(native), lit(0)) + valid = indexed.filter(not_(len_eq_0)) + invalid = indexed.filter(or_(native.is_null(), len_eq_0)) + valid_unique = group_by_table( + ExplodeBuilder.explode_column_fast(valid, v), + [idx], + [AggSpec.from_expr_ir(ir_unique(v), v)], ) + return concat_tables([valid_unique, invalid]).sort_by(idx).column(v) def str_join( From a10bf73c4d0025e2e7103cb62764d6f8ba1036b3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 12:57:15 +0000 Subject: [PATCH 182/215] fix: Avoid `pa.array` bug w/ `[]` --- narwhals/_plan/arrow/functions.py | 4 ++++ tests/plan/list_unique_test.py | 11 ++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 60bac0c966..ff45f90f05 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -1628,6 +1628,8 @@ def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) +# TODO @dangotbanned: Report `ListScalar.values` bug upstream +# See `tests/plan/list_unique_test.py::test_list_unique_scalar[None-None]` @overload def array(data: ArrowAny, /) -> ArrayAny: ... @overload @@ -1651,6 +1653,8 @@ def array( if isinstance(data, pa.Array): return data if isinstance(data, pa.Scalar): + if isinstance(data, pa.ListScalar) and data.is_valid is False: + return pa.array([None], data.type) return pa.array([data], data.type) return pa.array(data, dtype) diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py index 7212da6aff..6ac2b10a2e 100644 --- a/tests/plan/list_unique_test.py +++ b/tests/plan/list_unique_test.py @@ -31,17 +31,14 @@ def test_list_unique(data: Data) -> None: assert_equal_series(ser.explode(), [2, 3, None, None, None, None], "a") +# TODO @dangotbanned: Report `ListScalar.values` bug upstream +# - Returning `None` breaks: `__len__`,` __getitem__`, `__iter__` +# - Which breaks `pa.array([], pa.list_(pa.int64()))` @pytest.mark.parametrize( ("row", "expected"), [ ([None, "A", "B", "A", "A", "B"], [None, "A", "B"]), - pytest.param( - None, - None, - marks=pytest.mark.xfail( - reason="TODO: `object of type 'NoneType' has no len()`", raises=TypeError - ), - ), + (None, None), ([], []), ([None], [None]), ], From 27992a125112a86480a30426fcbcffa4f81d6e9a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 13:04:17 +0000 Subject: [PATCH 183/215] perf: Add broadcast fastpath for 1 column --- narwhals/_plan/compliant/column.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/compliant/column.py b/narwhals/_plan/compliant/column.py index 49d49035f3..598f462286 100644 --- a/narwhals/_plan/compliant/column.py +++ b/narwhals/_plan/compliant/column.py @@ -60,7 +60,7 @@ def align( `default` must be provided when operating in a `with_columns` context. """ exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs)) - length = cls._length_required(exprs, default) + length = default if len(exprs) == 1 else cls._length_required(exprs, default) if length is None: for e in exprs: yield e.to_series() From 1519d7f4c71a21360dfd0dac67db2556059f4cc6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 14:34:08 +0000 Subject: [PATCH 184/215] perf: Add a fastpath for all valid Skips *quite* a lot of work: - Creating an index column - 2x table filters - `explode_column`, which has more steps - `concat_table` - `sort_by` --- narwhals/_plan/arrow/functions.py | 25 +++++++++++++++---------- tests/plan/list_unique_test.py | 19 ++++++++++++++++--- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index ff45f90f05..0f54ca3061 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -629,16 +629,21 @@ def list_unique(native: ChunkedArrayAny) -> ChunkedArrayAny: from narwhals._plan.arrow.group_by import AggSpec idx, v = "index", "values" - indexed = concat_horizontal([int_range(len(native), chunked=False), native], [idx, v]) - len_eq_0 = eq(list_len(native), lit(0)) - valid = indexed.filter(not_(len_eq_0)) - invalid = indexed.filter(or_(native.is_null(), len_eq_0)) - valid_unique = group_by_table( - ExplodeBuilder.explode_column_fast(valid, v), - [idx], - [AggSpec.from_expr_ir(ir_unique(v), v)], - ) - return concat_tables([valid_unique, invalid]).sort_by(idx).column(v) + names = idx, v + len_not_eq_0 = not_eq(list_len(native), lit(0)) + aggs = [AggSpec.from_expr_ir(ir_unique(v), v)] + can_fastpath = all_(len_not_eq_0, ignore_nulls=False).as_py() + if can_fastpath: + arrays = [_list_parent_indices(native), _list_explode(native)] + result = group_by_table(concat_horizontal(arrays, names), [idx], aggs) + else: + indexed = concat_horizontal([int_range(len(native)), native], names) + valid = indexed.filter(len_not_eq_0) + invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) + explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) + valid_unique = group_by_table(explode_with_index, [idx], aggs) + result = concat_tables([valid_unique, invalid]).sort_by(idx) + return result.column(v) def str_join( diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py index 6ac2b10a2e..a49b37030d 100644 --- a/tests/plan/list_unique_test.py +++ b/tests/plan/list_unique_test.py @@ -14,12 +14,15 @@ @pytest.fixture(scope="module") def data() -> Data: - return {"a": [[2, 2, 3, None, None], None, [], [None]]} + return { + "a": [[2, 2, 3, None, None], None, [], [None]], + "b": [[1, 2, 2], [3, 4], [5, 5, 5, 6], [7]], + } def test_list_unique(data: Data) -> None: - df = dataframe(data).with_columns(nwp.col("a")) - ser = df.select(nwp.col("a").cast(nw.List(nw.Int32)).list.unique()).to_series() + df = dataframe(data).select(nwp.col("a").cast(nw.List(nw.Int32))) + ser = df.select(nwp.col("a").list.unique()).to_series() result = ser.to_list() assert len(result) == 4 assert len(result[0]) == 3 @@ -50,3 +53,13 @@ def test_list_unique_scalar( df = dataframe(data).select(nwp.col("a").cast(nw.List(nw.String)).first()) result = df.select(nwp.col("a").list.unique()).to_series() assert_equal_series(result, [expected], "a") + + +def test_list_unique_all_valid(data: Data) -> None: + df = dataframe(data).select(nwp.col("b").cast(nw.List(nw.Int32))) + ser = df.select(nwp.col("b").list.unique()).to_series() + result = ser.to_list() + assert set(result[0]) == {1, 2} + assert set(result[1]) == {3, 4} + assert set(result[2]) == {5, 6} + assert set(result[3]) == {7} From 3ae41137c7c6e739fdc8c15cc8576e8f1386af54 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 15:05:10 +0000 Subject: [PATCH 185/215] test: Reveal `Scalar` bugs Lol --- tests/plan/list_unique_test.py | 55 +++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py index a49b37030d..e4d77c978b 100644 --- a/tests/plan/list_unique_test.py +++ b/tests/plan/list_unique_test.py @@ -8,6 +8,9 @@ import narwhals._plan as nwp from tests.plan.utils import assert_equal_series, dataframe +pytest.importorskip("pyarrow") +import pyarrow as pa + if TYPE_CHECKING: from tests.conftest import Data @@ -20,9 +23,13 @@ def data() -> Data: } +a = nwp.col("a") +b = nwp.col("b") + + def test_list_unique(data: Data) -> None: - df = dataframe(data).select(nwp.col("a").cast(nw.List(nw.Int32))) - ser = df.select(nwp.col("a").list.unique()).to_series() + df = dataframe(data).select(a.cast(nw.List(nw.Int32))) + ser = df.select(a.list.unique()).to_series() result = ser.to_list() assert len(result) == 4 assert len(result[0]) == 3 @@ -40,24 +47,52 @@ def test_list_unique(data: Data) -> None: @pytest.mark.parametrize( ("row", "expected"), [ - ([None, "A", "B", "A", "A", "B"], [None, "A", "B"]), - (None, None), - ([], []), - ([None], [None]), + pytest.param( + [None, "A", "B", "A", "A", "B"], + [None, "A", "B"], + marks=pytest.mark.xfail( + reason="Unsupported input type for function 'list_parent_indices': Scalar(list[null, A, B, A, A, B])", + raises=pa.ArrowNotImplementedError, + ), + ), + pytest.param( + None, + None, + marks=pytest.mark.xfail( + reason="object of type 'NoneType' has no len()", raises=TypeError + ), + ), + pytest.param( + [], + [], + marks=pytest.mark.xfail( + reason="Filter should be array-like", raises=pa.ArrowTypeError + ), + ), + pytest.param( + [None], + [None], + marks=pytest.mark.xfail( + reason="Unsupported input type for function 'list_parent_indices': Scalar(list[null])", + raises=pa.ArrowNotImplementedError, + ), + ), ], ) def test_list_unique_scalar( row: list[str | None] | None, expected: list[str | None] | None ) -> None: data = {"a": [row]} - df = dataframe(data).select(nwp.col("a").cast(nw.List(nw.String)).first()) - result = df.select(nwp.col("a").list.unique()).to_series() + df = dataframe(data).select(a.cast(nw.List(nw.String))) + # NOTE: Don't separate `first().list.unique()` + # The chain is required to force the transition from `Expr` -> `Scalar` + result = df.select(a.first().list.unique()).to_series() assert_equal_series(result, [expected], "a") def test_list_unique_all_valid(data: Data) -> None: - df = dataframe(data).select(nwp.col("b").cast(nw.List(nw.Int32))) - ser = df.select(nwp.col("b").list.unique()).to_series() + df = dataframe(data).select(b.cast(nw.List(nw.Int32))) + ser = df.select(b.list.unique()).to_series() result = ser.to_list() assert set(result[0]) == {1, 2} assert set(result[1]) == {3, 4} From 7b2e06d67e2cf5ac1ad903465044f63ccc3ba1e2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 15:49:50 +0000 Subject: [PATCH 186/215] fix: Handle 2/4 scalar `list.unique` cases --- narwhals/_plan/arrow/expr.py | 3 +-- narwhals/_plan/arrow/functions.py | 34 +++++++++++++++++++++++-------- narwhals/_plan/arrow/typing.py | 9 ++++---- tests/plan/list_unique_test.py | 18 ++-------------- 4 files changed, 33 insertions(+), 31 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 3b09efe207..88995e8abe 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -957,9 +957,8 @@ def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: return self.unary(fn.list_get, node.function.index)(node, frame, name) - # TODO @dangotbanned: Add tests for scalar def unique(self, node: FExpr[lists.Unique], frame: Frame, name: str) -> Expr | Scalar: - return self.unary(fn.list_unique)(node, frame, name) # type: ignore[arg-type] + return self.unary(fn.list_unique)(node, frame, name) contains = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 0f54ca3061..6483520248 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -65,7 +65,6 @@ DateScalar, IntegerScalar, IntegerType, - LargeStringType, ListArray, ListScalar, ListTypeT, @@ -306,7 +305,7 @@ def has_large_string(data_types: Iterable[DataType], /) -> bool: return any(pa.types.is_large_string(tp) for tp in data_types) -def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStringType: +def string_type(data_types: Iterable[DataType] = (), /) -> StringType: """Return a native string type, compatible with `data_types`. Until [apache/arrow#45717] is resolved, we need to upcast `string` to `large_string` when joining. @@ -624,7 +623,14 @@ def list_join( # TODO @dangotbanned: Docs, explain why some of the intermediate steps are needed # hint: why not `replace_with_mask`? -def list_unique(native: ChunkedArrayAny) -> ChunkedArrayAny: +@overload +def list_unique(native: ChunkedList) -> ChunkedList: ... +@overload +def list_unique(native: ListScalar) -> ListScalar: ... +@overload +def list_unique(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: ... +def list_unique(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: + """Get the unique/distinct values in the list.""" from narwhals._plan.arrow.acero import group_by_table from narwhals._plan.arrow.group_by import AggSpec @@ -634,18 +640,30 @@ def list_unique(native: ChunkedArrayAny) -> ChunkedArrayAny: aggs = [AggSpec.from_expr_ir(ir_unique(v), v)] can_fastpath = all_(len_not_eq_0, ignore_nulls=False).as_py() if can_fastpath: + if isinstance(native, pa.Scalar): + return implode(_list_explode(native).unique()) arrays = [_list_parent_indices(native), _list_explode(native)] result = group_by_table(concat_horizontal(arrays, names), [idx], aggs) else: - indexed = concat_horizontal([int_range(len(native)), native], names) - valid = indexed.filter(len_not_eq_0) - invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) + # TODO @dangotbanned: Fix these, they're legit + indexed = concat_horizontal([int_range(len(native)), native], names) # type: ignore[arg-type, list-item] + valid = indexed.filter(len_not_eq_0) # pyright: ignore[reportArgumentType] + invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) # type: ignore[union-attr] explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) valid_unique = group_by_table(explode_with_index, [idx], aggs) result = concat_tables([valid_unique, invalid]).sort_by(idx) return result.column(v) +def implode(native: Arrow[Scalar[DataTypeT]]) -> ListScalar[DataTypeT]: + """Aggregate values into a list. + + The returned list itself is a scalar value of `list` dtype. + """ + arr = array(native) + return pa.ListArray.from_arrays([0, len(arr)], arr)[0] + + def str_join( native: Arrow[StringScalar], separator: str, *, ignore_nulls: bool = True ) -> StringScalar: @@ -655,9 +673,7 @@ def str_join( return native if ignore_nulls and native.null_count: native = native.drop_null() - offsets = [0, len(native)] - scalar = pa.ListArray.from_arrays(offsets, array(native))[0] - return list_join(scalar, separator) + return list_join(implode(native), separator) def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 76d7430623..95f0c2afe6 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -18,8 +18,8 @@ Int16Type, Int32Type, Int64Type, - LargeStringType as LargeStringType, - StringType as StringType, + LargeStringType as _LargeStringType, + StringType as _StringType, Uint8Type, Uint16Type, Uint32Type, @@ -30,8 +30,9 @@ from narwhals._native import NativeDataFrame, NativeSeries from narwhals.typing import SizedMultiIndexSelector as _SizedMultiIndexSelector - StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]" + StringType: TypeAlias = "_StringType | _LargeStringType" IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" + StringScalar: TypeAlias = "Scalar[StringType]" IntegerScalar: TypeAlias = "Scalar[IntegerType]" DateScalar: TypeAlias = "Scalar[Date32Type]" ListScalar: TypeAlias = "Scalar[pa.ListType[DataTypeT_co]]" @@ -42,7 +43,7 @@ PrimitiveNumericType: TypeAlias = "types._Integer | types._Floating" NumericType: TypeAlias = "PrimitiveNumericType | types._Decimal" NumericOrTemporalType: TypeAlias = "NumericType | types._Temporal" - StringOrBinaryType: TypeAlias = "StringType | LargeStringType | lib.StringViewType | lib.BinaryType | lib.LargeBinaryType | lib.BinaryViewType" + StringOrBinaryType: TypeAlias = "StringType | lib.StringViewType | lib.BinaryType | lib.LargeBinaryType | lib.BinaryViewType" BasicType: TypeAlias = ( "NumericOrTemporalType | StringOrBinaryType | BoolType | lib.NullType" ) diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py index e4d77c978b..1b399897f5 100644 --- a/tests/plan/list_unique_test.py +++ b/tests/plan/list_unique_test.py @@ -47,14 +47,7 @@ def test_list_unique(data: Data) -> None: @pytest.mark.parametrize( ("row", "expected"), [ - pytest.param( - [None, "A", "B", "A", "A", "B"], - [None, "A", "B"], - marks=pytest.mark.xfail( - reason="Unsupported input type for function 'list_parent_indices': Scalar(list[null, A, B, A, A, B])", - raises=pa.ArrowNotImplementedError, - ), - ), + pytest.param([None, "A", "B", "A", "A", "B"], [None, "A", "B"]), pytest.param( None, None, @@ -69,14 +62,7 @@ def test_list_unique(data: Data) -> None: reason="Filter should be array-like", raises=pa.ArrowTypeError ), ), - pytest.param( - [None], - [None], - marks=pytest.mark.xfail( - reason="Unsupported input type for function 'list_parent_indices': Scalar(list[null])", - raises=pa.ArrowNotImplementedError, - ), - ), + pytest.param([None], [None]), ], ) def test_list_unique_scalar( From e50a654550cd8c8112860dd756710f04e52f21ab Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 16:39:06 +0000 Subject: [PATCH 187/215] fix: Correctly handle `Scalar` in all `list.unique` cases Also has the benefit of skipping more work, since we don't need to group --- narwhals/_plan/arrow/functions.py | 18 ++++++++++-------- tests/plan/list_unique_test.py | 23 ++++------------------- 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 6483520248..89c094c28f 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -628,27 +628,29 @@ def list_unique(native: ChunkedList) -> ChunkedList: ... @overload def list_unique(native: ListScalar) -> ListScalar: ... @overload -def list_unique(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: ... -def list_unique(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: +def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: ... +def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: """Get the unique/distinct values in the list.""" from narwhals._plan.arrow.acero import group_by_table from narwhals._plan.arrow.group_by import AggSpec + if isinstance(native, pa.Scalar): + scalar = t.cast("pa.ListScalar[Any]", native) + if scalar.is_valid and (len(scalar) > 1): + return implode(_list_explode(native).unique()) + return scalar idx, v = "index", "values" names = idx, v len_not_eq_0 = not_eq(list_len(native), lit(0)) aggs = [AggSpec.from_expr_ir(ir_unique(v), v)] can_fastpath = all_(len_not_eq_0, ignore_nulls=False).as_py() if can_fastpath: - if isinstance(native, pa.Scalar): - return implode(_list_explode(native).unique()) arrays = [_list_parent_indices(native), _list_explode(native)] result = group_by_table(concat_horizontal(arrays, names), [idx], aggs) else: - # TODO @dangotbanned: Fix these, they're legit - indexed = concat_horizontal([int_range(len(native)), native], names) # type: ignore[arg-type, list-item] - valid = indexed.filter(len_not_eq_0) # pyright: ignore[reportArgumentType] - invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) # type: ignore[union-attr] + indexed = concat_horizontal([int_range(len(native)), native], names) + valid = indexed.filter(len_not_eq_0) + invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) valid_unique = group_by_table(explode_with_index, [idx], aggs) result = concat_tables([valid_unique, invalid]).sort_by(idx) diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py index 1b399897f5..7f82e593b5 100644 --- a/tests/plan/list_unique_test.py +++ b/tests/plan/list_unique_test.py @@ -8,9 +8,6 @@ import narwhals._plan as nwp from tests.plan.utils import assert_equal_series, dataframe -pytest.importorskip("pyarrow") -import pyarrow as pa - if TYPE_CHECKING: from tests.conftest import Data @@ -47,22 +44,10 @@ def test_list_unique(data: Data) -> None: @pytest.mark.parametrize( ("row", "expected"), [ - pytest.param([None, "A", "B", "A", "A", "B"], [None, "A", "B"]), - pytest.param( - None, - None, - marks=pytest.mark.xfail( - reason="object of type 'NoneType' has no len()", raises=TypeError - ), - ), - pytest.param( - [], - [], - marks=pytest.mark.xfail( - reason="Filter should be array-like", raises=pa.ArrowTypeError - ), - ), - pytest.param([None], [None]), + ([None, "A", "B", "A", "A", "B"], [None, "A", "B"]), + (None, None), + ([], []), + ([None], [None]), ], ) def test_list_unique_scalar( From 7eb40c76f3ac24b8e72fb7117f05ba15cbb6ac20 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:11:53 +0000 Subject: [PATCH 188/215] docs: Write an essay on `list.unique` --- narwhals/_plan/arrow/functions.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 89c094c28f..f09d34b886 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -621,8 +621,6 @@ def list_join( return pc.binary_join(native, separator) -# TODO @dangotbanned: Docs, explain why some of the intermediate steps are needed -# hint: why not `replace_with_mask`? @overload def list_unique(native: ChunkedList) -> ChunkedList: ... @overload @@ -630,7 +628,26 @@ def list_unique(native: ListScalar) -> ListScalar: ... @overload def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: ... def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: - """Get the unique/distinct values in the list.""" + """Get the unique/distinct values in the list. + + There's lots of tricky stuff going on in here, but for good reasons! + + Whenever possible, we want to avoid having to deal with these pesky guys: + + [["okay", None, "still fine"], None, []] + # ^^^^ ^^ + + - Those kinds of list elements are ignored natively + - `unique` is length-changing operation + - We can't use [`pc.replace_with_mask`] on a list + - We can't join when a table contains list columns [apache/arrow#43716] + + **But** - if we're lucky, and we got a non-awful list (or only one element) - then + most issues vanish. + + [`pc.replace_with_mask`]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.replace_with_mask.html + [apache/arrow#43716]: https://github.com/apache/arrow/issues/43716 + """ from narwhals._plan.arrow.acero import group_by_table from narwhals._plan.arrow.group_by import AggSpec @@ -648,11 +665,16 @@ def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScal arrays = [_list_parent_indices(native), _list_explode(native)] result = group_by_table(concat_horizontal(arrays, names), [idx], aggs) else: + # Oh no - we caught a bad one! + # We need to split things into good/bad - and only work on the good stuff. + # `int_range` is acting like `parent_indices`, but doesn't give up when it see's `None` or `[]` indexed = concat_horizontal([int_range(len(native)), native], names) valid = indexed.filter(len_not_eq_0) invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) + # To keep track of where we started, our index needs to be exploded with the list elements explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) valid_unique = group_by_table(explode_with_index, [idx], aggs) + # And now, because we can't join - we do a poor man's version of one 😉 result = concat_tables([valid_unique, invalid]).sort_by(idx) return result.column(v) From bad7df1ba59f24616aba21b59de64a040ea0156c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:24:13 +0000 Subject: [PATCH 189/215] test: Add `list.contains` tests --- tests/plan/list_contains_test.py | 44 ++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 tests/plan/list_contains_test.py diff --git a/tests/plan/list_contains_test.py b/tests/plan/list_contains_test.py new file mode 100644 index 0000000000..c6284073b3 --- /dev/null +++ b/tests/plan/list_contains_test.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import IntoExpr + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "a": [[2, 2, 3, None, None], None, [], [None]], + "b": [[1, 2, 2], [3, 4], [5, 5, 5, 6], [7]], + "c": [1, 3, None, 2], + } + + +a = nwp.col("a") +b = nwp.col("b") + + +@pytest.mark.xfail(reason="TODO: `ArrowExpr.list.contains`", raises=NotImplementedError) +@pytest.mark.parametrize( + ("item", "expected"), + [ + (2, [True, None, False, False]), + (4, [False, None, False, False]), + (nwp.col("c").last() + 1, [True, None, False, False]), + (nwp.lit(None, nw.Int32), [True, None, False, True]), + ], +) +def test_list_contains( + data: Data, item: IntoExpr, expected: list[bool | None] +) -> None: # pragma: no cover + df = dataframe(data).with_columns(a.cast(nw.List(nw.Int32))) + result = df.select(a.list.contains(item)) + assert_equal_data(result, {"a": expected}) From 8cf2173ecb4ff7e3932a148c9db42d4b3fd69da6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:46:18 +0000 Subject: [PATCH 190/215] test: Add `DataFrame.explode` fail case Will just use the simpler `ExplodeBuilder.explode` for this case --- tests/plan/explode_test.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/plan/explode_test.py b/tests/plan/explode_test.py index c7cc478d67..dd67027b38 100644 --- a/tests/plan/explode_test.py +++ b/tests/plan/explode_test.py @@ -22,6 +22,9 @@ from narwhals._plan.typing import ColumnNameOrSelector from tests.conftest import Data +pytest.importorskip("pyarrow") +import pyarrow as pa + @pytest.fixture(scope="module") def data() -> Data: @@ -55,6 +58,31 @@ def test_explode_frame_single_col( assert_equal_data(result, expected) +@pytest.mark.xfail( + reason="TODO: `Added column's length must match table's length. Expected length 0 but got length ...`", + raises=pa.ArrowInvalid, +) +@pytest.mark.parametrize( + ("column", "expected_values"), + [ + ("l2", [None, None, None, 3, 42]), + ("l3", [None, 1, 1, 2, 3]), + ("l4", [1, 2, 3, 123, 456]), + ("l5", [None, None, None, 83, 99]), + ], +) +def test_explode_frame_only_column( + column: str, expected_values: list[int | None], data: Data +) -> None: # pragma: no cover + result = ( + dataframe(data) + .select(nwp.col(column).cast(nw.List(nw.Int32()))) + .explode(column) + .sort(column) + ) + assert_equal_data(result, {column: expected_values}) + + @pytest.mark.parametrize( ("column", "more_columns", "expected"), [ From c6dc2a864530bd26da2c79d85cf8a1f85d050aa7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:54:20 +0000 Subject: [PATCH 191/215] fix: Support `explode` when only column in table# --- narwhals/_plan/arrow/functions.py | 2 ++ tests/plan/explode_test.py | 7 +------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index f09d34b886..7abb6871f9 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -439,6 +439,8 @@ def explode( def explode_column(self, native: pa.Table, column_name: str, /) -> pa.Table: """Explode a list-typed column in the context of `native`.""" ca = native.column(column_name) + if native.num_columns == 1: + return native.from_arrays([self.explode(ca)], [column_name]) safe = self._fill_with_null(ca) if self.options.any() else ca exploded = _list_explode(safe) col_idx = native.schema.get_field_index(column_name) diff --git a/tests/plan/explode_test.py b/tests/plan/explode_test.py index dd67027b38..bb8601dc8c 100644 --- a/tests/plan/explode_test.py +++ b/tests/plan/explode_test.py @@ -23,7 +23,6 @@ from tests.conftest import Data pytest.importorskip("pyarrow") -import pyarrow as pa @pytest.fixture(scope="module") @@ -58,10 +57,6 @@ def test_explode_frame_single_col( assert_equal_data(result, expected) -@pytest.mark.xfail( - reason="TODO: `Added column's length must match table's length. Expected length 0 but got length ...`", - raises=pa.ArrowInvalid, -) @pytest.mark.parametrize( ("column", "expected_values"), [ @@ -73,7 +68,7 @@ def test_explode_frame_single_col( ) def test_explode_frame_only_column( column: str, expected_values: list[int | None], data: Data -) -> None: # pragma: no cover +) -> None: result = ( dataframe(data) .select(nwp.col(column).cast(nw.List(nw.Int32()))) From 8fe333ecd61852d8877923af2420fe8e53387996 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:45:49 +0000 Subject: [PATCH 192/215] feat(expr-ir): Support `ArrowExpr.list.contains(IntoExpr)` --- narwhals/_plan/arrow/expr.py | 16 +++++++++++++- narwhals/_plan/arrow/functions.py | 35 +++++++++++++++++++++++++++++++ tests/plan/list_contains_test.py | 5 +---- 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 88995e8abe..987cd1395b 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -960,7 +960,21 @@ def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: def unique(self, node: FExpr[lists.Unique], frame: Frame, name: str) -> Expr | Scalar: return self.unary(fn.list_unique)(node, frame, name) - contains = not_implemented() + def contains( + self, node: FExpr[lists.Contains], frame: Frame, name: str + ) -> Expr | Scalar: + func = node.function + expr, other = func.unwrap_input(node) + prev = expr.dispatch(self.compliant, frame, name) + item = other.dispatch(self.compliant, frame, name) + if isinstance(prev, ArrowScalar): + msg = "TODO: `ArrowScalar.list.contains`" + raise NotImplementedError(msg) + if isinstance(item, ArrowExpr): + # Maybe one day, not now + raise NotImplementedError + result = fn.list_contains(prev.native, item.native) + return self.with_native(result, name) class ArrowStringNamespace( diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7abb6871f9..e306260c68 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -436,6 +436,11 @@ def explode( return _list_explode(safe) return chunked_array(_list_explode(safe)) + def explode_with_indices(self, native: ChunkedList) -> pa.Table: + safe = self._fill_with_null(native) if self.options.any() else native + arrays = [_list_parent_indices(safe), _list_explode(safe)] + return concat_horizontal(arrays, ["idx", "values"]) + def explode_column(self, native: pa.Table, column_name: str, /) -> pa.Table: """Explode a list-typed column in the context of `native`.""" ca = native.column(column_name) @@ -681,6 +686,36 @@ def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScal return result.column(v) +# TODO @dangotbanned: Clean up +# TODO @dangotbanned: Support `native: ListScalar` +# NOTE: Both of these weren't able to support `[None]`, where, 2 in [None] should be False +# https://github.com/apache/arrow/issues/33295 +# https://github.com/apache/arrow/issues/47118#issuecomment-3075893244 +def list_contains( + native: ChunkedList, item: NonNestedLiteral | ScalarAny +) -> ChunkedArray[pa.BooleanScalar]: + # empty should always be False + # None should always be None + # `None` in `[None]` should be True + # Anything else in `[None]` should be false + ca = native + table = ExplodeBuilder(empty_as_null=False, keep_nulls=False).explode_with_indices(ca) + values = is_in(table.column(1), array(lit(item))) + name = table.field(1).name + contains = ( + table.set_column(1, name, values) + .group_by("idx") + .aggregate([(name, "hash_any", pa_options.scalar_aggregate(ignore_nulls=True))]) + .column(1) + ) + # Here's the really key part: this mask has the same result we want to return + # So by filling the `True`, we can flip those to `False` if needed + # But if we were already `None` or `False` - then that's sticky + propagate_invalid = not_eq(list_len(ca), lit(0)) + results = replace_with_mask(array(propagate_invalid), propagate_invalid, contains) + return chunked_array(results) + + def implode(native: Arrow[Scalar[DataTypeT]]) -> ListScalar[DataTypeT]: """Aggregate values into a list. diff --git a/tests/plan/list_contains_test.py b/tests/plan/list_contains_test.py index c6284073b3..4e32045a47 100644 --- a/tests/plan/list_contains_test.py +++ b/tests/plan/list_contains_test.py @@ -26,7 +26,6 @@ def data() -> Data: b = nwp.col("b") -@pytest.mark.xfail(reason="TODO: `ArrowExpr.list.contains`", raises=NotImplementedError) @pytest.mark.parametrize( ("item", "expected"), [ @@ -36,9 +35,7 @@ def data() -> Data: (nwp.lit(None, nw.Int32), [True, None, False, True]), ], ) -def test_list_contains( - data: Data, item: IntoExpr, expected: list[bool | None] -) -> None: # pragma: no cover +def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) -> None: df = dataframe(data).with_columns(a.cast(nw.List(nw.Int32))) result = df.select(a.list.contains(item)) assert_equal_data(result, {"a": expected}) From 1a69b686ec67fc839ce85063a3becbfb49cc28b7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 12 Dec 2025 13:41:55 +0000 Subject: [PATCH 193/215] fix: Raise and suggest solution for len(1) Series broadcast Discovered that `polars` has a more specific error for this, while trying to create a len(1) Series of `List(String)` Adds `Series.{first,last}` too --- narwhals/_plan/arrow/expr.py | 10 +++++++++- narwhals/_plan/arrow/series.py | 8 ++++++++ narwhals/_plan/compliant/series.py | 3 +++ narwhals/_plan/series.py | 7 +++++++ tests/plan/expr_parsing_test.py | 25 ++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 987cd1395b..7f209aa933 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -49,7 +49,7 @@ not_implemented, qualified_type_name, ) -from narwhals.exceptions import InvalidOperationError +from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: from collections.abc import Callable, Mapping, Sequence @@ -404,8 +404,16 @@ def native(self) -> ChunkedArrayAny: def to_series(self) -> Series: return self._evaluated + # TODO @dangotbanned: Handle this `Series([...])` edge case higher up + # Can occur from a len(1) series passed to `with_columns`, which becomes a literal def broadcast(self, length: int, /) -> Series: if (actual_len := len(self)) != length: + if actual_len == 1: + msg = ( + f"Series {self.name}, length {actual_len} doesn't match the DataFrame height of {length}.\n\n" + "If you want an expression to be broadcasted, ensure it is a scalar (for instance by adding '.first()')." + ) + raise ShapeError(msg) raise shape_error(length, actual_len) return self._evaluated diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 3407cf3cca..d892f80c7f 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -30,6 +30,7 @@ Into1DArray, IntoDType, NonNestedLiteral, + PythonLiteral, _1DArray, ) @@ -305,6 +306,13 @@ def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Sel exploder = fn.ExplodeBuilder(empty_as_null=empty_as_null, keep_nulls=keep_nulls) return self._with_native(exploder.explode(self.native)) + def first(self) -> PythonLiteral: + return self.native[0].as_py() if len(self) else None + + def last(self) -> PythonLiteral: + ca = self.native + return ca[height - 1].as_py() if (height := len(ca)) else None + @property def struct(self) -> SeriesStructNamespace: return SeriesStructNamespace(self) diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 4f3ba969bc..faf14bf2f2 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -24,6 +24,7 @@ IntoDType, NonNestedLiteral, NumericLiteral, + PythonLiteral, SizedMultiIndexSelector, TemporalLiteral, _1DArray, @@ -139,6 +140,7 @@ def fill_null(self, value: NonNestedLiteral | Self) -> Self: ... def fill_null_with_strategy( self, strategy: FillNullStrategy, limit: int | None = None ) -> Self: ... + def first(self) -> PythonLiteral: ... def shift(self, n: int, *, fill_value: NonNestedLiteral = None) -> Self: ... def gather( self, @@ -159,6 +161,7 @@ def is_not_nan(self) -> Self: def is_not_null(self) -> Self: return self.is_null().__invert__() + def last(self) -> PythonLiteral: ... def rolling_mean( self, window_size: int, *, min_samples: int, center: bool = False ) -> Self: ... diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 65a8265d9c..a067a7ff4b 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -36,6 +36,7 @@ IntoDType, NonNestedLiteral, NumericLiteral, + PythonLiteral, SizedMultiIndexSelector, TemporalLiteral, ) @@ -266,6 +267,12 @@ def sum(self) -> float: def count(self) -> int: return self._compliant.count() + def first(self) -> PythonLiteral: + return self._compliant.first() + + def last(self) -> PythonLiteral: + return self._compliant.last() + def unique(self, *, maintain_order: bool = False) -> Self: return type(self)(self._compliant.unique(maintain_order=maintain_order)) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 54665ac8a8..34ab20a1c0 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -26,7 +26,7 @@ MultiOutputExpressionError, ShapeError, ) -from tests.plan.utils import assert_expr_ir_equal, re_compile +from tests.plan.utils import assert_equal_data, assert_expr_ir_equal, re_compile if TYPE_CHECKING: from contextlib import AbstractContextManager @@ -661,6 +661,29 @@ def test_mode_invalid() -> None: nwp.col("a").mode(keep="first") # type: ignore[arg-type] +def test_broadcast_len_1_series_invalid() -> None: + pytest.importorskip("pyarrow") + data = {"a": [1, 2, 3]} + values = [4] + df = nwp.DataFrame.from_dict(data, backend="pyarrow") + ser = nwp.Series.from_iterable(values, name="bad", backend="pyarrow") + with pytest.raises( + ShapeError, + match=re_compile( + r"series.+bad.+length.+1.+match.+DataFrame.+height.+3.+broadcasted.+\.first\(\)" + ), + ): + df.with_columns(ser) + + expected_series = {"a": [1, 2, 3], "literal": [4, 4, 4]} + # we can only preserve `Series.name` if we got a `lit(Series).first()`, not `lit(Series.first())` + expected_series_literal = {"a": [1, 2, 3], "bad": [4, 4, 4]} + + assert_equal_data(df.with_columns(ser.first()), expected_series) + assert_equal_data(df.with_columns(ser.last()), expected_series) + assert_equal_data(df.with_columns(nwp.lit(ser).first()), expected_series_literal) + + @pytest.mark.parametrize( ("window_size", "min_samples", "context"), [ From f10823f797839a7290f2e1ab88bbb45189ae328c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 12 Dec 2025 13:53:46 +0000 Subject: [PATCH 194/215] test: Add `test_list_contains_scalar` --- tests/plan/list_contains_test.py | 40 +++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/plan/list_contains_test.py b/tests/plan/list_contains_test.py index 4e32045a47..9bdb5be0b5 100644 --- a/tests/plan/list_contains_test.py +++ b/tests/plan/list_contains_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Final import pytest @@ -19,6 +19,7 @@ def data() -> Data: "a": [[2, 2, 3, None, None], None, [], [None]], "b": [[1, 2, 2], [3, 4], [5, 5, 5, 6], [7]], "c": [1, 3, None, 2], + "d": ["B", None, "A", "C"], } @@ -39,3 +40,40 @@ def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) df = dataframe(data).with_columns(a.cast(nw.List(nw.Int32))) result = df.select(a.list.contains(item)) assert_equal_data(result, {"a": expected}) + + +R1: Final[list[Any]] = [None, "A", "B", "A", "A", "B"] +R2: Final = None +R3: Final[list[Any]] = [] +R4: Final = [None] + + +@pytest.mark.xfail( + reason=" TODO: `ArrowScalar.list.contains`", raises=NotImplementedError +) +@pytest.mark.parametrize( + ("row", "item", "expected"), + [ + (R1, "A", True), + (R2, "A", None), + (R3, "A", False), + (R4, "A", False), + (R1, None, True), + (R2, None, None), + (R3, None, False), + (R4, None, True), + (R1, "C", False), + (R2, "C", None), + (R3, "C", False), + (R4, "C", False), + ], +) +def test_list_contains_scalar( + row: list[str | None] | None, + item: IntoExpr, + expected: bool | None, # noqa: FBT001 +) -> None: # pragma: no cover + data = {"a": [row]} + df = dataframe(data).select(a.cast(nw.List(nw.String))) + result = df.select(a.first().list.contains(item)) + assert_equal_data(result, {"a": expected}) From 61cfa529d7e32a468321e54fdb0213353967c1cc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 12 Dec 2025 14:54:20 +0000 Subject: [PATCH 195/215] feat(expr-ir): Support `ArrowScalar.list.contains` Just want the hacky version in first, before cleaning up all the paths --- narwhals/_plan/arrow/expr.py | 12 +++++------- narwhals/_plan/arrow/functions.py | 19 ++++++++++++++++--- tests/plan/list_contains_test.py | 11 +++-------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 7f209aa933..8d4a60de33 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -452,13 +452,15 @@ def filter(self, node: ir.Filter, frame: Frame, name: str) -> Expr: def first(self, node: First, frame: Frame, name: str) -> Scalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = native[0] if len(prev) else fn.lit(None, native.type) + result: NativeScalar = native[0] if len(prev) else fn.lit(None, native.type) return self._with_native(result, name) def last(self, node: Last, frame: Frame, name: str) -> Scalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = native[len_ - 1] if (len_ := len(prev)) else fn.lit(None, native.type) + result: NativeScalar = ( + native[len_ - 1] if (len_ := len(prev)) else fn.lit(None, native.type) + ) return self._with_native(result, name) def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar: @@ -975,14 +977,10 @@ def contains( expr, other = func.unwrap_input(node) prev = expr.dispatch(self.compliant, frame, name) item = other.dispatch(self.compliant, frame, name) - if isinstance(prev, ArrowScalar): - msg = "TODO: `ArrowScalar.list.contains`" - raise NotImplementedError(msg) if isinstance(item, ArrowExpr): # Maybe one day, not now raise NotImplementedError - result = fn.list_contains(prev.native, item.native) - return self.with_native(result, name) + return self.with_native(fn.list_contains(prev.native, item.native), name) class ArrowStringNamespace( diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index e306260c68..c7a8f0daa1 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -687,17 +687,24 @@ def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScal # TODO @dangotbanned: Clean up -# TODO @dangotbanned: Support `native: ListScalar` # NOTE: Both of these weren't able to support `[None]`, where, 2 in [None] should be False # https://github.com/apache/arrow/issues/33295 # https://github.com/apache/arrow/issues/47118#issuecomment-3075893244 def list_contains( - native: ChunkedList, item: NonNestedLiteral | ScalarAny -) -> ChunkedArray[pa.BooleanScalar]: + native: ChunkedOrScalar[ListScalar], item: NonNestedLiteral | ScalarAny +) -> ChunkedOrScalar[pa.BooleanScalar]: # empty should always be False # None should always be None # `None` in `[None]` should be True # Anything else in `[None]` should be false + if isinstance(native, pa.Scalar): + scalar = t.cast("pa.ListScalar[Any]", native) + if scalar.is_valid: + if len(scalar): + other = array(lit(item).cast(scalar.type.value_type)) + return any_(is_in(_list_explode(scalar), other)) + return lit(False, BOOL) + return lit(None, BOOL) ca = native table = ExplodeBuilder(empty_as_null=False, keep_nulls=False).explode_with_indices(ca) values = is_in(table.column(1), array(lit(item))) @@ -1706,6 +1713,12 @@ def hist_zeroed_data( return {"breakpoint": bp, "count": zeros(n)} +@overload +def lit(value: Any) -> NativeScalar: ... +@overload +def lit(value: Any, dtype: BoolType) -> pa.BooleanScalar: ... +@overload +def lit(value: Any, dtype: DataType | None = ...) -> NativeScalar: ... def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) diff --git a/tests/plan/list_contains_test.py b/tests/plan/list_contains_test.py index 9bdb5be0b5..0761434e97 100644 --- a/tests/plan/list_contains_test.py +++ b/tests/plan/list_contains_test.py @@ -48,9 +48,6 @@ def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) R4: Final = [None] -@pytest.mark.xfail( - reason=" TODO: `ArrowScalar.list.contains`", raises=NotImplementedError -) @pytest.mark.parametrize( ("row", "item", "expected"), [ @@ -69,11 +66,9 @@ def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) ], ) def test_list_contains_scalar( - row: list[str | None] | None, - item: IntoExpr, - expected: bool | None, # noqa: FBT001 -) -> None: # pragma: no cover + row: list[str | None] | None, item: IntoExpr, *, expected: bool | None +) -> None: data = {"a": [row]} df = dataframe(data).select(a.cast(nw.List(nw.String))) result = df.select(a.first().list.contains(item)) - assert_equal_data(result, {"a": expected}) + assert_equal_data(result, {"a": [expected]}) From 1e5826485d53d09efeaf0414a5f0442354452ff8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 12 Dec 2025 14:59:47 +0000 Subject: [PATCH 196/215] chore: Update cov --- narwhals/_plan/expressions/lists.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 40c2364894..8dd1749aa4 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -29,16 +29,14 @@ class Get(ListFunction): class Contains(ListFunction): """N-ary (expr, item).""" - def unwrap_input( - self, node: FExpr[Self], / - ) -> tuple[ExprIR, ExprIR]: # pragma: no cover + def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: expr, item = node.input return expr, item class IRListNamespace(IRNamespace): len: ClassVar = Len - unique: ClassVar = Unique # pragma: no cover + unique: ClassVar = Unique contains: ClassVar = Contains get: ClassVar = Get @@ -51,7 +49,7 @@ def _ir_namespace(self) -> type[IRListNamespace]: def len(self) -> Expr: return self._with_unary(self._ir.len()) - def unique(self) -> Expr: # pragma: no cover + def unique(self) -> Expr: return self._with_unary(self._ir.unique()) def get(self, index: int) -> Expr: From 38a698e3ab1a70a88b1a557f33569125b559e026 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:10:33 +0000 Subject: [PATCH 197/215] docs: Add a note on `CompliantExpr.dt`] I've tried adding the `not_implemented` 3 times now and kept forgetting why it wasn't there yet --- narwhals/_plan/compliant/expr.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index e53e25977c..39a7a912a1 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -280,6 +280,11 @@ def struct( self, ) -> ExprStructNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... + # NOTE: This test has a case for detecting `Expr` impl, but missing `CompliantExpr` member + # `tests/plan/dispatch_test.py::test_dispatch` + # TODO @dangotbanned: Update that logic when `dt` namespace is actually implemented + # dt: not_implemented = not_implemented()` + class EagerExpr( EagerBroadcast[SeriesT], From 30e5f48ac302c340740a35f9d90be3da726ceb1d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 12 Dec 2025 16:08:25 +0000 Subject: [PATCH 198/215] feat: Expose `Series.struct.schema` `ArrowSeries.struct.unnest` depends on this for backcompat I'd rather this was covered in all cases --- narwhals/_plan/arrow/series.py | 5 +++++ narwhals/_plan/compliant/accessors.py | 3 +++ narwhals/_plan/series.py | 8 ++++++++ tests/plan/hist_test.py | 3 +++ 4 files changed, 19 insertions(+) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index d892f80c7f..7f7d732064 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -14,6 +14,7 @@ from narwhals._plan.expressions import functions as F from narwhals._utils import Version, generate_repr from narwhals.dependencies import is_numpy_array_1d +from narwhals.schema import Schema if TYPE_CHECKING: from collections.abc import Iterable @@ -360,3 +361,7 @@ def unnest(self) -> DataFrame: # name overriding *may* be wrong def field(self, name: str) -> ArrowSeries: return self.with_native(fn.struct_field(self.native, name), name) + + @property + def schema(self) -> Schema: + return Schema.from_arrow(fn.struct_schema(self.native)) diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 0c1eb0de28..9f1a42ef13 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -13,6 +13,7 @@ from narwhals._plan.expressions import FunctionExpr as FExpr, lists, strings from narwhals._plan.expressions.categorical import GetCategories from narwhals._plan.expressions.struct import FieldByName + from narwhals.schema import Schema class ExprCatNamespace(Protocol[FrameT_contra, ExprT_co]): @@ -93,3 +94,5 @@ def field( class SeriesStructNamespace(Protocol[DataFrameT_co, SeriesT_co]): def field(self, name: str) -> SeriesT_co: ... def unnest(self) -> DataFrameT_co: ... + @property + def schema(self) -> Schema: ... diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index a067a7ff4b..075fe569d7 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -32,6 +32,7 @@ from narwhals._plan.dataframe import DataFrame from narwhals._typing import EagerAllowed, IntoBackend, _EagerAllowedImpl from narwhals.dtypes import DType + from narwhals.schema import Schema from narwhals.typing import ( IntoDType, NonNestedLiteral, @@ -321,6 +322,13 @@ def unnest(self) -> DataFrame[Any, Any]: ) return result + def field(self, name: str) -> SeriesT: # pragma: no cover + return type(self._series)(self._series._compliant.struct.field(name)) + + @property + def schema(self) -> Schema: + return self._series._compliant.struct.schema + class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py index 28dfbfdd06..fdd42c40c4 100644 --- a/tests/plan/hist_test.py +++ b/tests/plan/hist_test.py @@ -276,6 +276,7 @@ def test_hist_expr_breakpoint( dtype_count = nw.Int64 dtype_struct = nw.Struct({"breakpoint": dtype_breakpoint, "count": dtype_count}) + schema_struct = nw.Schema({"breakpoint": nw.Float64(), "count": nw.Int64()}) expected_schema = nw.Schema( [ ("int", dtype_struct), @@ -308,6 +309,8 @@ def test_hist_expr_breakpoint( } assert result_schema == expected_schema assert_equal_data(result, expected_data) + for ser in result.iter_columns(): + assert ser.struct.schema == schema_struct @bin_count_cases From 1f9da6864292d14f90b8a0b097198f2fb7376933 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 12 Dec 2025 19:38:32 +0000 Subject: [PATCH 199/215] refactor: Pull out tools from list ops --- narwhals/_plan/arrow/functions.py | 114 +++++++++++++++++------------- narwhals/_plan/arrow/group_by.py | 29 +++++++- narwhals/_plan/arrow/series.py | 4 +- 3 files changed, 94 insertions(+), 53 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index c7a8f0daa1..7a675382a5 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -655,7 +655,6 @@ def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScal [`pc.replace_with_mask`]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.replace_with_mask.html [apache/arrow#43716]: https://github.com/apache/arrow/issues/43716 """ - from narwhals._plan.arrow.acero import group_by_table from narwhals._plan.arrow.group_by import AggSpec if isinstance(native, pa.Scalar): @@ -666,61 +665,46 @@ def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScal idx, v = "index", "values" names = idx, v len_not_eq_0 = not_eq(list_len(native), lit(0)) - aggs = [AggSpec.from_expr_ir(ir_unique(v), v)] can_fastpath = all_(len_not_eq_0, ignore_nulls=False).as_py() if can_fastpath: arrays = [_list_parent_indices(native), _list_explode(native)] - result = group_by_table(concat_horizontal(arrays, names), [idx], aggs) - else: - # Oh no - we caught a bad one! - # We need to split things into good/bad - and only work on the good stuff. - # `int_range` is acting like `parent_indices`, but doesn't give up when it see's `None` or `[]` - indexed = concat_horizontal([int_range(len(native)), native], names) - valid = indexed.filter(len_not_eq_0) - invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) - # To keep track of where we started, our index needs to be exploded with the list elements - explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) - valid_unique = group_by_table(explode_with_index, [idx], aggs) - # And now, because we can't join - we do a poor man's version of one 😉 - result = concat_tables([valid_unique, invalid]).sort_by(idx) - return result.column(v) - - -# TODO @dangotbanned: Clean up -# NOTE: Both of these weren't able to support `[None]`, where, 2 in [None] should be False -# https://github.com/apache/arrow/issues/33295 -# https://github.com/apache/arrow/issues/47118#issuecomment-3075893244 + return AggSpec.unique(v).over_index(concat_horizontal(arrays, names), idx) + # Oh no - we caught a bad one! + # We need to split things into good/bad - and only work on the good stuff. + # `int_range` is acting like `parent_indices`, but doesn't give up when it see's `None` or `[]` + indexed = concat_horizontal([int_range(len(native)), native], names) + valid = indexed.filter(len_not_eq_0) + invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) + # To keep track of where we started, our index needs to be exploded with the list elements + explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) + valid_unique = AggSpec.unique(v).over(explode_with_index, [idx]) + # And now, because we can't join - we do a poor man's version of one 😉 + return concat_tables([valid_unique, invalid]).sort_by(idx).column(v) + + def list_contains( native: ChunkedOrScalar[ListScalar], item: NonNestedLiteral | ScalarAny ) -> ChunkedOrScalar[pa.BooleanScalar]: - # empty should always be False - # None should always be None - # `None` in `[None]` should be True - # Anything else in `[None]` should be false + from narwhals._plan.arrow.group_by import AggSpec + if isinstance(native, pa.Scalar): scalar = t.cast("pa.ListScalar[Any]", native) if scalar.is_valid: if len(scalar): - other = array(lit(item).cast(scalar.type.value_type)) - return any_(is_in(_list_explode(scalar), other)) + value_type = scalar.type.value_type + return any_(eq_missing(_list_explode(scalar), lit(item).cast(value_type))) return lit(False, BOOL) return lit(None, BOOL) - ca = native - table = ExplodeBuilder(empty_as_null=False, keep_nulls=False).explode_with_indices(ca) - values = is_in(table.column(1), array(lit(item))) - name = table.field(1).name - contains = ( - table.set_column(1, name, values) - .group_by("idx") - .aggregate([(name, "hash_any", pa_options.scalar_aggregate(ignore_nulls=True))]) - .column(1) - ) + builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) + tbl = builder.explode_with_indices(native) + idx, name = tbl.column_names + contains = eq_missing(tbl.column(name), item) + l_contains = AggSpec.any(name).over_index(tbl.set_column(1, name, contains), idx) # Here's the really key part: this mask has the same result we want to return # So by filling the `True`, we can flip those to `False` if needed # But if we were already `None` or `False` - then that's sticky - propagate_invalid = not_eq(list_len(ca), lit(0)) - results = replace_with_mask(array(propagate_invalid), propagate_invalid, contains) - return chunked_array(results) + propagate_invalid: ChunkedArray[pa.BooleanScalar] = not_eq(list_len(native), lit(0)) + return replace_with_mask(propagate_invalid, propagate_invalid, l_contains) def implode(native: Arrow[Scalar[DataTypeT]]) -> ListScalar[DataTypeT]: @@ -1306,11 +1290,15 @@ def replace_strict_default( def replace_with_mask( native: ChunkedOrArrayT, mask: Predicate, replacements: ChunkedOrArrayAny ) -> ChunkedOrArrayT: - if not isinstance(mask, pa.BooleanArray): - mask = t.cast("pa.BooleanArray", array(mask)) - if not isinstance(replacements, pa.Array): - replacements = array(replacements) - result: ChunkedOrArrayT = pc.replace_with_mask(native, mask, replacements) + """Replace elements of `native`, at positions defined by `mask`. + + The length of `replacements` must equal the number of `True` values in `mask`. + """ + if isinstance(native, pa.ChunkedArray): + args = [array(p) for p in (native, mask, replacements)] + return chunked_array(pc.call_function("replace_with_mask", args)) + args = [native, array(mask), array(replacements)] + result: ChunkedOrArrayT = pc.call_function("replace_with_mask", args) return result @@ -1367,12 +1355,38 @@ def is_in(values: ArrowAny, /, other: ChunkedOrArrayAny) -> ArrowAny: return is_in_(values, other) # type: ignore[no-any-return] -def ir_min_max(name: str, /) -> MinMax: - return MinMax(expr=ir.col(name)) +@t.overload +def eq_missing( + native: ChunkedArrayAny, other: NonNestedLiteral | ArrowAny +) -> ChunkedArray[pa.BooleanScalar]: ... +@t.overload +def eq_missing( + native: ArrayAny, other: NonNestedLiteral | ArrowAny +) -> Array[pa.BooleanScalar]: ... +@t.overload +def eq_missing( + native: ScalarAny, other: NonNestedLiteral | ArrowAny +) -> pa.BooleanScalar: ... +@t.overload +def eq_missing( + native: ChunkedOrScalarAny, other: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarAny: ... +def eq_missing(native: ArrowAny, other: NonNestedLiteral | ArrowAny) -> ArrowAny: + """Equivalent to `native == other` where `None == None`. + This differs from default `eq` where null values are propagated. -def ir_unique(name: str, /) -> ir.FunctionExpr[F.Unique]: - return F.Unique().to_function_expr(ir.col(name)) + Note: + Unique to `pyarrow`, this wrapper will ensure `None` uses `native.type`. + """ + if isinstance(other, (pa.Array, pa.ChunkedArray)): + return is_in(native, other) + item = array(other if isinstance(other, pa.Scalar) else lit(other, native.type)) + return is_in(native, item) + + +def ir_min_max(name: str, /) -> MinMax: + return MinMax(expr=ir.col(name)) def _boolean_is_unique( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 29d87f866e..849572f864 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -17,7 +17,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence from typing_extensions import Self, TypeAlias @@ -136,6 +136,33 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: fn_name = SUPPORTED_IR[type(expr)] return cls(expr.name if isinstance(expr, ir.Column) else (), fn_name, name=name) + # NOTE: Fast-paths for single column rewrites + @classmethod + def _from_function(cls, tp: type[ir.Function], name: str) -> Self: + return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name) + + @classmethod + def any(cls, name: str) -> Self: + return cls._from_function(ir.boolean.Any, name) + + @classmethod + def unique(cls, name: str) -> Self: + return cls._from_function(ir.functions.Unique, name) + + def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: + """Sugar for `native.group_by(keys).aggregate([self])`. + + Returns a table with columns named: `[*keys, self.name]` + """ + return acero.group_by_table(native, keys, [self]) + + def over_index(self, native: pa.Table, index_column: str) -> ChunkedArrayAny: + """Execute this aggregation over `index_column`. + + Returns a single, (unnamed) array, representing the aggregation results. + """ + return acero.group_by_table(native, [index_column], [self]).column(self.name) + def group_by_error( column_name: str, expr: ir.ExprIR, reason: Literal["too complex"] | None = None diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 7f7d732064..f92e6497c3 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -100,8 +100,8 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: def scatter(self, indices: Self, values: Self) -> Self: mask = fn.is_in(fn.int_range(len(self), chunked=False), indices.native) - replacements = fn.array(values._gather(pc.sort_indices(indices.native))) - return self._with_native(pc.replace_with_mask(self.native, mask, replacements)) + replacements = values._gather(pc.sort_indices(indices.native)) + return self._with_native(fn.replace_with_mask(self.native, mask, replacements)) def is_in(self, other: Self) -> Self: return self._with_native(fn.is_in(self.native, other.native)) From 2d84d14e86b9931a0bf3d1a3270273cd000ed46a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 12 Dec 2025 23:05:27 +0000 Subject: [PATCH 200/215] feat(DRAFT): `Expr.list.join(ignore_nulls=True)` progress TIL: `pyarrow.compute.and_not` exists --- narwhals/_plan/arrow/functions.py | 89 ++++++++++++++++++++++++++++--- 1 file changed, 82 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7a675382a5..047eae4349 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -436,7 +436,7 @@ def explode( return _list_explode(safe) return chunked_array(_list_explode(safe)) - def explode_with_indices(self, native: ChunkedList) -> pa.Table: + def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table: safe = self._fill_with_null(native) if self.options.any() else native arrays = [_list_parent_indices(safe), _list_explode(safe)] return concat_horizontal(arrays, ["idx", "values"]) @@ -620,12 +620,79 @@ def list_join( Each list of values in the first input is joined using each second input as separator. If any input list is null or contains a null, the corresponding output will be null. + + Edge cases: + + >>> import polars as pl + >>> data = { + ... "s": [ + ... ["a", "b", "c"], + ... ["x", "y"], + ... ["1", None, "3"], + ... [None], + ... None, + ... [], + ... [None, None], # <-- everything works except this, for now + ... ] + ... } + >>> s = pl.col("s") + >>> result = pl.DataFrame(data).select( + ... s, + ... ignore_nulls=s.list.join("-", ignore_nulls=True), + ... propagate_nulls=s.list.join("-", ignore_nulls=False), + ... ) + >>> result + shape: (7, 3) + ┌──────────────────┬──────────────┬─────────────────┐ + │ s ┆ ignore_nulls ┆ propagate_nulls │ + │ --- ┆ --- ┆ --- │ + │ list[str] ┆ str ┆ str │ + ╞══════════════════╪══════════════╪═════════════════╡ + │ ["a", "b", "c"] ┆ a-b-c ┆ a-b-c │ + │ ["x", "y"] ┆ x-y ┆ x-y │ + │ ["1", null, "3"] ┆ 1-3 ┆ null │ + │ [null] ┆ ┆ null │ + │ null ┆ null ┆ null │ + │ [] ┆ ┆ │ + │ [null, null] ┆ ┆ null │ + └──────────────────┴──────────────┴─────────────────┘ """ - if ignore_nulls: - # NOTE: `polars` default is `True`, will need to handle that if this becomes api - msg = "TODO: `ArrowExpr.list.join(ignore_nulls=True)`" + join = t.cast( + "Callable[[Any, Any], ChunkedArray[StringScalar] | pa.StringArray]", + pc.binary_join, + ) + if not ignore_nulls: + return pc.binary_join(native, separator) + # NOTE: `polars` default is `True` + if isinstance(native, pa.Scalar): + to_join = ( + implode(_list_explode(native).drop_null()) if native.is_valid else native + ) + return pc.binary_join(to_join, separator) + result = join(native, separator) + if not result.null_count: + # if we got here and there were no nulls, then we're done + return result + todo_mask = pc.and_not(result.is_null(), native.is_null()) + todo_lists = native.filter(todo_mask) + list_len_1: ChunkedOrArrayAny = eq(list_len(todo_lists), lit(1)) # pyright: ignore[reportAssignmentType] + only_single_null = any_(list_len_1).as_py() + if only_single_null: + todo_lists = when_then(list_len_1, lit([""], todo_lists.type), todo_lists) + builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) + replacements = join( + builder.explode_with_indices(todo_lists) + .drop_null() + .group_by("idx") + .aggregate([("values", "hash_list")]) + .column(1), + separator, + ) + if len(replacements) != len(list_len_1): + # probably do-able, but the edge cases here are getting hairy + msg = f"TODO: `ArrowExpr.list.join` w/ `[None, None , ...]` element\n{native!r}" raise NotImplementedError(msg) - return pc.binary_join(native, separator) + return replace_with_mask(result, todo_mask, replacements) @overload @@ -1287,9 +1354,17 @@ def replace_strict_default( return chunked_array(result) if isinstance(native, pa.ChunkedArray) else result[0] +@overload def replace_with_mask( native: ChunkedOrArrayT, mask: Predicate, replacements: ChunkedOrArrayAny -) -> ChunkedOrArrayT: +) -> ChunkedOrArrayT: ... +@overload +def replace_with_mask( + native: ChunkedOrArrayAny, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayAny: ... +def replace_with_mask( + native: ChunkedOrArrayAny, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayAny: """Replace elements of `native`, at positions defined by `mask`. The length of `replacements` must equal the number of `True` values in `mask`. @@ -1298,7 +1373,7 @@ def replace_with_mask( args = [array(p) for p in (native, mask, replacements)] return chunked_array(pc.call_function("replace_with_mask", args)) args = [native, array(mask), array(replacements)] - result: ChunkedOrArrayT = pc.call_function("replace_with_mask", args) + result: ChunkedOrArrayAny = pc.call_function("replace_with_mask", args) return result From fbb71a64b93f4f2f711e46189f41fbe4a41763ee Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 13:20:20 +0000 Subject: [PATCH 201/215] fix: Handle `[None, None , ...]` in `list.join(ignore_nulls=True)` --- narwhals/_plan/arrow/functions.py | 44 +++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 047eae4349..ce80c092e6 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -592,6 +592,11 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: return result +# TODO @dangotbanned: Raise a feature request for `pc.binary_join(strings, separator, *, options: JoinOptions)` +# - Default for `binary_join_element_wise` is the only behavior available (here) currently +# - Working around it is a **slog** +# TODO @dangotbanned: Major de-uglyify +# Everything is functional, need to add tests before simplifying @t.overload def list_join( native: ChunkedList[StringType], @@ -632,7 +637,7 @@ def list_join( ... [None], ... None, ... [], - ... [None, None], # <-- everything works except this, for now + ... [None, None], ... ] ... } >>> s = pl.col("s") @@ -678,20 +683,32 @@ def list_join( list_len_1: ChunkedOrArrayAny = eq(list_len(todo_lists), lit(1)) # pyright: ignore[reportAssignmentType] only_single_null = any_(list_len_1).as_py() if only_single_null: + # `[None]` todo_lists = when_then(list_len_1, lit([""], todo_lists.type), todo_lists) builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) - replacements = join( - builder.explode_with_indices(todo_lists) - .drop_null() - .group_by("idx") - .aggregate([("values", "hash_list")]) - .column(1), - separator, + pre_drop_null = builder.explode_with_indices(todo_lists) + implode_by_idx = ( + pre_drop_null.drop_null().group_by("idx").aggregate([("values", "hash_list")]) ) + replacements = join(implode_by_idx.column(1), separator) if len(replacements) != len(list_len_1): - # probably do-able, but the edge cases here are getting hairy - msg = f"TODO: `ArrowExpr.list.join` w/ `[None, None , ...]` element\n{native!r}" - raise NotImplementedError(msg) + # `[None, ..., None]` + # This is a very unlucky case to hit, because + # - we can detect the issue earlier + # - but we can't join a table with a list in it + # So this is after-the-fact and messy + empty = lit("", todo_lists.type.value_type) + replacements = ( + implode_by_idx.select(["idx"]) + .append_column("values", replacements) + .join( + to_table(pre_drop_null.column("idx").unique(), "idx"), + "idx", + join_type="full outer", + ) + .column("values") + .fill_null(empty) + ) return replace_with_mask(result, todo_mask, replacements) @@ -1867,6 +1884,11 @@ def concat_vertical( return result +def to_table(array: ChunkedOrArrayAny, name: str = "") -> pa.Table: + """Equivalent to `Series.to_frame`, but with an option to insert a name for the column.""" + return concat_horizontal((array,), (name,)) + + def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: return ( (first := next(iter(obj.items())), None) From 4e2da365906199ea9b682b69f0cc95d9e45ef085 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 13:21:13 +0000 Subject: [PATCH 202/215] chore: remove outdated comment --- narwhals/_plan/arrow/series.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index f92e6497c3..7390e7f7d5 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -347,9 +347,6 @@ def unnest(self) -> DataFrame: if len(native): table = pa.Table.from_struct_array(native) else: - # TODO @dangotbanned: Report empty bug upstream, no option to pass a schema to resolve the error - # `ValueError: Must pass schema, or at least one RecordBatch` - # https://github.com/apache/arrow/blob/b2e8f2505ba3eafe65a78ece6ae87fa7d0c1c133/python/pyarrow/table.pxi#L4943-L4949 table = fn.struct_schema(native).empty_table() else: # pragma: no cover # NOTE: Too strict, doesn't allow `Array[StructScalar]` From f9eccc28616668dc5a8a127de37fee2ac54b8a8c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 14:05:53 +0000 Subject: [PATCH 203/215] test: Start adding `list.join` tests --- narwhals/_plan/arrow/expr.py | 2 + narwhals/_plan/arrow/functions.py | 38 +-------------- narwhals/_plan/compliant/accessors.py | 3 ++ narwhals/_plan/expressions/lists.py | 13 ++++++ tests/plan/list_join_test.py | 66 +++++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 37 deletions(-) create mode 100644 tests/plan/list_join_test.py diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 8d4a60de33..55213f5239 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -982,6 +982,8 @@ def contains( raise NotImplementedError return self.with_native(fn.list_contains(prev.native, item.native), name) + join = not_implemented() + class ArrowStringNamespace( ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index ce80c092e6..d2c6b3f82d 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -596,7 +596,7 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: # - Default for `binary_join_element_wise` is the only behavior available (here) currently # - Working around it is a **slog** # TODO @dangotbanned: Major de-uglyify -# Everything is functional, need to add tests before simplifying +# Everything is functional, need to finish adding tests before simplifying @t.overload def list_join( native: ChunkedList[StringType], @@ -625,42 +625,6 @@ def list_join( Each list of values in the first input is joined using each second input as separator. If any input list is null or contains a null, the corresponding output will be null. - - Edge cases: - - >>> import polars as pl - >>> data = { - ... "s": [ - ... ["a", "b", "c"], - ... ["x", "y"], - ... ["1", None, "3"], - ... [None], - ... None, - ... [], - ... [None, None], - ... ] - ... } - >>> s = pl.col("s") - >>> result = pl.DataFrame(data).select( - ... s, - ... ignore_nulls=s.list.join("-", ignore_nulls=True), - ... propagate_nulls=s.list.join("-", ignore_nulls=False), - ... ) - >>> result - shape: (7, 3) - ┌──────────────────┬──────────────┬─────────────────┐ - │ s ┆ ignore_nulls ┆ propagate_nulls │ - │ --- ┆ --- ┆ --- │ - │ list[str] ┆ str ┆ str │ - ╞══════════════════╪══════════════╪═════════════════╡ - │ ["a", "b", "c"] ┆ a-b-c ┆ a-b-c │ - │ ["x", "y"] ┆ x-y ┆ x-y │ - │ ["1", null, "3"] ┆ 1-3 ┆ null │ - │ [null] ┆ ┆ null │ - │ null ┆ null ┆ null │ - │ [] ┆ ┆ │ - │ [null, null] ┆ ┆ null │ - └──────────────────┴──────────────┴─────────────────┘ """ join = t.cast( "Callable[[Any, Any], ChunkedArray[StringScalar] | pa.StringArray]", diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py index 9f1a42ef13..26df3ff3ba 100644 --- a/narwhals/_plan/compliant/accessors.py +++ b/narwhals/_plan/compliant/accessors.py @@ -35,6 +35,9 @@ def len( def unique( self, node: FExpr[lists.Unique], frame: FrameT_contra, name: str ) -> ExprT_co: ... + def join( + self, node: FExpr[lists.Join], frame: FrameT_contra, name: str + ) -> ExprT_co: ... class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 8dd1749aa4..e35b4fb41e 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -25,6 +25,12 @@ class Unique(ListFunction): ... class Get(ListFunction): __slots__ = ("index",) index: int +class Join(ListFunction): + """Join all string items in a sublist and place a separator between them.""" + + __slots__ = ("ignore_nulls", "separator") + separator: str + ignore_nulls: bool # fmt: on class Contains(ListFunction): """N-ary (expr, item).""" @@ -39,6 +45,7 @@ class IRListNamespace(IRNamespace): unique: ClassVar = Unique contains: ClassVar = Contains get: ClassVar = Get + join: ClassVar = Join class ExprListNamespace(ExprNamespace[IRListNamespace]): @@ -65,3 +72,9 @@ def contains(self, item: IntoExpr) -> Expr: if not item_ir.is_scalar: raise function_arg_non_scalar_error(contains, "item", item_ir) return self._expr._from_ir(contains.to_function_expr(self._expr._ir, item_ir)) + + def join(self, separator: str, *, ignore_nulls: bool = True) -> Expr: + ensure_type(separator, str, param_name="separator") + return self._with_unary( + self._ir.join(separator=separator, ignore_nulls=ignore_nulls) + ) diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py new file mode 100644 index 0000000000..a1a468324c --- /dev/null +++ b/tests/plan/list_join_test.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "a": [ + ["a", "b", "c"], + [None, None, None], + [None, None, "1", "2", None, "3", None], + ["x", "y"], + ["1", None, "3"], + [None], + None, + [], + [None, None], + ] + } + + +a = nwp.col("a") + + +@pytest.mark.xfail( + reason="TODO: `list.join` is not yet implemented for 'ArrowExpr'", + raises=NotImplementedError, +) +@pytest.mark.parametrize( + ("separator", "ignore_nulls", "expected"), + [ + ("-", False, ["a-b-c", None, None, "x-y", None, None, None, "", None]), + ("-", True, ["a-b-c", "", "1-2-3", "x-y", "1-3", "", None, "", ""]), + ("", False, ["abc", None, None, "xy", None, None, None, "", None]), + ("", True, ["abc", "", "123", "xy", "13", "", None, "", ""]), + ], + ids=[ + "hyphen-propagate-nulls", + "hyphen-ignore-nulls", + "empty-propagate-nulls", + "empty-ignore-nulls", + ], +) +def test_list_join( + data: Data, separator: str, *, ignore_nulls: bool, expected: list[str | None] +) -> None: # pragma: no cover + df = dataframe(data).with_columns(a.cast(nw.List(nw.String))) + expr = a.list.join(separator, ignore_nulls=ignore_nulls) + result = df.select(expr) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.xfail +def test_list_join_scalar() -> None: # pragma: no cover + msg = "TODO: Add non-duplicated tests for this" + raise NotImplementedError(msg) From 8c66f4fcca6a3170c97269e871cdac754334d00e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 14:29:43 +0000 Subject: [PATCH 204/215] hook up `list.join` and unxfail some tests I added some extra cases that I hadn't considered while debugging Good news is they failed --- narwhals/_plan/arrow/expr.py | 6 +++++- narwhals/_plan/arrow/functions.py | 7 +++++++ tests/plan/list_join_test.py | 25 ++++++++++++++++++------- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 55213f5239..adca054aa4 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -982,7 +982,11 @@ def contains( raise NotImplementedError return self.with_native(fn.list_contains(prev.native, item.native), name) - join = not_implemented() + def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scalar: + separator, ignore_nulls = node.function.separator, node.function.ignore_nulls + return self.unary(fn.list_join, separator, ignore_nulls=ignore_nulls)( + node, frame, name + ) class ArrowStringNamespace( diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d2c6b3f82d..e03bbac395 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -618,6 +618,13 @@ def list_join( *, ignore_nulls: bool = ..., ) -> pa.StringScalar: ... +@t.overload +def list_join( + native: ChunkedOrScalar[ListScalar[StringType]], + separator: str, + *, + ignore_nulls: bool = ..., +) -> ChunkedOrScalar[StringScalar]: ... def list_join( native: ArrowAny, separator: Arrow[StringScalar] | str, *, ignore_nulls: bool = False ) -> ArrowAny: diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py index a1a468324c..7d6a44029b 100644 --- a/tests/plan/list_join_test.py +++ b/tests/plan/list_join_test.py @@ -31,18 +31,29 @@ def data() -> Data: a = nwp.col("a") - -@pytest.mark.xfail( - reason="TODO: `list.join` is not yet implemented for 'ArrowExpr'", - raises=NotImplementedError, +# TODO @dangotbanned: Ensure the final branch works when replacements are mixed +XFAIL_INCORRECT_RESULTS = pytest.mark.xfail( + reason="Returned out-of-order post-join", raises=AssertionError ) + + @pytest.mark.parametrize( ("separator", "ignore_nulls", "expected"), [ ("-", False, ["a-b-c", None, None, "x-y", None, None, None, "", None]), - ("-", True, ["a-b-c", "", "1-2-3", "x-y", "1-3", "", None, "", ""]), + pytest.param( + "-", + True, + ["a-b-c", "", "1-2-3", "x-y", "1-3", "", None, "", ""], + marks=XFAIL_INCORRECT_RESULTS, + ), ("", False, ["abc", None, None, "xy", None, None, None, "", None]), - ("", True, ["abc", "", "123", "xy", "13", "", None, "", ""]), + pytest.param( + "", + True, + ["abc", "", "123", "xy", "13", "", None, "", ""], + marks=XFAIL_INCORRECT_RESULTS, + ), ], ids=[ "hyphen-propagate-nulls", @@ -53,7 +64,7 @@ def data() -> Data: ) def test_list_join( data: Data, separator: str, *, ignore_nulls: bool, expected: list[str | None] -) -> None: # pragma: no cover +) -> None: df = dataframe(data).with_columns(a.cast(nw.List(nw.String))) expr = a.list.join(separator, ignore_nulls=ignore_nulls) result = df.select(expr) From 679e6534839fd41aa2937549b1f639e9c163f99a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 14:36:05 +0000 Subject: [PATCH 205/215] fix: Ensure table join doesn't break order The tiniest of fixes --- narwhals/_plan/arrow/functions.py | 1 + tests/plan/list_join_test.py | 19 ++----------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index e03bbac395..bb78f3104e 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -677,6 +677,7 @@ def list_join( "idx", join_type="full outer", ) + .sort_by("idx") .column("values") .fill_null(empty) ) diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py index 7d6a44029b..60efbd3c5c 100644 --- a/tests/plan/list_join_test.py +++ b/tests/plan/list_join_test.py @@ -31,29 +31,14 @@ def data() -> Data: a = nwp.col("a") -# TODO @dangotbanned: Ensure the final branch works when replacements are mixed -XFAIL_INCORRECT_RESULTS = pytest.mark.xfail( - reason="Returned out-of-order post-join", raises=AssertionError -) - @pytest.mark.parametrize( ("separator", "ignore_nulls", "expected"), [ ("-", False, ["a-b-c", None, None, "x-y", None, None, None, "", None]), - pytest.param( - "-", - True, - ["a-b-c", "", "1-2-3", "x-y", "1-3", "", None, "", ""], - marks=XFAIL_INCORRECT_RESULTS, - ), + ("-", True, ["a-b-c", "", "1-2-3", "x-y", "1-3", "", None, "", ""]), ("", False, ["abc", None, None, "xy", None, None, None, "", None]), - pytest.param( - "", - True, - ["abc", "", "123", "xy", "13", "", None, "", ""], - marks=XFAIL_INCORRECT_RESULTS, - ), + ("", True, ["abc", "", "123", "xy", "13", "", None, "", ""]), ], ids=[ "hyphen-propagate-nulls", From e5410b3af7ea6e79d3ba53bbfd721c1ae961de25 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 15:53:22 +0000 Subject: [PATCH 206/215] test: Add `test_list_join_scalar` --- tests/plan/list_join_test.py | 73 +++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py index 60efbd3c5c..27268444d0 100644 --- a/tests/plan/list_join_test.py +++ b/tests/plan/list_join_test.py @@ -9,24 +9,31 @@ from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: + from typing import Final, TypeVar + + from typing_extensions import TypeAlias + from tests.conftest import Data + T = TypeVar("T") + SubList: TypeAlias = list[T] | list[T | None] | list[None] | None + SubListStr: TypeAlias = SubList[str] + + +R1: Final[SubListStr] = ["a", "b", "c"] +R2: Final[SubListStr] = [None, None, None] +R3: Final[SubListStr] = [None, None, "1", "2", None, "3", None] +R4: Final[SubListStr] = ["x", "y"] +R5: Final[SubListStr] = ["1", None, "3"] +R6: Final[SubListStr] = [None] +R7: Final[SubListStr] = None +R8: Final[SubListStr] = [] +R9: Final[SubListStr] = [None, None] + @pytest.fixture(scope="module") def data() -> Data: - return { - "a": [ - ["a", "b", "c"], - [None, None, None], - [None, None, "1", "2", None, "3", None], - ["x", "y"], - ["1", None, "3"], - [None], - None, - [], - [None, None], - ] - } + return {"a": [R1, R2, R3, R4, R5, R6, R7, R8, R9]} a = nwp.col("a") @@ -41,10 +48,10 @@ def data() -> Data: ("", True, ["abc", "", "123", "xy", "13", "", None, "", ""]), ], ids=[ - "hyphen-propagate-nulls", - "hyphen-ignore-nulls", - "empty-propagate-nulls", - "empty-ignore-nulls", + "hyphen-propagate_nulls", + "hyphen-ignore_nulls", + "empty-propagate_nulls", + "empty-ignore_nulls", ], ) def test_list_join( @@ -56,7 +63,31 @@ def test_list_join( assert_equal_data(result, {"a": expected}) -@pytest.mark.xfail -def test_list_join_scalar() -> None: # pragma: no cover - msg = "TODO: Add non-duplicated tests for this" - raise NotImplementedError(msg) +@pytest.mark.parametrize( + "ignore_nulls", [True, False], ids=["ignore_nulls", "propagate_nulls"] +) +@pytest.mark.parametrize("separator", ["?", "", " "], ids=["question", "empty", "space"]) +@pytest.mark.parametrize( + "row", [R1, R2, R3, R4, R5, R6, R7, R8, R9], ids=[f"row-{i}" for i in range(1, 10)] +) +def test_list_join_scalar(row: SubListStr, separator: str, *, ignore_nulls: bool) -> None: + data = {"a": [row]} + df = dataframe(data).select(a.cast(nw.List(nw.String))) + expr = a.first().list.join(separator, ignore_nulls=ignore_nulls) + result = df.select(expr) + expected: str | None + if row is None: + expected = None + elif row == []: + expected = "" + elif any(el is None for el in row): + if not ignore_nulls: + expected = None + elif all(el is None for el in row): + expected = "" + else: + expected = separator.join(el for el in row if el is not None) + else: + expected = separator.join(el for el in row if el is not None) + + assert_equal_data(result, {"a": [expected]}) From 1344e3dfe532b096b85840c8bb0900e0c3e52bbd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 16:17:10 +0000 Subject: [PATCH 207/215] chore: fix cov --- tests/plan/meta_test.py | 2 +- tests/plan/range_test.py | 4 ++-- tests/plan/temp_test.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 503dd26b3e..5385c65e17 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -30,7 +30,7 @@ raises=AssertionError, ), ) - else: + else: # pragma: no cover marks = () OVER_CASE = pytest.param( nwp.col("a").last().over("b", order_by="c"), diff --git a/tests/plan/range_test.py b/tests/plan/range_test.py index 7765f23d59..0dc02bc5ad 100644 --- a/tests/plan/range_test.py +++ b/tests/plan/range_test.py @@ -7,7 +7,7 @@ from narwhals.exceptions import ShapeError from tests.utils import PYARROW_VERSION -if PYARROW_VERSION < (21,): +if PYARROW_VERSION < (21,): # pragma: no cover pytest.importorskip("numpy") import datetime as dt @@ -223,7 +223,7 @@ def test_linear_space_values( expected = np.linspace(start, end, num_samples, endpoint=False) elif interval == "right": expected = np.linspace(start, end, num_samples + 1)[1:] - elif interval == "none": + else: expected = np.linspace(start, end, num_samples + 2)[1:-1] assert_equal_series(result, expected, "ls") diff --git a/tests/plan/temp_test.py b/tests/plan/temp_test.py index 9dd7a0e42f..58873be4d1 100644 --- a/tests/plan/temp_test.py +++ b/tests/plan/temp_test.py @@ -66,14 +66,14 @@ def test_temp_column_names_sources(source: _StoresColumns | Iterable[str]) -> No @given(n_chars=st.integers(6, 106)) @pytest.mark.slow -def test_temp_column_name_n_chars(n_chars: int) -> None: +def test_temp_column_name_n_chars(n_chars: int) -> None: # pragma: no cover name = temp.column_name(_COLUMNS, n_chars=n_chars) assert name not in _COLUMNS @given(n_new_names=st.integers(10_000, 100_000)) @pytest.mark.slow -def test_temp_column_names_always_new_names(n_new_names: int) -> None: +def test_temp_column_names_always_new_names(n_new_names: int) -> None: # pragma: no cover it = temp.column_names(_COLUMNS) new_names = set(islice(it, n_new_names)) assert len(new_names) == n_new_names From e13ef6a81bfa5a6bd52012e95a47eb4ce062823b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 16:48:27 +0000 Subject: [PATCH 208/215] refactor: start cleaning up that mess you made --- narwhals/_plan/arrow/functions.py | 37 ++++++++++++++++--------------- narwhals/_plan/arrow/group_by.py | 6 +++++ 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index bb78f3104e..481d8f6547 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -596,7 +596,7 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: # - Default for `binary_join_element_wise` is the only behavior available (here) currently # - Working around it is a **slog** # TODO @dangotbanned: Major de-uglyify -# Everything is functional, need to finish adding tests before simplifying +# Everything is functional, need to simplifying @t.overload def list_join( native: ChunkedList[StringType], @@ -633,6 +633,8 @@ def list_join( Each list of values in the first input is joined using each second input as separator. If any input list is null or contains a null, the corresponding output will be null. """ + from narwhals._plan.arrow.group_by import AggSpec + join = t.cast( "Callable[[Any, Any], ChunkedArray[StringScalar] | pa.StringArray]", pc.binary_join, @@ -649,39 +651,38 @@ def list_join( if not result.null_count: # if we got here and there were no nulls, then we're done return result - todo_mask = pc.and_not(result.is_null(), native.is_null()) - todo_lists = native.filter(todo_mask) - list_len_1: ChunkedOrArrayAny = eq(list_len(todo_lists), lit(1)) # pyright: ignore[reportAssignmentType] + is_null_sensitive = pc.and_not(result.is_null(), native.is_null()) + lists = native.filter(is_null_sensitive) + list_len_1: ChunkedOrArrayAny = eq(list_len(lists), lit(1)) # pyright: ignore[reportAssignmentType] only_single_null = any_(list_len_1).as_py() if only_single_null: # `[None]` - todo_lists = when_then(list_len_1, lit([""], todo_lists.type), todo_lists) + lists = when_then(list_len_1, lit([""], lists.type), lists) + idx, v = "idx", "values" builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) - pre_drop_null = builder.explode_with_indices(todo_lists) - implode_by_idx = ( - pre_drop_null.drop_null().group_by("idx").aggregate([("values", "hash_list")]) - ) - replacements = join(implode_by_idx.column(1), separator) + explode_w_idx = builder.explode_with_indices(lists) + implode_by_idx = AggSpec.implode(v).over(explode_w_idx.drop_null(), [idx]) + replacements = join(implode_by_idx.column(v), separator) if len(replacements) != len(list_len_1): # `[None, ..., None]` # This is a very unlucky case to hit, because # - we can detect the issue earlier # - but we can't join a table with a list in it # So this is after-the-fact and messy - empty = lit("", todo_lists.type.value_type) + empty = lit("", lists.type.value_type) replacements = ( - implode_by_idx.select(["idx"]) - .append_column("values", replacements) + implode_by_idx.select([idx]) + .append_column(v, replacements) .join( - to_table(pre_drop_null.column("idx").unique(), "idx"), - "idx", + to_table(explode_w_idx.column(idx).unique(), idx), + idx, join_type="full outer", ) - .sort_by("idx") - .column("values") + .sort_by(idx) + .column(v) .fill_null(empty) ) - return replace_with_mask(result, todo_mask, replacements) + return replace_with_mask(result, is_null_sensitive, replacements) @overload diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 849572f864..af918a3e4c 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -149,6 +149,12 @@ def any(cls, name: str) -> Self: def unique(cls, name: str) -> Self: return cls._from_function(ir.functions.Unique, name) + @classmethod + def implode(cls, name: str) -> Self: + # TODO @dangotbanned: Replace with `agg.Implode` (via `_from_agg`) once both have been dded + # https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-plan/src/dsl/expr/mod.rs#L44 + return cls(name, SUPPORTED_IR[ir.Column], None, name) + def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: """Sugar for `native.group_by(keys).aggregate([self])`. From f0ace30c840d722968d7152f0e6391657f922bda Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 17:02:53 +0000 Subject: [PATCH 209/215] refactor: `select().append_column()` -> `set_column()` --- narwhals/_plan/arrow/functions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 481d8f6547..8241a202ba 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -130,6 +130,9 @@ F64: Final = pa.float64() BOOL: Final = pa.bool_() +EMPTY: Final = "" +"""The empty string.""" + class MinMax(ir.AggExpr): """Returns a `Struct({'min': ..., 'max': ...})`. @@ -657,7 +660,7 @@ def list_join( only_single_null = any_(list_len_1).as_py() if only_single_null: # `[None]` - lists = when_then(list_len_1, lit([""], lists.type), lists) + lists = when_then(list_len_1, lit([EMPTY], lists.type), lists) idx, v = "idx", "values" builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) explode_w_idx = builder.explode_with_indices(lists) @@ -669,10 +672,9 @@ def list_join( # - we can detect the issue earlier # - but we can't join a table with a list in it # So this is after-the-fact and messy - empty = lit("", lists.type.value_type) + empty = lit(EMPTY, lists.type.value_type) replacements = ( - implode_by_idx.select([idx]) - .append_column(v, replacements) + implode_by_idx.set_column(1, v, replacements) .join( to_table(explode_w_idx.column(idx).unique(), idx), idx, From 0c12d4ab07d064759525ad1cd9da48b63a91660c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 17:07:55 +0000 Subject: [PATCH 210/215] perf: Rewrite `"full outer"` join as `"left outer"` --- narwhals/_plan/arrow/functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 8241a202ba..1cbc68d75f 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -674,12 +674,8 @@ def list_join( # So this is after-the-fact and messy empty = lit(EMPTY, lists.type.value_type) replacements = ( - implode_by_idx.set_column(1, v, replacements) - .join( - to_table(explode_w_idx.column(idx).unique(), idx), - idx, - join_type="full outer", - ) + to_table(explode_w_idx.column(idx).unique(), idx) + .join(implode_by_idx.set_column(1, v, replacements), idx) .sort_by(idx) .column(v) .fill_null(empty) From ee676b30db950363d3324958ef4bb2a70f95b0b0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 18:56:57 +0000 Subject: [PATCH 211/215] test: Add some fast path cases --- tests/plan/list_join_test.py | 69 ++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py index 27268444d0..f826377555 100644 --- a/tests/plan/list_join_test.py +++ b/tests/plan/list_join_test.py @@ -9,6 +9,7 @@ from tests.plan.utils import assert_equal_data, dataframe if TYPE_CHECKING: + from collections.abc import Sequence from typing import Final, TypeVar from typing_extensions import TypeAlias @@ -91,3 +92,71 @@ def test_list_join_scalar(row: SubListStr, separator: str, *, ignore_nulls: bool expected = separator.join(el for el in row if el is not None) assert_equal_data(result, {"a": [expected]}) + + +@pytest.mark.parametrize( + ("rows", "expected"), + [ + ([R1, R4, ["all", "okay"]], ["a b c", "x y", "all okay"]), + ( + [ + None, + ["no", "nulls", "inside"], + None, + None, + ["only", "on", "validity"], + None, + ], + [None, "no nulls inside", None, None, "only on validity", None], + ), + ( + [["just", "empty", "lists"], [], [], ["nothing", "fancy"], []], + ["just empty lists", "", "", "nothing fancy", ""], + ), + ([None, None, None], [None, None, None]), + ( + [ + ["every", None, "null"], + None, + [None, "is", "lonely"], + ["not", "even"], + ["a", "single", None, "friend"], + [None], + ], + ["every null", None, "is lonely", "not even", "a single friend", ""], + ), + ( + [ + ["even", None, "this"], + [], + [None], + None, + [None], + [None, "can", "be", "cheap"], + [], + None, + [None], + ], + ["even this", "", "", None, "", "can be cheap", "", None, ""], + ), + ], + ids=[ + "full", + "no-nulls-inside", + "only-empty-lists", + "full-null", + "max-1-null", + "mixed-bag", + ], +) +def test_list_join_ignore_nulls_fastpaths( + rows: Sequence[SubListStr], expected: list[str | None] +) -> None: + # When we don't need to handle *every* edge case at the same time ... + # ... things can be simpler + separator = " " + data = {"a": list(rows)} + df = dataframe(data).with_columns(a.cast(nw.List(nw.String))) + expr = a.list.join(separator) + result = df.select(expr) + assert_equal_data(result, {"a": expected}) From acfe7569e68c9337bb93f7bd52442b9b4bb37fce Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 19:38:20 +0000 Subject: [PATCH 212/215] refactor: Split out `list_join_scalar` --- narwhals/_plan/arrow/expr.py | 12 ++++-- narwhals/_plan/arrow/functions.py | 64 ++++++++++++++++++++----------- 2 files changed, 50 insertions(+), 26 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index adca054aa4..0a95b18a84 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -984,9 +984,15 @@ def contains( def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scalar: separator, ignore_nulls = node.function.separator, node.function.ignore_nulls - return self.unary(fn.list_join, separator, ignore_nulls=ignore_nulls)( - node, frame, name - ) + previous = node.input[0].dispatch(self.compliant, frame, name) + result: ChunkedOrScalarAny + if isinstance(previous, ArrowExpr): + result = fn.list_join(previous.native, separator, ignore_nulls=ignore_nulls) + else: + result = fn.list_join_scalar( + previous.native, separator, ignore_nulls=ignore_nulls + ) + return self.with_native(result, name) class ArrowStringNamespace( diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 1cbc68d75f..7cf5c5d4c8 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -616,21 +616,17 @@ def list_join( ) -> pa.StringArray: ... @t.overload def list_join( - native: ListScalar[StringType], - separator: Arrow[StringScalar] | str, - *, - ignore_nulls: bool = ..., -) -> pa.StringScalar: ... -@t.overload -def list_join( - native: ChunkedOrScalar[ListScalar[StringType]], + native: ChunkedOrArray[ListScalar[StringType]], separator: str, *, ignore_nulls: bool = ..., -) -> ChunkedOrScalar[StringScalar]: ... +) -> ChunkedOrArray[StringScalar]: ... def list_join( - native: ArrowAny, separator: Arrow[StringScalar] | str, *, ignore_nulls: bool = False -) -> ArrowAny: + native: ChunkedOrArrayAny, + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = True, +) -> ChunkedOrArrayAny: """Join all string items in a sublist and place a separator between them. Each list of values in the first input is joined using each second input as separator. @@ -642,17 +638,9 @@ def list_join( "Callable[[Any, Any], ChunkedArray[StringScalar] | pa.StringArray]", pc.binary_join, ) - if not ignore_nulls: - return pc.binary_join(native, separator) - # NOTE: `polars` default is `True` - if isinstance(native, pa.Scalar): - to_join = ( - implode(_list_explode(native).drop_null()) if native.is_valid else native - ) - return pc.binary_join(to_join, separator) result = join(native, separator) - if not result.null_count: - # if we got here and there were no nulls, then we're done + if not ignore_nulls or not result.null_count: + # nice, no work for us then return result is_null_sensitive = pc.and_not(result.is_null(), native.is_null()) lists = native.filter(is_null_sensitive) @@ -683,6 +671,23 @@ def list_join( return replace_with_mask(result, is_null_sensitive, replacements) +def list_join_scalar( + native: ListScalar[StringType], + separator: StringScalar | str, + *, + ignore_nulls: bool = True, +) -> StringScalar: + """Join all string items in a `ListScalar` and place a separator between them. + + Note: + Consider using `list_join` or `str_join` if you don't already have `native` in this shape. + """ + if ignore_nulls and native.is_valid: + native = implode(_list_explode(native).drop_null()) + result: StringScalar = pc.call_function("binary_join", [native, separator]) + return result + + @overload def list_unique(native: ChunkedList) -> ChunkedList: ... @overload @@ -780,7 +785,7 @@ def str_join( return native if ignore_nulls and native.null_count: native = native.drop_null() - return list_join(implode(native), separator) + return list_join_scalar(implode(native), separator, ignore_nulls=False) def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: @@ -861,6 +866,10 @@ def _str_split( return result +@t.overload +def str_split( + native: ChunkedArrayAny, by: str, *, literal: bool = ... +) -> ChunkedArray[ListScalar]: ... @t.overload def str_split( native: ChunkedOrScalarAny, by: str, *, literal: bool = ... @@ -873,6 +882,15 @@ def str_split(native: ArrowAny, by: str, *, literal: bool = True) -> Arrow[ListS return _str_split(native, by, literal=literal) +@t.overload +def str_splitn( + native: ChunkedArrayAny, + by: str, + n: int, + *, + literal: bool = ..., + as_struct: bool = ..., +) -> ChunkedArray[ListScalar]: ... @t.overload def str_splitn( native: ChunkedOrScalarAny, @@ -961,7 +979,7 @@ def str_replace_vector( list_split_by = str_split(match, pattern, literal=literal) else: list_split_by = str_splitn(match, pattern, n + 1, literal=literal) - replaced = list_join(list_split_by, match_replacements) + replaced = list_join(list_split_by, match_replacements, ignore_nulls=False) if all_(has_match, ignore_nulls=False).as_py(): return chunked_array(replaced) return replace_with_mask(native, has_match, array(replaced)) From e68d9ab9b12562848602e7a0d2f7baf80bc0576a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Dec 2025 21:56:23 +0000 Subject: [PATCH 213/215] When all seems lost, explain why it be that way --- narwhals/_plan/arrow/functions.py | 62 +++++++++++++++++++------------ narwhals/_plan/arrow/typing.py | 1 + tests/plan/list_join_test.py | 4 +- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7cf5c5d4c8..e7784c05e7 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -79,6 +79,7 @@ StringScalar, StringType, StructArray, + UInt32Type, UnaryFunction, UnaryNumeric, VectorFunction, @@ -126,6 +127,7 @@ """`pyarrow.compute.utf8_zero_fill` added in https://github.com/apache/arrow/pull/46815""" # NOTE: Common data type instances to share +UI32: Final = pa.uint32() I64: Final = pa.int64() F64: Final = pa.float64() BOOL: Final = pa.bool_() @@ -595,11 +597,15 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: return result +_list_join = t.cast( + "Callable[[ChunkedOrArrayAny, Arrow[StringScalar] | str], ChunkedArray[StringScalar] | pa.StringArray]", + pc.binary_join, +) + + # TODO @dangotbanned: Raise a feature request for `pc.binary_join(strings, separator, *, options: JoinOptions)` # - Default for `binary_join_element_wise` is the only behavior available (here) currently # - Working around it is a **slog** -# TODO @dangotbanned: Major de-uglyify -# Everything is functional, need to simplifying @t.overload def list_join( native: ChunkedList[StringType], @@ -634,39 +640,45 @@ def list_join( """ from narwhals._plan.arrow.group_by import AggSpec - join = t.cast( - "Callable[[Any, Any], ChunkedArray[StringScalar] | pa.StringArray]", - pc.binary_join, - ) - result = join(native, separator) + # (1): Try to return *as-is* from `pc.binary_join` + result = _list_join(native, separator) if not ignore_nulls or not result.null_count: - # nice, no work for us then return result is_null_sensitive = pc.and_not(result.is_null(), native.is_null()) + if array(is_null_sensitive, BOOL).true_count == 0: + return result + + # (2): Deal with only the bad kids lists = native.filter(is_null_sensitive) - list_len_1: ChunkedOrArrayAny = eq(list_len(lists), lit(1)) # pyright: ignore[reportAssignmentType] - only_single_null = any_(list_len_1).as_py() - if only_single_null: - # `[None]` - lists = when_then(list_len_1, lit([EMPTY], lists.type), lists) + + # (2.1): We know that `[None]` should join as `""`, and that is the only length-1 list we could have after the filter + list_len_eq_1 = eq(list_len(lists), lit(1, UI32)) + has_a_len_1_null = any_(list_len_eq_1).as_py() + if has_a_len_1_null: + lists = when_then(list_len_eq_1, lit([EMPTY], lists.type), lists) + + # (2.2): Everything left falls into one of these boxes: + # - (2.1): `[""]` + # - (2.2): `["something", (str | None)*, None]` <--- We fix this here and hope for the best + # - (2.3): `[None, (None)*, None]` idx, v = "idx", "values" builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) explode_w_idx = builder.explode_with_indices(lists) implode_by_idx = AggSpec.implode(v).over(explode_w_idx.drop_null(), [idx]) - replacements = join(implode_by_idx.column(v), separator) - if len(replacements) != len(list_len_1): - # `[None, ..., None]` - # This is a very unlucky case to hit, because - # - we can detect the issue earlier - # - but we can't join a table with a list in it - # So this is after-the-fact and messy - empty = lit(EMPTY, lists.type.value_type) + replacements = _list_join(implode_by_idx.column(v), separator) + + # (2.3): The cursed box 😨 + if len(replacements) != len(lists): + # This is a very unlucky case to hit, because we *can* detect the issue earlier + # but we *can't* join a table with a list in it. So we deal with the fallout now ... + # The end result is identical to (2.1) + indices_all = to_table(explode_w_idx.column(idx).unique(), idx) + indices_repaired = implode_by_idx.set_column(1, v, replacements) replacements = ( - to_table(explode_w_idx.column(idx).unique(), idx) - .join(implode_by_idx.set_column(1, v, replacements), idx) + indices_all.join(indices_repaired, idx) .sort_by(idx) .column(v) - .fill_null(empty) + .fill_null(lit(EMPTY, lists.type.value_type)) ) return replace_with_mask(result, is_null_sensitive, replacements) @@ -1813,6 +1825,8 @@ def lit(value: Any) -> NativeScalar: ... @overload def lit(value: Any, dtype: BoolType) -> pa.BooleanScalar: ... @overload +def lit(value: Any, dtype: UInt32Type) -> pa.UInt32Scalar: ... +@overload def lit(value: Any, dtype: DataType | None = ...) -> NativeScalar: ... def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 95f0c2afe6..c2befc214b 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -30,6 +30,7 @@ from narwhals._native import NativeDataFrame, NativeSeries from narwhals.typing import SizedMultiIndexSelector as _SizedMultiIndexSelector + UInt32Type: TypeAlias = "Uint32Type" StringType: TypeAlias = "_StringType | _LargeStringType" IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" StringScalar: TypeAlias = "Scalar[StringType]" diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py index f826377555..881700f6c7 100644 --- a/tests/plan/list_join_test.py +++ b/tests/plan/list_join_test.py @@ -132,7 +132,7 @@ def test_list_join_scalar(row: SubListStr, separator: str, *, ignore_nulls: bool [None], None, [None], - [None, "can", "be", "cheap"], + [None, "can", "be", None, "cheap"], [], None, [None], @@ -141,7 +141,7 @@ def test_list_join_scalar(row: SubListStr, separator: str, *, ignore_nulls: bool ), ], ids=[ - "full", + "all-good", "no-nulls-inside", "only-empty-lists", "full-null", From 73ab21cc488caf94698305d7fe1215df35300661 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 12:55:28 +0000 Subject: [PATCH 214/215] Update `pc.binary_join` comment --- narwhals/_plan/arrow/functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index e7784c05e7..d2d892582f 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -603,9 +603,7 @@ def list_get(native: ArrowAny, index: int) -> ArrowAny: ) -# TODO @dangotbanned: Raise a feature request for `pc.binary_join(strings, separator, *, options: JoinOptions)` -# - Default for `binary_join_element_wise` is the only behavior available (here) currently -# - Working around it is a **slog** +# NOTE: Raised for native null-handling (https://github.com/apache/arrow/issues/48477) @t.overload def list_join( native: ChunkedList[StringType], From 7bebe12111537863b7ae120f491f8ee8f283f3ea Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Dec 2025 13:23:14 +0000 Subject: [PATCH 215/215] refactor: Avoid inline circular import Resolves (https://github.com/narwhals-dev/narwhals/pull/3325#discussion_r2617043847) --- narwhals/_plan/arrow/functions.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index d2d892582f..42b44983b5 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -481,16 +481,20 @@ def explode_columns(self, native: pa.Table, subset: Collection[str], /) -> pa.Ta _list_explode(arr) for arr in self._iter_ensure_shape(first_len, arrays[1:]) ) + column_names = native.column_names + result = native first_result = _list_explode(first_safe) - if len(first_result) != len(native): - gathered = native.drop_columns(subset).take(_list_parent_indices(first_safe)) + if len(first_result) == len(native): + # fastpath for all length-1 lists + # if only the first is length-1, then the others raise during iteration on either branch for name, arr in zip(subset, chain([first_result], it)): - gathered = gathered.append_column(name, arr) - return gathered.select(native.column_names) - # NOTE: Not too happy about this import - from narwhals._plan.arrow.dataframe import with_arrays - - return with_arrays(native, zip(subset, chain([first_result], it))) + result = result.set_column(column_names.index(name), name, arr) + else: + result = result.drop_columns(subset).take(_list_parent_indices(first_safe)) + for name, arr in zip(subset, chain([first_result], it)): + result = result.append_column(name, arr) + result = result.select(column_names) + return result @classmethod def explode_column_fast(cls, native: pa.Table, column_name: str, /) -> pa.Table: