diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index c53d1576ba..03ef4b249c 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -7,12 +7,15 @@ all, all_horizontal, any_horizontal, + coalesce, col, concat_str, date_range, exclude, + format, int_range, len, + linear_space, lit, max, max_horizontal, @@ -37,12 +40,15 @@ "all", "all_horizontal", "any_horizontal", + "coalesce", "col", "concat_str", "date_range", "exclude", + "format", "int_range", "len", + "linear_space", "lit", "max", "max_horizontal", diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 7676de82bb..98eb51bec9 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -182,7 +182,20 @@ def _method_name(tp: type[ExprIRT | FunctionT]) -> str: def get_dispatch_name(expr: ExprIR | type[Function], /) -> str: - """Return the synthesized method name for `expr`.""" - return ( - repr(expr.function) if is_function_expr(expr) else expr.__expr_ir_dispatch__.name - ) + """Return the synthesized method name for `expr`. + + Note: + Refers to the `Compliant*` method name, which may be *either* more general + or more specialized than what the user called. + """ + dispatch: Dispatcher[Any] + if is_function_expr(expr): + from narwhals._plan import expressions as ir + + if isinstance(expr, (ir.RollingExpr, ir.AnonymousExpr)): + dispatch = expr.__expr_ir_dispatch__ + else: + dispatch = expr.function.__expr_ir_dispatch__ + else: + dispatch = expr.__expr_ir_dispatch__ + return dispatch.name diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index 8b71a4dd8d..8a8caafda2 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -52,8 +52,8 @@ def __init_subclass__( **kwds: Any, ) -> None: super().__init_subclass__(*args, **kwds) - if accessor: - config = replace(config or FEOptions.default(), accessor_name=accessor) + if accessor_name := accessor or cls.__expr_ir_config__.accessor_name: + config = replace(config or FEOptions.default(), accessor_name=accessor_name) if options: cls._function_options = staticmethod(options) if config: diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 8231af57ac..30503c55c3 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -24,7 +24,7 @@ NativeSeriesT, Seq, ) - from narwhals.typing import NonNestedLiteral + from narwhals.typing import NonNestedLiteral, PythonLiteral T = TypeVar("T") @@ -38,6 +38,7 @@ bytes, Decimal, ) +_PYTHON_LITERAL_TPS = (*_NON_NESTED_LITERAL_TPS, list, tuple, type(None)) def _ir(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 @@ -68,6 +69,10 @@ def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) +def is_python_literal(obj: Any) -> TypeIs[PythonLiteral]: + return isinstance(obj, _PYTHON_LITERAL_TPS) + + def is_expr(obj: Any) -> TypeIs[Expr]: return isinstance(obj, _expr().Expr) diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index 5b23cefde4..ecdfc92e26 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -18,7 +18,11 @@ is_selector, ) from narwhals._plan.common import flatten_hash_safe -from narwhals._plan.exceptions import invalid_into_expr_error, is_iterable_error +from narwhals._plan.exceptions import ( + invalid_into_expr_error, + is_iterable_error, + list_literal_error, +) from narwhals._utils import qualified_type_name from narwhals.dependencies import get_polars from narwhals.exceptions import InvalidOperationError @@ -127,7 +131,7 @@ def parse_into_expr_ir( expr = col(input) elif isinstance(input, list): if list_as_series is None: - raise TypeError(input) # pragma: no cover + raise list_literal_error(input) expr = lit(list_as_series(input)) else: expr = lit(input, dtype=dtype) @@ -331,7 +335,7 @@ def _combine_predicates(predicates: Iterator[ExprIR], /) -> ExprIR: inputs = (first,) else: return first - return AllHorizontal().to_function_expr(*inputs) + return AllHorizontal(ignore_nulls=False).to_function_expr(*inputs) def _is_iterable(obj: Iterable[T] | Any) -> TypeIs[Iterable[T]]: diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 541d38671a..82102e189d 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -18,7 +18,7 @@ import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any, Final, Union, cast +from typing import TYPE_CHECKING, Any, Final, Literal, Union, cast import pyarrow as pa # ignore-banned-import import pyarrow.acero as pac @@ -61,7 +61,9 @@ """ Target: TypeAlias = OneOrSeq[Field] -Aggregation: TypeAlias = "_Aggregation" +Aggregation: TypeAlias = Union[ + "_Aggregation", Literal["hash_kurtosis", "hash_skew", "kurtosis", "skew"] +] AggregateOptions: TypeAlias = "_AggregateOptions" Opts: TypeAlias = "AggregateOptions | None" OutputName: TypeAlias = str diff --git a/narwhals/_plan/arrow/common.py b/narwhals/_plan/arrow/common.py index fdbe173f2c..bd6f4b8e2e 100644 --- a/narwhals/_plan/arrow/common.py +++ b/narwhals/_plan/arrow/common.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic -from narwhals._plan.arrow.functions import BACKEND_VERSION +from narwhals._plan.arrow.functions import BACKEND_VERSION, random_indices from narwhals._typing_compat import TypeVar from narwhals._utils import Implementation, Version, _StoresNative @@ -43,6 +43,10 @@ def _with_native(self, native: NativeT) -> Self: msg = f"{type(self).__name__}._with_native" raise NotImplementedError(msg) + def __len__(self) -> int: + msg = f"{type(self).__name__}.__len__" + raise NotImplementedError(msg) + if BACKEND_VERSION >= (18,): def _gather(self, indices: Indices) -> NativeT: @@ -57,5 +61,14 @@ def gather(self, indices: Indices | _StoresNative[ChunkedArrayAny]) -> Self: ca = self._gather(indices.native if is_series(indices) else indices) return self._with_native(ca) + def gather_every(self, n: int, offset: int = 0) -> Self: + return self._with_native(self.native[offset::n]) + def slice(self, offset: int, length: int | None = None) -> Self: return self._with_native(self.native.slice(offset=offset, length=length)) + + def sample_n( + self, n: int = 1, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: + mask = random_indices(len(self), n, with_replacement=with_replacement, seed=seed) + return self.gather(mask) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 2c0b2e37e8..ecef91b772 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -16,6 +16,7 @@ from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.compliant.dataframe import EagerDataFrame from narwhals._plan.compliant.typing import namespace +from narwhals._plan.exceptions import shape_error from narwhals._plan.expressions import NamedIR from narwhals._utils import Version, generate_repr from narwhals.schema import Schema @@ -24,16 +25,18 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence import polars as pl - from typing_extensions import Self + from typing_extensions import Self, TypeAlias - from narwhals._plan.arrow.typing import ChunkedArrayAny + from narwhals._plan.arrow.typing import ChunkedArrayAny, ChunkedOrArrayAny from narwhals._plan.compliant.group_by import GroupByResolver from narwhals._plan.expressions import ExprIR, NamedIR - from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.options import ExplodeOptions, SortMultipleOptions from narwhals._plan.typing import NonCrossJoinStrategy from narwhals.dtypes import DType from narwhals.typing import IntoSchema +Incomplete: TypeAlias = Any + class ArrowDataFrame( FrameSeries["pa.Table"], EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"] @@ -48,6 +51,10 @@ def _with_native(self, native: pa.Table) -> Self: def _group_by(self) -> type[GroupBy]: return GroupBy + @property + def shape(self) -> tuple[int, int]: + return self.native.shape + def group_by_resolver(self, resolver: GroupByResolver, /) -> GroupBy: return self._group_by.from_resolver(self, resolver) @@ -68,11 +75,16 @@ def __len__(self) -> int: @classmethod def from_dict( - cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None + cls, + data: Mapping[str, Any], + /, + *, + schema: IntoSchema | None = None, + version: Version = Version.MAIN, ) -> Self: pa_schema = Schema(schema).to_arrow() if schema is not None else schema native = pa.Table.from_pydict(data, schema=pa_schema) - return cls.from_native(native, version=Version.MAIN) + return cls.from_native(native, version=version) def iter_columns(self) -> Iterator[Series]: for name, series in zip(self.columns, self.native.itercolumns()): @@ -96,10 +108,12 @@ def to_polars(self) -> pl.DataFrame: return pl.DataFrame(self.native) - def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[Series]: - ns = namespace(self) - from_named_ir = ns._expr.from_named_ir - yield from ns._expr.align(from_named_ir(e, self) for e in nodes) + def _evaluate_irs( + self, nodes: Iterable[NamedIR[ExprIR]], /, *, length: int | None = None + ) -> Iterator[Series]: + expr = namespace(self)._expr + from_named_ir = expr.from_named_ir + yield from expr.align((from_named_ir(e, self) for e in nodes), default=length) def sort(self, by: Sequence[str], options: SortMultipleOptions | None = None) -> Self: return self.gather(fn.sort_indices(self.native, *by, options=options)) @@ -121,6 +135,19 @@ def with_row_index_by( column = fn.unsort_indices(indices) return self._with_native(self.native.add_column(0, name, column)) + def to_struct(self, name: str = "") -> Series: + native = self.native + if fn.TO_STRUCT_ARRAY_ACCEPTS_EMPTY: + struct = native.to_struct_array() + elif fn.HAS_FROM_TO_STRUCT_ARRAY: + if len(native): + struct = native.to_struct_array() + else: + struct = fn.chunked_array([], pa.struct(native.schema)) + else: + struct = fn.struct(native.column_names, native.columns) + return Series.from_native(struct, name, version=self.version) + def get_column(self, name: str) -> Series: chunked = self.native.column(name) return Series.from_native(chunked, name, version=self.version) @@ -136,6 +163,12 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: native = self.native.filter(~to_drop) return self._with_native(native) + def explode(self, subset: Sequence[str], options: ExplodeOptions) -> Self: + builder = fn.ExplodeBuilder.from_options(options) + if len(subset) == 1: + return self._with_native(builder.explode_column(self.native, subset[0])) + return self._with_native(builder.explode_columns(self.native, subset)) + def rename(self, mapping: Mapping[str, str]) -> Self: names: dict[str, str] | list[str] if fn.BACKEND_VERSION >= (17,): @@ -144,20 +177,26 @@ def rename(self, mapping: Mapping[str, str]) -> Self: names = [mapping.get(c, c) for c in self.columns] return self._with_native(self.native.rename_columns(names)) - # NOTE: Use instead of `with_columns` for trivial cases + def with_series(self, series: Series) -> Self: + """Add a new column or replace an existing one. + + Uses similar semantics as `with_columns`, but: + - for a single named `Series` + - no broadcasting (use `Scalar.broadcast` instead) + - no length checking (use `with_series_checked` instead) + """ + return self._with_native(with_array(self.native, series.name, series.native)) + + def with_series_checked(self, series: Series) -> Self: + expected, actual = len(self), len(series) + if len(series) != len(self): + raise shape_error(expected, actual) + return self.with_series(series) + def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: - native = self.native - columns = self.columns height = len(self) - for into_series in exprs: - name = into_series.name - chunked = into_series.broadcast(height).native - if name in columns: - i = columns.index(name) - native = native.set_column(i, name, chunked) - else: - native = native.append_column(name, chunked) - return self._with_native(native) + names_and_columns = ((e.name, e.broadcast(height).native) for e in exprs) + return self._with_native(with_arrays(self.native, names_and_columns)) def select_names(self, *column_names: str) -> Self: return self._with_native(self.native.select(list(column_names))) @@ -200,3 +239,22 @@ def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[S from_native = self._with_native partitions = partition_by(self.native, by, include_key=include_key) return [from_native(df) for df in partitions] + + +def with_array(table: pa.Table, name: str, column: ChunkedOrArrayAny) -> pa.Table: + column_names = table.column_names + if name in column_names: + return table.set_column(column_names.index(name), name, column) + return table.append_column(name, column) + + +def with_arrays( + table: pa.Table, names_and_columns: Iterable[tuple[str, ChunkedOrArrayAny]], / +) -> pa.Table: + column_names = table.column_names + for name, column in names_and_columns: + if name in column_names: + table = table.set_column(column_names.index(name), name, column) + else: + table = table.append_column(name, column) + return table diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 31181b7a95..0a95b18a84 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -1,41 +1,75 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol, overload +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import narwhals_to_native_dtype -from narwhals._plan import expressions as ir -from narwhals._plan._guards import is_function_expr, is_seq_column +from narwhals._plan import common, expressions as ir +from narwhals._plan._guards import ( + is_function_expr, + is_iterable_reject, + is_python_literal, + is_seq_column, +) from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co from narwhals._plan.common import temp +from narwhals._plan.compliant.accessors import ( + ExprCatNamespace, + ExprListNamespace, + ExprStringNamespace, + ExprStructNamespace, +) from narwhals._plan.compliant.column import ExprDispatch from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar from narwhals._plan.compliant.typing import namespace +from narwhals._plan.exceptions import shape_error +from narwhals._plan.expressions import FunctionExpr as FExpr, functions as F from narwhals._plan.expressions.boolean import ( + IsDuplicated, IsFirstDistinct, IsInExpr, IsInSeq, IsInSeries, IsLastDistinct, + IsNotNan, + IsNotNull, + IsUnique, ) from narwhals._plan.expressions.functions import NullCount -from narwhals._utils import Implementation, Version, _StoresNative, not_implemented +from narwhals._utils import ( + Implementation, + Version, + _StoresNative, + not_implemented, + qualified_type_name, +) from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable, Mapping, Sequence from typing_extensions import Self, TypeAlias - from narwhals._arrow.typing import Incomplete from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction + from narwhals._plan.arrow.typing import ( + ChunkedArrayAny, + P, + UnaryFunctionP, + VectorFunction, + ) + from narwhals._plan.expressions import ( + BinaryExpr, + FunctionExpr as FExpr, + lists, + strings, + ) from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -61,19 +95,21 @@ IsNull, Not, ) - from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr as FExpr + from narwhals._plan.expressions.categorical import GetCategories from narwhals._plan.expressions.functions import ( Abs, CumAgg, Diff, + FillNan, FillNull, NullCount, Pow, Rank, Shift, ) + from narwhals._plan.expressions.struct import FieldByName from narwhals._plan.typing import Seq - from narwhals.typing import Into1DArray, IntoDType, PythonLiteral + from narwhals.typing import IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" Scalar: TypeAlias = "ArrowScalar" @@ -100,7 +136,7 @@ def pow(self, node: FExpr[Pow], frame: Frame, name: str) -> StoresNativeT_co: base, exponent = node.function.unwrap_input(node) base_ = base.dispatch(self, frame, "base").native exponent_ = exponent.dispatch(self, frame, "exponent").native - return self._with_native(pc.power(base_, exponent_), name) + return self._with_native(fn.power(base_, exponent_), name) def fill_null( self, node: FExpr[FillNull], frame: Frame, name: str @@ -110,6 +146,12 @@ def fill_null( value_ = value.dispatch(self, frame, "value").native return self._with_native(pc.fill_null(native, value_), name) + def fill_nan(self, node: FExpr[FillNan], frame: Frame, name: str) -> StoresNativeT_co: + expr, value = node.function.unwrap_input(node) + native = expr.dispatch(self, frame, name).native + value_ = value.dispatch(self, frame, "value").native + return self._with_native(fn.fill_nan(native, value_), name) + def is_between( self, node: FExpr[IsBetween], frame: Frame, name: str ) -> StoresNativeT_co: @@ -120,20 +162,49 @@ def is_between( result = fn.is_between(native, lower, upper, node.function.closed) return self._with_native(result, name) + @overload def _unary_function( - self, fn_native: Callable[[Any], Any], / + self, fn_native: UnaryFunctionP[P], /, *args: P.args, **kwds: P.kwargs + ) -> Callable[[FExpr[Any], Frame, str], StoresNativeT_co]: ... + @overload + def _unary_function( + self, fn_native: Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny], / + ) -> Callable[[FExpr[Any], Frame, str], StoresNativeT_co]: ... + def _unary_function( + self, fn_native: UnaryFunctionP[P], /, *args: P.args, **kwds: P.kwargs ) -> Callable[[FExpr[Any], Frame, str], StoresNativeT_co]: - def func(node: FExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: + """Return a function with the signature `(node, frame, name)`. + + Handles dispatching prior expressions, and rewrapping the result of this one. + + Arity refers to the number of expression inputs to a function (after expanding). + + So a **unary** function will look like: + + col("a").round(2) + + Which unravels to: + + FunctionExpr( + input=(Column(name="a"),), + # ^ length-1 tuple + function=Round(decimals=2), + # ^ non-expression argument + options=..., + ) + """ + + def func(node: FExpr[Any], frame: Frame, name: str, /) -> StoresNativeT_co: native = node.input[0].dispatch(self, frame, name).native - return self._with_native(fn_native(native), name) + return self._with_native(fn_native(native, *args, **kwds), name) return func def abs(self, node: FExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(pc.abs)(node, frame, name) + return self._unary_function(fn.abs_)(node, frame, name) def not_(self, node: FExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(pc.invert)(node, frame, name) + return self._unary_function(fn.not_)(node, frame, name) def all(self, node: FExpr[All], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.all_)(node, frame, name) @@ -153,24 +224,21 @@ def is_in_expr( ) -> StoresNativeT_co: expr, other = node.function.unwrap_input(node) right = other.dispatch(self, frame, name).native - if isinstance(right, pa.Scalar): - right = fn.array(right) - result = fn.is_in(expr.dispatch(self, frame, name).native, right) + arr = fn.array(right) if isinstance(right, pa.Scalar) else right + result = fn.is_in(expr.dispatch(self, frame, name).native, arr) return self._with_native(result, name) def is_in_series( self, node: FExpr[IsInSeries[ChunkedArrayAny]], frame: Frame, name: str ) -> StoresNativeT_co: - native = node.input[0].dispatch(self, frame, name).native other = node.function.other.unwrap().to_native() - return self._with_native(fn.is_in(native, other), name) + return self._unary_function(fn.is_in, other)(node, frame, name) def is_in_seq( self, node: FExpr[IsInSeq], frame: Frame, name: str ) -> StoresNativeT_co: - native = node.input[0].dispatch(self, frame, name).native other = fn.array(node.function.other) - return self._with_native(fn.is_in(native, other), name) + return self._unary_function(fn.is_in, other)(node, frame, name) def is_nan(self, node: FExpr[IsNan], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.is_nan)(node, frame, name) @@ -178,6 +246,16 @@ def is_nan(self, node: FExpr[IsNan], frame: Frame, name: str) -> StoresNativeT_c def is_null(self, node: FExpr[IsNull], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.is_null)(node, frame, name) + def is_not_nan( + self, node: FExpr[IsNotNan], frame: Frame, name: str + ) -> StoresNativeT_co: + return self._unary_function(fn.is_not_nan)(node, frame, name) + + def is_not_null( + self, node: FExpr[IsNotNull], frame: Frame, name: str + ) -> StoresNativeT_co: + return self._unary_function(fn.is_not_null)(node, frame, name) + def binary_expr(self, node: BinaryExpr, frame: Frame, name: str) -> StoresNativeT_co: lhs, rhs = ( node.left.dispatch(self, frame, name), @@ -195,13 +273,70 @@ def ternary_expr( result = pc.if_else(when.native, then.native, otherwise.native) return self._with_native(result, name) - exp = not_implemented() # type: ignore[misc] - log = not_implemented() # type: ignore[misc] - sqrt = not_implemented() # type: ignore[misc] - round = not_implemented() # type: ignore[misc] - clip = not_implemented() # type: ignore[misc] - drop_nulls = not_implemented() # type: ignore[misc] - replace_strict = not_implemented() # type: ignore[misc] + def log(self, node: FExpr[F.Log], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(fn.log, node.function.base)(node, frame, name) + + def exp(self, node: FExpr[F.Exp], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(fn.exp)(node, frame, name) + + def sqrt(self, node: FExpr[F.Sqrt], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(fn.sqrt)(node, frame, name) + + def round(self, node: FExpr[F.Round], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(fn.round, node.function.decimals)(node, frame, name) + + def ceil(self, node: FExpr[F.Ceil], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(fn.ceil)(node, frame, name) + + def floor(self, node: FExpr[F.Floor], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(fn.floor)(node, frame, name) + + def clip(self, node: FExpr[F.Clip], frame: Frame, name: str) -> StoresNativeT_co: + expr, lower, upper = node.function.unwrap_input(node) + result = fn.clip( + expr.dispatch(self, frame, name).native, + lower.dispatch(self, frame, name).native, + upper.dispatch(self, frame, name).native, + ) + return self._with_native(result, name) + + def clip_lower( + self, node: FExpr[F.ClipLower], frame: Frame, name: str + ) -> StoresNativeT_co: + expr, other = node.function.unwrap_input(node) + result = fn.clip_lower( + expr.dispatch(self, frame, name).native, + other.dispatch(self, frame, name).native, + ) + return self._with_native(result, name) + + def clip_upper( + self, node: FExpr[F.ClipUpper], frame: Frame, name: str + ) -> StoresNativeT_co: + expr, other = node.function.unwrap_input(node) + result = fn.clip_upper( + expr.dispatch(self, frame, name).native, + other.dispatch(self, frame, name).native, + ) + return self._with_native(result, name) + + def replace_strict( + self, node: FExpr[F.ReplaceStrict], frame: Frame, name: str + ) -> StoresNativeT_co: + old, new = node.function.old, node.function.new + dtype = fn.dtype_native(node.function.return_dtype, self.version) + return self._unary_function(fn.replace_strict, old, new, dtype)(node, frame, name) + + def replace_strict_default( + self, node: FExpr[F.ReplaceStrictDefault], frame: Frame, name: str + ) -> StoresNativeT_co: + func = node.function + expr, default_ = func.unwrap_input(node) + native = expr.dispatch(self, frame, name).native + default = default_.dispatch(self, frame, name).native + dtype = fn.dtype_native(func.return_dtype, self.version) + result = fn.replace_strict_default(native, func.old, func.new, default, dtype) + return self._with_native(result, name) class ArrowExpr( # type: ignore[misc] @@ -238,7 +373,7 @@ def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Sel def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Self: if isinstance(result, pa.Scalar): return ArrowScalar.from_native(result, name, version=self.version) - return self.from_native(result, name or self.name, self.version) + return self.from_native(result, name, self.version) # NOTE: I'm not sure what I meant by # > "isn't natively supported on `ChunkedArray`" @@ -269,10 +404,17 @@ def native(self) -> ChunkedArrayAny: def to_series(self) -> Series: return self._evaluated + # TODO @dangotbanned: Handle this `Series([...])` edge case higher up + # Can occur from a len(1) series passed to `with_columns`, which becomes a literal def broadcast(self, length: int, /) -> Series: if (actual_len := len(self)) != length: - msg = f"Expected object of length {length}, got {actual_len}." - raise ShapeError(msg) + if actual_len == 1: + msg = ( + f"Series {self.name}, length {actual_len} doesn't match the DataFrame height of {length}.\n\n" + "If you want an expression to be broadcasted, ensure it is a scalar (for instance by adding '.first()')." + ) + raise ShapeError(msg) + raise shape_error(length, actual_len) return self._evaluated def __len__(self) -> int: @@ -310,13 +452,15 @@ def filter(self, node: ir.Filter, frame: Frame, name: str) -> Expr: def first(self, node: First, frame: Frame, name: str) -> Scalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = native[0] if len(prev) else fn.lit(None, native.type) + result: NativeScalar = native[0] if len(prev) else fn.lit(None, native.type) return self._with_native(result, name) def last(self, node: Last, frame: Frame, name: str) -> Scalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = native[len_ - 1] if (len_ := len(prev)) else fn.lit(None, native.type) + result: NativeScalar = ( + native[len_ - 1] if (len_ := len(prev)) else fn.lit(None, native.type) + ) return self._with_native(result, name) def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar: @@ -381,21 +525,17 @@ def min(self, node: Min, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.min_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def null_count(self, node: FExpr[NullCount], frame: Frame, name: str) -> Scalar: + def null_count(self, node: FExpr[F.NullCount], frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.input[0], frame, name).native return self._with_native(fn.null_count(native), name) - # TODO @dangotbanned: top-level, complex-ish nodes - # - [ ] Over - # - [x] `over_ordered` - # - [x] `group_by`, `join` - # - [x] `over` (with partitions) - # - [x] `over_ordered` (with partitions) - # - [ ] fix: join on nulls after https://github.com/narwhals-dev/narwhals/issues/3300 - # - [ ] `map_batches` - # - [x] elementwise - # - [ ] scalar - # - [ ] `rolling_expr` has 4 variants + def kurtosis(self, node: FExpr[F.Kurtosis], frame: Frame, name: str) -> Scalar: + native = self._dispatch_expr(node.input[0], frame, name).native + return self._with_native(fn.kurtosis_skew(native, "kurtosis"), name) + + def skew(self, node: FExpr[F.Skew], frame: Frame, name: str) -> Scalar: + native = self._dispatch_expr(node.input[0], frame, name).native + return self._with_native(fn.kurtosis_skew(native, "skew"), name) def over( self, @@ -408,9 +548,9 @@ def over( expr = node.expr by = node.partition_by if is_function_expr(expr) and isinstance( - expr.function, (IsFirstDistinct, IsLastDistinct) + expr.function, (IsFirstDistinct, IsLastDistinct, IsUnique, IsDuplicated) ): - return self._is_first_last_distinct( + return self._boolean_length_preserving( expr, frame, name, by, sort_indices=sort_indices ) resolved = frame._grouper.by_irs(*by).agg_irs(expr.alias(name)).resolve(frame) @@ -429,15 +569,17 @@ def over_ordered( return evaluated return self.from_series(evaluated.broadcast(len(frame)).gather(indices)) - def _is_first_last_distinct( + def _boolean_length_preserving( self, - node: FExpr[IsFirstDistinct | IsLastDistinct], + node: FExpr[IsFirstDistinct | IsLastDistinct | IsUnique | IsDuplicated], frame: Frame, name: str, partition_by: Seq[ir.ExprIR] = (), *, sort_indices: pa.UInt64Array | None = None, ) -> Self: + # NOTE: This subset of functions can be expressed as a mask applied to indices + into_column_agg, mask = fn.BOOLEAN_LENGTH_PRESERVING[type(node.function)] idx_name = temp.column_name(frame) df = frame._with_columns([node.input[0].dispatch(self, frame, name)]) if sort_indices is not None: @@ -445,32 +587,44 @@ def _is_first_last_distinct( df = df._with_native(df.native.add_column(0, idx_name, column)) else: df = df.with_row_index(idx_name) - agg = fn.IS_FIRST_LAST_DISTINCT[type(node.function)](idx_name) + agg_node = into_column_agg(idx_name) if not (partition_by or sort_indices is not None): - distinct = df.group_by_names((name,)).agg((ir.named_ir(idx_name, agg),)) + aggregated = df.group_by_names((name,)).agg( + (ir.named_ir(idx_name, agg_node),) + ) else: - distinct = df.group_by_agg_irs((ir.col(name), *partition_by), agg) + aggregated = df.group_by_agg_irs((ir.col(name), *partition_by), agg_node) index = df.to_series().alias(name) - return self.from_series(index.is_in(distinct.get_column(idx_name))) + final_result = mask(index.native, aggregated.get_column(idx_name).native) + return self.from_series(index._with_native(final_result)) - # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` - def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: - if node.is_scalar: - # NOTE: Just trying to avoid redoing the whole API for `Series` - msg = "Only elementwise is currently supported" - raise NotImplementedError(msg) + # NOTE: Can't implement in `EagerExpr` (like on `main`) + # The version here is missing `__narwhals_namespace__` + def map_batches( + self, node: ir.AnonymousExpr, frame: Frame, name: str + ) -> Self | Scalar: series = self._dispatch_expr(node.input[0], frame, name) udf = node.function.function - result: Series | Into1DArray = udf(series) - if not isinstance(result, Series): - result = Series.from_numpy(result, name, version=self.version) + udf_result: Series | Iterable[Any] | Any = udf(series) + if node.is_scalar: + return ArrowScalar.from_unknown( + udf_result, name, dtype=node.function.return_dtype, version=self.version + ) + if isinstance(udf_result, Series): + result = udf_result + elif isinstance(udf_result, Iterable) and not is_iterable_reject(udf_result): + result = Series.from_iterable(udf_result, name=name, version=self.version) + else: + msg = ( + "`map_batches` with `returns_scalar=False` must return a Series; " + f"found '{qualified_type_name(udf_result)}'.\n\nIf `returns_scalar` " + "is set to `True`, a returned value can be a scalar value." + ) + raise TypeError(msg) if dtype := node.function.return_dtype: result = result.cast(dtype) return self.from_series(result) - def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self: - raise NotImplementedError - def shift(self, node: FExpr[Shift], frame: Frame, name: str) -> Self: return self._vector_function(fn.shift, node.function.n)(node, frame, name) @@ -482,32 +636,145 @@ def rank(self, node: FExpr[Rank], frame: Frame, name: str) -> Self: def _cumulative(self, node: FExpr[CumAgg], frame: Frame, name: str) -> Self: native = self._dispatch_expr(node.input[0], frame, name).native - func = fn.CUMULATIVE[type(node.function)] - if not node.function.reverse: - result = func(native) - else: - result = fn.reverse(func(fn.reverse(native))) - return self._with_native(result, name) + return self._with_native(fn.cumulative(native, node.function), name) + + def unique(self, node: FExpr[F.Unique], frame: Frame, name: str) -> Self: + return self.from_series(self._dispatch_expr(node.input[0], frame, name).unique()) + + def gather_every(self, node: FExpr[F.GatherEvery], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + n, offset = node.function.n, node.function.offset + return self.from_series(series.gather_every(n=n, offset=offset)) + + def sample_n(self, node: FExpr[F.SampleN], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + func = node.function + n, replace, seed = func.n, func.with_replacement, func.seed + result = series.sample_n(n, with_replacement=replace, seed=seed) + return self.from_series(result) + + def sample_frac(self, node: FExpr[F.SampleFrac], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + func = node.function + fraction, replace, seed = func.fraction, func.with_replacement, func.seed + result = series.sample_frac(fraction, with_replacement=replace, seed=seed) + return self.from_series(result) + + def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: + return self._vector_function(fn.drop_nulls)(node, frame, name) + + def mode_all(self, node: FExpr[F.ModeAll], frame: Frame, name: str) -> Self: + return self._vector_function(fn.mode_all)(node, frame, name) + + def mode_any(self, node: FExpr[F.ModeAny], frame: Frame, name: str) -> Scalar: + native = self._dispatch_expr(node.input[0], frame, name).native + return self._with_native(fn.mode_any(native), name) + + def fill_null_with_strategy( + self, node: FExpr[F.FillNullWithStrategy], frame: Frame, name: str + ) -> Self: + native = self._dispatch_expr(node.input[0], frame, name).native + strategy, limit = node.function.strategy, node.function.limit + func = fn.fill_null_with_strategy + return self._with_native(func(native, strategy, limit), name) cum_count = _cumulative cum_min = _cumulative cum_max = _cumulative cum_prod = _cumulative cum_sum = _cumulative - is_first_distinct = _is_first_last_distinct - is_last_distinct = _is_first_last_distinct + is_first_distinct = _boolean_length_preserving + is_last_distinct = _boolean_length_preserving + is_duplicated = _boolean_length_preserving + is_unique = _boolean_length_preserving + + _ROLLING: ClassVar[Mapping[type[F.RollingWindow], Callable[..., Series]]] = { + F.RollingSum: Series.rolling_sum, + F.RollingMean: Series.rolling_mean, + F.RollingVar: Series.rolling_var, + F.RollingStd: Series.rolling_std, + } + + def rolling_expr( + self, node: ir.RollingExpr[F.RollingWindow], frame: Frame, name: str + ) -> Self: + s = self._dispatch_expr(node.input[0], frame, name) + roll_options = node.function.options + size = roll_options.window_size + samples = roll_options.min_samples + center = roll_options.center + op = type(node.function) + method = self._ROLLING[op] + if op in {F.RollingSum, F.RollingMean}: + return self.from_series(method(s, size, min_samples=samples, center=center)) + ddof = roll_options.ddof + result = method(s, size, min_samples=samples, center=center, ddof=ddof) + return self.from_series(result) + + # NOTE: Should not be returning a struct when all `include_*` are false + # https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-ops/src/chunked_array/hist.rs#L223-L223 + def _hist_finish(self, data: Mapping[str, Any], name: str) -> Self: + ns = namespace(self) + if len(data) == 1: + count = next(iter(data.values())) + series = ns._series.from_iterable(count, version=self.version, name=name) + else: + series = ns._dataframe.from_dict(data, version=self.version).to_struct(name) + return self.from_series(series) + + def hist_bins(self, node: FExpr[F.HistBins], frame: Frame, name: str) -> Self: + native = self._dispatch_expr(node.input[0], frame, name).native + func = node.function + bins = func.bins + include = func.include_breakpoint + if len(bins) <= 1: + data = func.empty_data + elif fn.is_only_nulls(native, nan_is_null=True): + data = fn.hist_zeroed_data(bins, include_breakpoint=include) + else: + data = fn.hist_bins(native, bins, include_breakpoint=include) + return self._hist_finish(data, name) + + def hist_bin_count( + self, node: FExpr[F.HistBinCount], frame: Frame, name: str + ) -> Self: + native = self._dispatch_expr(node.input[0], frame, name).native + func = node.function + bin_count = func.bin_count + include = func.include_breakpoint + if bin_count == 0: + data = func.empty_data + elif fn.is_only_nulls(native, nan_is_null=True): + data = fn.hist_zeroed_data(bin_count, include_breakpoint=include) + else: + # NOTE: `Decimal` is not supported, but excluding it from the typing is surprisingly complicated + # https://docs.rs/polars-core/0.52.0/polars_core/datatypes/enum.DataType.html#method.is_primitive_numeric + lower: NativeScalar = fn.min_(native) + upper: NativeScalar = fn.max_(native) + if lower.equals(upper): + # All data points are identical - use unit interval + rhs = fn.lit(0.5) + lower, upper = fn.sub(lower, rhs), fn.add(upper, rhs) + bins = fn.linear_space(lower.as_py(), upper.as_py(), bin_count + 1) + data = fn.hist_bins(native, bins, include_breakpoint=include) + return self._hist_finish(data, name) # ewm_mean = not_implemented() # noqa: ERA001 - hist_bins = not_implemented() - hist_bin_count = not_implemented() - mode = not_implemented() - unique = not_implemented() - fill_null_with_strategy = not_implemented() - kurtosis = not_implemented() - skew = not_implemented() - gather_every = not_implemented() - is_duplicated = not_implemented() - is_unique = not_implemented() + @property + def cat(self) -> ArrowCatNamespace[Expr]: + return ArrowCatNamespace(self) + + @property + def list(self) -> ArrowListNamespace[Expr]: + return ArrowListNamespace(self) + + @property + def str(self) -> ArrowStringNamespace[Expr]: + return ArrowStringNamespace(self) + + @property + def struct(self) -> ArrowStructNamespace[Expr]: + return ArrowStructNamespace(self) class ArrowScalar( @@ -559,6 +826,23 @@ def from_series(cls, series: Series) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) + @classmethod + def from_unknown( + cls, + value: Any, + name: str = "literal", + /, + *, + dtype: IntoDType | None = None, + version: Version = Version.MAIN, + ) -> Self: + if isinstance(value, pa.Scalar): + return cls.from_native(value, name, version) + if is_python_literal(value): + return cls.from_python(value, name, dtype=dtype, version=version) + native = fn.lit(value, fn.dtype_native(dtype, version)) + return cls.from_native(native, name, version) + def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series: msg = f"Expected unreachable, but hit at: {node!r}" raise InvalidOperationError(msg) @@ -574,17 +858,15 @@ def to_series(self) -> Series: return self.broadcast(1) def to_python(self) -> PythonLiteral: - return self.native.as_py() # type: ignore[no-any-return] + result: PythonLiteral = self.native.as_py() + return result def broadcast(self, length: int) -> Series: scalar = self.native if length == 1: chunked = fn.chunked_array(scalar) else: - # NOTE: Same issue as `pa.scalar` overlapping overloads - # https://github.com/zen-xu/pyarrow-stubs/pull/209 - pa_repeat: Incomplete = pa.repeat - chunked = fn.chunked_array(pa_repeat(scalar, length)) + chunked = fn.chunked_array(fn.repeat_unchecked(scalar, length)) return Series.from_native(chunked, self.name, version=self.version) def count(self, node: Count, frame: Frame, name: str) -> Scalar: @@ -595,6 +877,31 @@ def null_count(self, node: FExpr[NullCount], frame: Frame, name: str) -> Self: native = node.input[0].dispatch(self, frame, name).native return self._with_native(pa.scalar(0 if native.is_valid else 1), name) + def drop_nulls( # type: ignore[override] + self, node: FExpr[F.DropNulls], frame: Frame, name: str + ) -> Scalar | Expr: + previous = node.input[0].dispatch(self, frame, name) + if previous.native.is_valid: + return previous + chunked = fn.chunked_array([[]], previous.native.type) + return ArrowExpr.from_native(chunked, name, version=self.version) + + @property + def cat(self) -> ArrowCatNamespace[Scalar]: + return ArrowCatNamespace(self) + + @property + def list(self) -> ArrowListNamespace[Scalar]: + return ArrowListNamespace(self) + + @property + def str(self) -> ArrowStringNamespace[Scalar]: + return ArrowStringNamespace(self) + + @property + def struct(self) -> ArrowStructNamespace[Scalar]: + return ArrowStructNamespace(self) + filter = not_implemented() over = not_implemented() over_ordered = not_implemented() @@ -608,3 +915,176 @@ def null_count(self, node: FExpr[NullCount], frame: Frame, name: str) -> Self: cum_min = not_implemented() cum_max = not_implemented() cum_prod = not_implemented() + + +ExprOrScalarT = TypeVar("ExprOrScalarT", ArrowExpr, ArrowScalar) + + +class ArrowAccessor(Generic[ExprOrScalarT]): + def __init__(self, compliant: ExprOrScalarT, /) -> None: + self._compliant: ExprOrScalarT = compliant + + @property + def compliant(self) -> ExprOrScalarT: + return self._compliant + + def __narwhals_namespace__(self) -> ArrowNamespace: + return namespace(self.compliant) + + @property + def version(self) -> Version: + return self.compliant.version + + def with_native(self, native: ChunkedOrScalarAny, name: str, /) -> Expr | Scalar: + return self.compliant._with_native(native, name) + + @overload + def unary( + self, fn_native: UnaryFunctionP[P], /, *args: P.args, **kwds: P.kwargs + ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: ... + @overload + def unary( + self, fn_native: Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny], / + ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: ... + def unary( + self, fn_native: UnaryFunctionP[P], /, *args: P.args, **kwds: P.kwargs + ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: + return self.compliant._unary_function(fn_native, *args, **kwds) + + +class ArrowCatNamespace(ExprCatNamespace["Frame", "Expr"], ArrowAccessor[ExprOrScalarT]): + def get_categories(self, node: FExpr[GetCategories], frame: Frame, name: str) -> Expr: + native = node.input[0].dispatch(self.compliant, frame, name).native + return ArrowExpr.from_native(fn.get_categories(native), name, self.version) + + +class ArrowListNamespace( + ExprListNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] +): + def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: + return self.unary(fn.list_len)(node, frame, name) + + def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: + return self.unary(fn.list_get, node.function.index)(node, frame, name) + + def unique(self, node: FExpr[lists.Unique], frame: Frame, name: str) -> Expr | Scalar: + return self.unary(fn.list_unique)(node, frame, name) + + def contains( + self, node: FExpr[lists.Contains], frame: Frame, name: str + ) -> Expr | Scalar: + func = node.function + expr, other = func.unwrap_input(node) + prev = expr.dispatch(self.compliant, frame, name) + item = other.dispatch(self.compliant, frame, name) + if isinstance(item, ArrowExpr): + # Maybe one day, not now + raise NotImplementedError + return self.with_native(fn.list_contains(prev.native, item.native), name) + + def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scalar: + separator, ignore_nulls = node.function.separator, node.function.ignore_nulls + previous = node.input[0].dispatch(self.compliant, frame, name) + result: ChunkedOrScalarAny + if isinstance(previous, ArrowExpr): + result = fn.list_join(previous.native, separator, ignore_nulls=ignore_nulls) + else: + result = fn.list_join_scalar( + previous.native, separator, ignore_nulls=ignore_nulls + ) + return self.with_native(result, name) + + +class ArrowStringNamespace( + ExprStringNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] +): + def len_chars( + self, node: FExpr[strings.LenChars], frame: Frame, name: str + ) -> Expr | Scalar: + return self.unary(fn.str_len_chars)(node, frame, name) + + def slice(self, node: FExpr[strings.Slice], frame: Frame, name: str) -> Expr | Scalar: + offset, length = node.function.offset, node.function.length + return self.unary(fn.str_slice, offset, length)(node, frame, name) + + def zfill(self, node: FExpr[strings.ZFill], frame: Frame, name: str) -> Expr | Scalar: + return self.unary(fn.str_zfill, node.function.length)(node, frame, name) + + def contains( + self, node: FExpr[strings.Contains], frame: Frame, name: str + ) -> Expr | Scalar: + pattern, literal = node.function.pattern, node.function.literal + return self.unary(fn.str_contains, pattern, literal=literal)(node, frame, name) + + def ends_with( + self, node: FExpr[strings.EndsWith], frame: Frame, name: str + ) -> Expr | Scalar: + return self.unary(fn.str_ends_with, node.function.suffix)(node, frame, name) + + def replace( + self, node: FExpr[strings.Replace], frame: Frame, name: str + ) -> Expr | Scalar: + func = node.function + pattern, literal, n = (func.pattern, func.literal, func.n) + expr, other = func.unwrap_input(node) + prev = expr.dispatch(self.compliant, frame, name) + value = other.dispatch(self.compliant, frame, name) + if isinstance(value, ArrowScalar): + result = fn.str_replace( + prev.native, pattern, value.native.as_py(), literal=literal, n=n + ) + elif isinstance(prev, ArrowExpr): + result = fn.str_replace_vector( + prev.native, pattern, value.native, literal=literal, n=n + ) + else: + # not sure this even makes sense + msg = "TODO: `ArrowScalar.str.replace(value: ArrowExpr)`" + raise NotImplementedError(msg) + return self.with_native(result, name) + + def replace_all( + self, node: FExpr[strings.ReplaceAll], frame: Frame, name: str + ) -> Expr | Scalar: + rewrite: FExpr[Any] = common.replace( + node, function=node.function.to_replace_n(-1) + ) + return self.replace(rewrite, frame, name) + + def split(self, node: FExpr[strings.Split], frame: Frame, name: str) -> Expr | Scalar: + return self.unary(fn.str_split, node.function.by)(node, frame, name) + + def starts_with( + self, node: FExpr[strings.StartsWith], frame: Frame, name: str + ) -> Expr | Scalar: + return self.unary(fn.str_starts_with, node.function.prefix)(node, frame, name) + + def strip_chars( + self, node: FExpr[strings.StripChars], frame: Frame, name: str + ) -> Expr | Scalar: + return self.unary(fn.str_strip_chars, node.function.characters)(node, frame, name) + + def to_uppercase( + self, node: FExpr[strings.ToUppercase], frame: Frame, name: str + ) -> Expr | Scalar: + return self.unary(fn.str_to_uppercase)(node, frame, name) + + def to_lowercase( + self, node: FExpr[strings.ToLowercase], frame: Frame, name: str + ) -> Expr | Scalar: + return self.unary(fn.str_to_lowercase)(node, frame, name) + + def to_titlecase( + self, node: FExpr[strings.ToTitlecase], frame: Frame, name: str + ) -> Expr | Scalar: + return self.unary(fn.str_to_titlecase)(node, frame, name) + + to_date = not_implemented() + to_datetime = not_implemented() + + +class ArrowStructNamespace( + ExprStructNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] +): + def field(self, node: FExpr[FieldByName], frame: Frame, name: str) -> Expr | Scalar: + return self.unary(fn.struct_field, node.function.name)(node, frame, name) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index a17cb0ebc7..42b44983b5 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -2,8 +2,10 @@ from __future__ import annotations +import math import typing as t -from collections.abc import Callable, Sequence +from collections.abc import Callable, Collection, Iterator, Sequence +from itertools import chain from typing import TYPE_CHECKING, Any, Final, Literal, overload import pyarrow as pa # ignore-banned-import @@ -12,75 +14,156 @@ from narwhals._arrow.utils import ( cast_for_truediv, chunked_array as _chunked_array, - floordiv_compat as floordiv, + concat_tables as concat_tables, # noqa: PLC0414 + floordiv_compat as _floordiv, + narwhals_to_native_dtype as _dtype_native, ) -from narwhals._plan import expressions as ir +from narwhals._plan import common, expressions as ir +from narwhals._plan._guards import is_non_nested_literal from narwhals._plan.arrow import options as pa_options from narwhals._plan.expressions import functions as F, operators as ops -from narwhals._utils import Implementation +from narwhals._plan.options import ExplodeOptions +from narwhals._utils import Implementation, Version, no_default +from narwhals.exceptions import ShapeError if TYPE_CHECKING: import datetime as dt from collections.abc import Iterable, Mapping - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import Self, TypeAlias, TypeIs, TypeVarTuple, Unpack - from narwhals._arrow.typing import Incomplete, PromoteOptions + from narwhals._arrow.typing import Incomplete + from narwhals._plan.arrow.acero import Field from narwhals._plan.arrow.typing import ( Array, ArrayAny, + Arrow, ArrowAny, + ArrowListT, + ArrowT, BinaryComp, + BinaryFunction, BinaryLogical, BinaryNumericTemporal, BinOp, + BooleanLengthPreserving, + BooleanScalar, + BoolType, ChunkedArray, ChunkedArrayAny, + ChunkedList, ChunkedOrArray, ChunkedOrArrayAny, ChunkedOrArrayT, ChunkedOrScalar, ChunkedOrScalarAny, + ChunkedOrScalarT, + ChunkedStruct, DataType, DataTypeRemap, DataTypeT, DateScalar, IntegerScalar, IntegerType, - LargeStringType, + ListArray, + ListScalar, + ListTypeT, NativeScalar, + NonListTypeT, + NumericScalar, + Predicate, + SameArrowT, Scalar, ScalarAny, ScalarT, StringScalar, StringType, + StructArray, + UInt32Type, UnaryFunction, + UnaryNumeric, + VectorFunction, ) + from narwhals._plan.compliant.typing import SeriesT from narwhals._plan.options import RankOptions, SortMultipleOptions, SortOptions - from narwhals.typing import ClosedInterval, IntoArrowSchema, PythonLiteral + from narwhals._plan.typing import Seq + from narwhals._typing import NoDefault + from narwhals.typing import ( + ClosedInterval, + FillNullStrategy, + IntoArrowSchema, + IntoDType, + NonNestedLiteral, + NumericLiteral, + PythonLiteral, + ) + + Ts = TypeVarTuple("Ts") BACKEND_VERSION = Implementation.PYARROW._backend_version() """Static backend version for `pyarrow`.""" RANK_ACCEPTS_CHUNKED: Final = BACKEND_VERSION >= (14,) +HAS_FROM_TO_STRUCT_ARRAY: Final = BACKEND_VERSION >= (15,) +"""`pyarrow.Table.{from,to}_struct_array` added in https://github.com/apache/arrow/pull/38520""" + +HAS_STRUCT_TYPE_FIELDS: Final = BACKEND_VERSION >= (18,) +"""`pyarrow.StructType.fields` added in https://github.com/apache/arrow/pull/43481""" + HAS_SCATTER: Final = BACKEND_VERSION >= (20,) """`pyarrow.compute.scatter` added in https://github.com/apache/arrow/pull/44394""" +HAS_KURTOSIS_SKEW = BACKEND_VERSION >= (20,) +"""`pyarrow.compute.{kurtosis,skew}` added in https://github.com/apache/arrow/pull/45677""" + HAS_ARANGE: Final = BACKEND_VERSION >= (21,) """`pyarrow.arange` added in https://github.com/apache/arrow/pull/46778""" +TO_STRUCT_ARRAY_ACCEPTS_EMPTY: Final = BACKEND_VERSION >= (21,) +"""`pyarrow.Table.to_struct_array` fixed in https://github.com/apache/arrow/pull/46357""" + +HAS_ZFILL: Final = BACKEND_VERSION >= (21,) +"""`pyarrow.compute.utf8_zero_fill` added in https://github.com/apache/arrow/pull/46815""" +# NOTE: Common data type instances to share +UI32: Final = pa.uint32() I64: Final = pa.int64() F64: Final = pa.float64() +BOOL: Final = pa.bool_() + +EMPTY: Final = "" +"""The empty string.""" + + +class MinMax(ir.AggExpr): + """Returns a `Struct({'min': ..., 'max': ...})`. + + https://arrow.apache.org/docs/python/generated/pyarrow.compute.min_max.html#pyarrow.compute.min_max + """ + IntoColumnAgg: TypeAlias = Callable[[str], ir.AggExpr] """Helper constructor for single-column aggregations.""" -is_null = pc.is_null -is_not_null = t.cast("UnaryFunction[ScalarAny,pa.BooleanScalar]", pc.is_valid) -is_nan = pc.is_nan -is_finite = pc.is_finite +is_null = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_null) +is_not_null = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_valid) +is_nan = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_nan) +is_finite = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_finite) +not_ = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.invert) + + +@overload +def is_not_nan(native: ChunkedArrayAny) -> ChunkedArray[pa.BooleanScalar]: ... +@overload +def is_not_nan(native: ScalarAny) -> pa.BooleanScalar: ... +@overload +def is_not_nan(native: ChunkedOrScalarAny) -> ChunkedOrScalar[pa.BooleanScalar]: ... +@overload +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: ... +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: + return not_(is_nan(native)) + and_ = t.cast("BinaryLogical", pc.and_kleene) or_ = t.cast("BinaryLogical", pc.or_kleene) @@ -95,15 +178,22 @@ add = t.cast("BinaryNumericTemporal", pc.add) -sub = pc.subtract +sub = t.cast("BinaryNumericTemporal", pc.subtract) multiply = pc.multiply +power = t.cast("BinaryFunction[NumericScalar, NumericScalar]", pc.power) +floordiv = _floordiv +abs_ = t.cast("UnaryNumeric", pc.abs) +exp = t.cast("UnaryNumeric", pc.exp) +sqrt = t.cast("UnaryNumeric", pc.sqrt) +ceil = t.cast("UnaryNumeric", pc.ceil) +floor = t.cast("UnaryNumeric", pc.floor) -def truediv(lhs: Any, rhs: Any) -> Any: +def truediv(lhs: Incomplete, rhs: Incomplete) -> Incomplete: return pc.divide(*cast_for_truediv(lhs, rhs)) -def modulus(lhs: Any, rhs: Any) -> Any: +def modulus(lhs: Incomplete, rhs: Incomplete) -> Incomplete: floor_div = floordiv(lhs, rhs) return sub(lhs, multiply(floor_div, rhs)) @@ -131,16 +221,44 @@ def modulus(lhs: Any, rhs: Any) -> Any: ops.ExclusiveOr: xor, } + +def bin_op( + function: Callable[[Any, Any], Any], /, *, reflect: bool = False +) -> Callable[[SeriesT, Any], SeriesT]: + """Attach a binary operator to `ArrowSeries`.""" + + def f(self: SeriesT, other: SeriesT | Any, /) -> SeriesT: + right = other.native if isinstance(other, type(self)) else lit(other) + return self._with_native(function(self.native, right)) + + def f_reflect(self: SeriesT, other: SeriesT | Any, /) -> SeriesT: + if isinstance(other, type(self)): + name = other.name + right: ArrowAny = other.native + else: + name = "literal" + right = lit(other) + return self.from_native(function(right, self.native), name, version=self.version) + + return f_reflect if reflect else f + + _IS_BETWEEN: Mapping[ClosedInterval, tuple[BinaryComp, BinaryComp]] = { "left": (gt_eq, lt), "right": (gt, lt_eq), "none": (gt, lt), "both": (gt_eq, lt_eq), } -IS_FIRST_LAST_DISTINCT: Mapping[type[ir.boolean.BooleanFunction], IntoColumnAgg] = { - ir.boolean.IsFirstDistinct: ir.min, - ir.boolean.IsLastDistinct: ir.max, -} + + +@t.overload +def dtype_native(dtype: IntoDType, version: Version) -> pa.DataType: ... +@t.overload +def dtype_native(dtype: None, version: Version) -> None: ... +@t.overload +def dtype_native(dtype: IntoDType | None, version: Version) -> pa.DataType | None: ... +def dtype_native(dtype: IntoDType | None, version: Version) -> pa.DataType | None: + return dtype if dtype is None else _dtype_native(dtype, version) @t.overload @@ -192,7 +310,7 @@ def has_large_string(data_types: Iterable[DataType], /) -> bool: return any(pa.types.is_large_string(tp) for tp in data_types) -def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStringType: +def string_type(data_types: Iterable[DataType] = (), /) -> StringType: """Return a native string type, compatible with `data_types`. Until [apache/arrow#45717] is resolved, we need to upcast `string` to `large_string` when joining. @@ -202,23 +320,777 @@ def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStr return pa.large_string() if has_large_string(data_types) else pa.string() -def any_(native: Any) -> pa.BooleanScalar: - return pc.any(native, min_count=0) +# NOTE: `mypy` isn't happy, but this broadcasting behavior is worth documenting +@t.overload +def struct(names: Iterable[str], columns: Iterable[ChunkedArrayAny]) -> ChunkedStruct: ... +@t.overload +def struct(names: Iterable[str], columns: Iterable[ArrayAny]) -> pa.StructArray: ... +@t.overload +def struct( # type: ignore[overload-overlap] + names: Iterable[str], columns: Iterable[ScalarAny] | Iterable[NonNestedLiteral] +) -> pa.StructScalar: ... +@t.overload +def struct( # type: ignore[overload-overlap] + names: Iterable[str], columns: Iterable[ChunkedArrayAny | NonNestedLiteral] +) -> ChunkedStruct: ... +@t.overload +def struct( + names: Iterable[str], columns: Iterable[ArrayAny | NonNestedLiteral] +) -> pa.StructArray: ... +@t.overload +def struct(names: Iterable[str], columns: Iterable[ArrowAny]) -> Incomplete: ... +def struct(names: Iterable[str], columns: Iterable[Incomplete]) -> Incomplete: + """Collect columns into a struct. + + Arguments: + names: Names of the struct fields to create. + columns: Value(s) to collect into a struct. Scalars will will be broadcast unless all + inputs are scalar. + """ + return pc.make_struct( + *columns, options=pc.MakeStructOptions(common.ensure_seq_str(names)) + ) + + +def struct_schema(native: Arrow[pa.StructScalar] | pa.StructType) -> pa.Schema: + """Get the struct definition as a schema.""" + tp = native.type if _is_arrow(native) else native + fields = tp.fields if HAS_STRUCT_TYPE_FIELDS else list(tp) + return pa.schema(fields) + + +@t.overload +def struct_field(native: ChunkedStruct, field: Field, /) -> ChunkedArrayAny: ... +@t.overload +def struct_field(native: StructArray, field: Field, /) -> ArrayAny: ... +@t.overload +def struct_field(native: pa.StructScalar, field: Field, /) -> ScalarAny: ... +@t.overload +def struct_field(native: SameArrowT, field: Field, /) -> SameArrowT: ... +@t.overload +def struct_field(native: ChunkedOrScalarAny, field: Field, /) -> ChunkedOrScalarAny: ... +def struct_field(native: ArrowAny, field: Field, /) -> ArrowAny: + """Retrieve one `Struct` field.""" + func = t.cast("Callable[[Any,Any], ArrowAny]", pc.struct_field) + return func(native, field) + + +@t.overload +def struct_fields(native: ChunkedStruct, *fields: Field) -> Seq[ChunkedArrayAny]: ... +@t.overload +def struct_fields(native: StructArray, *fields: Field) -> Seq[ArrayAny]: ... +@t.overload +def struct_fields(native: pa.StructScalar, *fields: Field) -> Seq[ScalarAny]: ... +@t.overload +def struct_fields(native: SameArrowT, *fields: Field) -> Seq[SameArrowT]: ... +def struct_fields(native: ArrowAny, *fields: Field) -> Seq[ArrowAny]: + """Retrieve multiple `Struct` fields.""" + func = t.cast("Callable[[Any,Any], ArrowAny]", pc.struct_field) + return tuple(func(native, name) for name in fields) + + +def get_categories(native: ArrowAny) -> ChunkedArrayAny: + da: Incomplete + if isinstance(native, pa.ChunkedArray): + da = native.unify_dictionaries().chunk(0) + else: + da = native + return chunked_array(da.dictionary) + + +class ExplodeBuilder: + """Tools for exploding lists. + + The complexity of these operations increases with: + - Needing to preserve null/empty elements + - All variants are cheaper if this can be skipped + - Exploding in the context of a table + - Where a single column is much simpler than multiple + """ + + options: ExplodeOptions + + def __init__(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> None: + self.options = ExplodeOptions(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + + @classmethod + def from_options(cls, options: ExplodeOptions, /) -> Self: + obj = cls.__new__(cls) + obj.options = options + return obj + + @t.overload + def explode( + self, native: ChunkedList[DataTypeT] | ListScalar[DataTypeT] + ) -> ChunkedArray[Scalar[DataTypeT]]: ... + @t.overload + def explode(self, native: ListArray[DataTypeT]) -> Array[Scalar[DataTypeT]]: ... + @t.overload + def explode( + self, native: Arrow[ListScalar[DataTypeT]] + ) -> ChunkedOrArray[Scalar[DataTypeT]]: ... + def explode( + self, native: Arrow[ListScalar[DataTypeT]] + ) -> ChunkedOrArray[Scalar[DataTypeT]]: + """Explode list elements, expanding one-level into a new array. + + Equivalent to `polars.{Expr,Series}.explode`. + """ + safe = self._fill_with_null(native) if self.options.any() else native + if not isinstance(safe, pa.Scalar): + return _list_explode(safe) + return chunked_array(_list_explode(safe)) + + def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table: + safe = self._fill_with_null(native) if self.options.any() else native + arrays = [_list_parent_indices(safe), _list_explode(safe)] + return concat_horizontal(arrays, ["idx", "values"]) + + def explode_column(self, native: pa.Table, column_name: str, /) -> pa.Table: + """Explode a list-typed column in the context of `native`.""" + ca = native.column(column_name) + if native.num_columns == 1: + return native.from_arrays([self.explode(ca)], [column_name]) + safe = self._fill_with_null(ca) if self.options.any() else ca + exploded = _list_explode(safe) + col_idx = native.schema.get_field_index(column_name) + if len(exploded) == len(native): + return native.set_column(col_idx, column_name, exploded) + return ( + native.remove_column(col_idx) + .take(_list_parent_indices(safe)) + .add_column(col_idx, column_name, exploded) + ) + + def explode_columns(self, native: pa.Table, subset: Collection[str], /) -> pa.Table: + """Explode multiple list-typed columns in the context of `native`.""" + subset = list(subset) + arrays = native.select(subset).columns + first = arrays[0] + first_len = list_len(first) + if self.options.any(): + mask = self._predicate(first_len) + first_safe = self._fill_with_null(first, mask) + it = ( + _list_explode(self._fill_with_null(arr, mask)) + for arr in self._iter_ensure_shape(first_len, arrays[1:]) + ) + else: + first_safe = first + it = ( + _list_explode(arr) + for arr in self._iter_ensure_shape(first_len, arrays[1:]) + ) + column_names = native.column_names + result = native + first_result = _list_explode(first_safe) + if len(first_result) == len(native): + # fastpath for all length-1 lists + # if only the first is length-1, then the others raise during iteration on either branch + for name, arr in zip(subset, chain([first_result], it)): + result = result.set_column(column_names.index(name), name, arr) + else: + result = result.drop_columns(subset).take(_list_parent_indices(first_safe)) + for name, arr in zip(subset, chain([first_result], it)): + result = result.append_column(name, arr) + result = result.select(column_names) + return result + + @classmethod + def explode_column_fast(cls, native: pa.Table, column_name: str, /) -> pa.Table: + """Explode a list-typed column in the context of `native`, ignoring empty and nulls.""" + return cls(empty_as_null=False, keep_nulls=False).explode_column( + native, column_name + ) + + def _iter_ensure_shape( + self, + first_len: ChunkedArray[pa.UInt32Scalar], + arrays: Iterable[ChunkedArrayAny], + /, + ) -> Iterator[ChunkedArrayAny]: + for arr in arrays: + if not first_len.equals(list_len(arr)): + msg = "exploded columns must have matching element counts" + raise ShapeError(msg) + yield arr + + def _predicate(self, lengths: ArrowAny, /) -> Arrow[pa.BooleanScalar]: + """Return True for each sublist length that indicates the original sublist should be replaced with `[None]`.""" + empty_as_null, keep_nulls = self.options.empty_as_null, self.options.keep_nulls + if empty_as_null and keep_nulls: + return or_(is_null(lengths), eq(lengths, lit(0))) + if empty_as_null: + return eq(lengths, lit(0)) + return is_null(lengths) + + def _fill_with_null( + self, native: ArrowListT, mask: Arrow[BooleanScalar] | NoDefault = no_default + ) -> ArrowListT: + """Replace each sublist in `native` with `[None]`, according to `self.options`. + + Arguments: + native: List-typed arrow data. + mask: An optional, pre-computed replacement mask. By default, this is generated from `native`. + """ + predicate = self._predicate(list_len(native)) if mask is no_default else mask + result: ArrowListT = when_then(predicate, lit([None], native.type), native) + return result + + +@t.overload +def _list_explode(native: ChunkedList[DataTypeT]) -> ChunkedArray[Scalar[DataTypeT]]: ... +@t.overload +def _list_explode( + native: ListArray[NonListTypeT] | ListScalar[NonListTypeT], +) -> Array[Scalar[NonListTypeT]]: ... +@t.overload +def _list_explode(native: ListArray[DataTypeT]) -> Array[Scalar[DataTypeT]]: ... +@t.overload +def _list_explode(native: ListScalar[ListTypeT]) -> ListArray[ListTypeT]: ... +def _list_explode(native: Arrow[ListScalar]) -> ChunkedOrArrayAny: + result: ChunkedOrArrayAny = pc.call_function("list_flatten", [native]) + return result + + +@t.overload +def _list_parent_indices(native: ChunkedList) -> ChunkedArray[pa.Int64Scalar]: ... +@t.overload +def _list_parent_indices(native: ListArray) -> pa.Int64Array: ... +def _list_parent_indices( + native: ChunkedOrArray[ListScalar], +) -> ChunkedOrArray[pa.Int64Scalar]: + """Don't use this withut handling nulls!""" + result: ChunkedOrArray[pa.Int64Scalar] = pc.call_function( + "list_parent_indices", [native] + ) + return result + + +@t.overload +def list_len(native: ChunkedList) -> ChunkedArray[pa.UInt32Scalar]: ... +@t.overload +def list_len(native: ListArray) -> pa.UInt32Array: ... +@t.overload +def list_len(native: ListScalar) -> pa.UInt32Scalar: ... +@t.overload +def list_len(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[pa.UInt32Scalar]: ... +@t.overload +def list_len(native: Arrow[ListScalar[Any]]) -> Arrow[pa.UInt32Scalar]: ... +def list_len(native: ArrowAny) -> ArrowAny: + length: Incomplete = pc.list_value_length + result: ArrowAny = length(native).cast(pa.uint32()) + return result + + +@t.overload +def list_get( + native: ChunkedList[DataTypeT], index: int +) -> ChunkedArray[Scalar[DataTypeT]]: ... +@t.overload +def list_get(native: ListArray[DataTypeT], index: int) -> Array[Scalar[DataTypeT]]: ... +@t.overload +def list_get(native: ListScalar[DataTypeT], index: int) -> Scalar[DataTypeT]: ... +@t.overload +def list_get(native: SameArrowT, index: int) -> SameArrowT: ... +@t.overload +def list_get(native: ChunkedOrScalarAny, index: int) -> ChunkedOrScalarAny: ... +def list_get(native: ArrowAny, index: int) -> ArrowAny: + list_get_: Incomplete = pc.list_element + result: ArrowAny = list_get_(native, index) + return result + + +_list_join = t.cast( + "Callable[[ChunkedOrArrayAny, Arrow[StringScalar] | str], ChunkedArray[StringScalar] | pa.StringArray]", + pc.binary_join, +) + + +# NOTE: Raised for native null-handling (https://github.com/apache/arrow/issues/48477) +@t.overload +def list_join( + native: ChunkedList[StringType], + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = ..., +) -> ChunkedArray[StringScalar]: ... +@t.overload +def list_join( + native: ListArray[StringType], + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = ..., +) -> pa.StringArray: ... +@t.overload +def list_join( + native: ChunkedOrArray[ListScalar[StringType]], + separator: str, + *, + ignore_nulls: bool = ..., +) -> ChunkedOrArray[StringScalar]: ... +def list_join( + native: ChunkedOrArrayAny, + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = True, +) -> ChunkedOrArrayAny: + """Join all string items in a sublist and place a separator between them. + + Each list of values in the first input is joined using each second input as separator. + If any input list is null or contains a null, the corresponding output will be null. + """ + from narwhals._plan.arrow.group_by import AggSpec + + # (1): Try to return *as-is* from `pc.binary_join` + result = _list_join(native, separator) + if not ignore_nulls or not result.null_count: + return result + is_null_sensitive = pc.and_not(result.is_null(), native.is_null()) + if array(is_null_sensitive, BOOL).true_count == 0: + return result + + # (2): Deal with only the bad kids + lists = native.filter(is_null_sensitive) + + # (2.1): We know that `[None]` should join as `""`, and that is the only length-1 list we could have after the filter + list_len_eq_1 = eq(list_len(lists), lit(1, UI32)) + has_a_len_1_null = any_(list_len_eq_1).as_py() + if has_a_len_1_null: + lists = when_then(list_len_eq_1, lit([EMPTY], lists.type), lists) + + # (2.2): Everything left falls into one of these boxes: + # - (2.1): `[""]` + # - (2.2): `["something", (str | None)*, None]` <--- We fix this here and hope for the best + # - (2.3): `[None, (None)*, None]` + idx, v = "idx", "values" + builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) + explode_w_idx = builder.explode_with_indices(lists) + implode_by_idx = AggSpec.implode(v).over(explode_w_idx.drop_null(), [idx]) + replacements = _list_join(implode_by_idx.column(v), separator) + + # (2.3): The cursed box 😨 + if len(replacements) != len(lists): + # This is a very unlucky case to hit, because we *can* detect the issue earlier + # but we *can't* join a table with a list in it. So we deal with the fallout now ... + # The end result is identical to (2.1) + indices_all = to_table(explode_w_idx.column(idx).unique(), idx) + indices_repaired = implode_by_idx.set_column(1, v, replacements) + replacements = ( + indices_all.join(indices_repaired, idx) + .sort_by(idx) + .column(v) + .fill_null(lit(EMPTY, lists.type.value_type)) + ) + return replace_with_mask(result, is_null_sensitive, replacements) + + +def list_join_scalar( + native: ListScalar[StringType], + separator: StringScalar | str, + *, + ignore_nulls: bool = True, +) -> StringScalar: + """Join all string items in a `ListScalar` and place a separator between them. + + Note: + Consider using `list_join` or `str_join` if you don't already have `native` in this shape. + """ + if ignore_nulls and native.is_valid: + native = implode(_list_explode(native).drop_null()) + result: StringScalar = pc.call_function("binary_join", [native, separator]) + return result + + +@overload +def list_unique(native: ChunkedList) -> ChunkedList: ... +@overload +def list_unique(native: ListScalar) -> ListScalar: ... +@overload +def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: ... +def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: + """Get the unique/distinct values in the list. + + There's lots of tricky stuff going on in here, but for good reasons! + + Whenever possible, we want to avoid having to deal with these pesky guys: + + [["okay", None, "still fine"], None, []] + # ^^^^ ^^ + + - Those kinds of list elements are ignored natively + - `unique` is length-changing operation + - We can't use [`pc.replace_with_mask`] on a list + - We can't join when a table contains list columns [apache/arrow#43716] + + **But** - if we're lucky, and we got a non-awful list (or only one element) - then + most issues vanish. + + [`pc.replace_with_mask`]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.replace_with_mask.html + [apache/arrow#43716]: https://github.com/apache/arrow/issues/43716 + """ + from narwhals._plan.arrow.group_by import AggSpec + + if isinstance(native, pa.Scalar): + scalar = t.cast("pa.ListScalar[Any]", native) + if scalar.is_valid and (len(scalar) > 1): + return implode(_list_explode(native).unique()) + return scalar + idx, v = "index", "values" + names = idx, v + len_not_eq_0 = not_eq(list_len(native), lit(0)) + can_fastpath = all_(len_not_eq_0, ignore_nulls=False).as_py() + if can_fastpath: + arrays = [_list_parent_indices(native), _list_explode(native)] + return AggSpec.unique(v).over_index(concat_horizontal(arrays, names), idx) + # Oh no - we caught a bad one! + # We need to split things into good/bad - and only work on the good stuff. + # `int_range` is acting like `parent_indices`, but doesn't give up when it see's `None` or `[]` + indexed = concat_horizontal([int_range(len(native)), native], names) + valid = indexed.filter(len_not_eq_0) + invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) + # To keep track of where we started, our index needs to be exploded with the list elements + explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) + valid_unique = AggSpec.unique(v).over(explode_with_index, [idx]) + # And now, because we can't join - we do a poor man's version of one 😉 + return concat_tables([valid_unique, invalid]).sort_by(idx).column(v) + + +def list_contains( + native: ChunkedOrScalar[ListScalar], item: NonNestedLiteral | ScalarAny +) -> ChunkedOrScalar[pa.BooleanScalar]: + from narwhals._plan.arrow.group_by import AggSpec + + if isinstance(native, pa.Scalar): + scalar = t.cast("pa.ListScalar[Any]", native) + if scalar.is_valid: + if len(scalar): + value_type = scalar.type.value_type + return any_(eq_missing(_list_explode(scalar), lit(item).cast(value_type))) + return lit(False, BOOL) + return lit(None, BOOL) + builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) + tbl = builder.explode_with_indices(native) + idx, name = tbl.column_names + contains = eq_missing(tbl.column(name), item) + l_contains = AggSpec.any(name).over_index(tbl.set_column(1, name, contains), idx) + # Here's the really key part: this mask has the same result we want to return + # So by filling the `True`, we can flip those to `False` if needed + # But if we were already `None` or `False` - then that's sticky + propagate_invalid: ChunkedArray[pa.BooleanScalar] = not_eq(list_len(native), lit(0)) + return replace_with_mask(propagate_invalid, propagate_invalid, l_contains) + + +def implode(native: Arrow[Scalar[DataTypeT]]) -> ListScalar[DataTypeT]: + """Aggregate values into a list. + + The returned list itself is a scalar value of `list` dtype. + """ + arr = array(native) + return pa.ListArray.from_arrays([0, len(arr)], arr)[0] + + +def str_join( + native: Arrow[StringScalar], separator: str, *, ignore_nulls: bool = True +) -> StringScalar: + """Vertically concatenate the string values in the column to a single string value.""" + if isinstance(native, pa.Scalar): + # already joined + return native + if ignore_nulls and native.null_count: + native = native.drop_null() + return list_join_scalar(implode(native), separator, ignore_nulls=False) + + +def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: + len_chars: Incomplete = pc.utf8_length + result: ChunkedOrScalarAny = len_chars(native) + return result + + +def str_slice( + native: ChunkedOrScalarAny, offset: int, length: int | None = None +) -> ChunkedOrScalarAny: + stop = length if length is None else offset + length + return pc.utf8_slice_codeunits(native, offset, stop=stop) + + +def str_pad_start( + native: ChunkedOrScalarAny, length: int, fill_char: str = " " +) -> ChunkedOrScalarAny: # pragma: no cover + return pc.utf8_lpad(native, length, fill_char) + + +@t.overload +def str_find( + native: ChunkedArrayAny, + pattern: str, + *, + literal: bool = ..., + not_found: int | None = ..., +) -> ChunkedArray[IntegerScalar]: ... +@t.overload +def str_find( + native: Array, pattern: str, *, literal: bool = ..., not_found: int | None = ... +) -> Array[IntegerScalar]: ... +@t.overload +def str_find( + native: ScalarAny, pattern: str, *, literal: bool = ..., not_found: int | None = ... +) -> IntegerScalar: ... +def str_find( + native: Arrow[StringScalar], + pattern: str, + *, + literal: bool = False, + not_found: int | None = -1, +) -> Arrow[IntegerScalar]: + """Return the bytes offset of the first substring matching a pattern. + + To match `pl.Expr.str.find` behavior, pass `not_found=None`. + + Note: + `pyarrow` distinguishes null *inputs* with `None` and failed matches with `-1`. + """ + # NOTE: `pyarrow-stubs` uses concrete types here + fn_name = "find_substring" if literal else "find_substring_regex" + result: Arrow[IntegerScalar] = pc.call_function( + fn_name, [native], pa_options.match_substring(pattern) + ) + if not_found == -1: + return result + return when_then(eq(result, lit(-1)), lit(not_found, result.type), result) + + +_StringFunction0: TypeAlias = "Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny]" +_StringFunction1: TypeAlias = "Callable[[ChunkedOrScalarAny, str], ChunkedOrScalarAny]" +str_starts_with = t.cast("_StringFunction1", pc.starts_with) +str_ends_with = t.cast("_StringFunction1", pc.ends_with) +str_to_uppercase = t.cast("_StringFunction0", pc.utf8_upper) +str_to_lowercase = t.cast("_StringFunction0", pc.utf8_lower) +str_to_titlecase = t.cast("_StringFunction0", pc.utf8_title) + + +def _str_split( + native: ArrowAny, by: str, n: int | None = None, *, literal: bool = True +) -> Arrow[ListScalar]: + name = "split_pattern" if literal else "split_pattern_regex" + result: Arrow[ListScalar] = pc.call_function( + name, [native], pa_options.split_pattern(by, n) + ) + return result + + +@t.overload +def str_split( + native: ChunkedArrayAny, by: str, *, literal: bool = ... +) -> ChunkedArray[ListScalar]: ... +@t.overload +def str_split( + native: ChunkedOrScalarAny, by: str, *, literal: bool = ... +) -> ChunkedOrScalar[ListScalar]: ... +@t.overload +def str_split(native: ArrayAny, by: str, *, literal: bool = ...) -> pa.ListArray[Any]: ... +@t.overload +def str_split(native: ArrowAny, by: str, *, literal: bool = ...) -> Arrow[ListScalar]: ... +def str_split(native: ArrowAny, by: str, *, literal: bool = True) -> Arrow[ListScalar]: + return _str_split(native, by, literal=literal) + + +@t.overload +def str_splitn( + native: ChunkedArrayAny, + by: str, + n: int, + *, + literal: bool = ..., + as_struct: bool = ..., +) -> ChunkedArray[ListScalar]: ... +@t.overload +def str_splitn( + native: ChunkedOrScalarAny, + by: str, + n: int, + *, + literal: bool = ..., + as_struct: bool = ..., +) -> ChunkedOrScalar[ListScalar]: ... +@t.overload +def str_splitn( + native: ArrayAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... +) -> pa.ListArray[Any]: ... +@t.overload +def str_splitn( + native: ArrowAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... +) -> Arrow[ListScalar]: ... +def str_splitn( + native: ArrowAny, by: str, n: int, *, literal: bool = True, as_struct: bool = False +) -> Arrow[ListScalar]: + """Split the string by a substring, restricted to returning at most `n` items.""" + result = _str_split(native, by, n, literal=literal) + if as_struct: + # NOTE: `polars` would return a struct w/ field names (`'field_0`, ..., 'field_n-1`) + msg = "TODO: `ArrowExpr.str.splitn`" + raise NotImplementedError(msg) + return result + + +@t.overload +def str_contains( + native: ChunkedArrayAny, pattern: str, *, literal: bool = ... +) -> ChunkedArray[pa.BooleanScalar]: ... +@t.overload +def str_contains( + native: ChunkedOrScalarAny, pattern: str, *, literal: bool = ... +) -> ChunkedOrScalar[pa.BooleanScalar]: ... +@t.overload +def str_contains( + native: ArrowAny, pattern: str, *, literal: bool = ... +) -> Arrow[pa.BooleanScalar]: ... +def str_contains( + native: ArrowAny, pattern: str, *, literal: bool = False +) -> Arrow[pa.BooleanScalar]: + """Check if the string contains a substring that matches a pattern.""" + name = "match_substring" if literal else "match_substring_regex" + result: Arrow[pa.BooleanScalar] = pc.call_function( + name, [native], pa_options.match_substring(pattern) + ) + return result + + +def str_strip_chars(native: Incomplete, characters: str | None) -> Incomplete: + if characters: + return pc.utf8_trim(native, characters) + return pc.utf8_trim_whitespace(native) + + +def str_replace( + native: Incomplete, pattern: str, value: str, *, literal: bool = False, n: int = 1 +) -> Incomplete: + fn = pc.replace_substring if literal else pc.replace_substring_regex + return fn(native, pattern, replacement=value, max_replacements=n) + + +def str_replace_all( + native: Incomplete, pattern: str, value: str, *, literal: bool = False +) -> Incomplete: + return str_replace(native, pattern, value, literal=literal, n=-1) + + +def str_replace_vector( + native: ChunkedArrayAny, + pattern: str, + replacements: ChunkedArrayAny, + *, + literal: bool = False, + n: int | None = 1, +) -> ChunkedArrayAny: + has_match = str_contains(native, pattern, literal=literal) + if not any_(has_match).as_py(): + # fastpath, no work to do + return native + match, match_replacements = filter_arrays(has_match, native, replacements) + if n is None or n == -1: + list_split_by = str_split(match, pattern, literal=literal) + else: + list_split_by = str_splitn(match, pattern, n + 1, literal=literal) + replaced = list_join(list_split_by, match_replacements, ignore_nulls=False) + if all_(has_match, ignore_nulls=False).as_py(): + return chunked_array(replaced) + return replace_with_mask(native, has_match, array(replaced)) + + +def str_zfill(native: ChunkedOrScalarAny, length: int) -> ChunkedOrScalarAny: + if HAS_ZFILL: + zfill: Incomplete = pc.utf8_zero_fill # type: ignore[attr-defined] + result: ChunkedOrScalarAny = zfill(native, length) + else: + result = _str_zfill_compat(native, length) + return result + + +# TODO @dangotbanned: Finish tidying this up +def _str_zfill_compat( + native: ChunkedOrScalarAny, length: int +) -> Incomplete: # pragma: no cover + dtype = string_type([native.type]) + hyphen, plus = lit("-", dtype), lit("+", dtype) + + padded_remaining = str_pad_start(str_slice(native, 1), length - 1, "0") + padded_lt_length = str_pad_start(native, length, "0") + + binary_join: Incomplete = pc.binary_join_element_wise + if isinstance(native, pa.Scalar): + case_1: ArrowAny = hyphen # starts with hyphen and less than length + case_2: ArrowAny = plus # starts with plus and less than length + else: + arr_len = len(native) + case_1 = repeat_unchecked(hyphen, arr_len) + case_2 = repeat_unchecked(plus, arr_len) + + first_char = str_slice(native, 0, 1) + lt_length = lt(str_len_chars(native), lit(length)) + first_hyphen_lt_length = and_(eq(first_char, hyphen), lt_length) + first_plus_lt_length = and_(eq(first_char, plus), lt_length) + return when_then( + first_hyphen_lt_length, + binary_join(case_1, padded_remaining, ""), + when_then( + first_plus_lt_length, + binary_join(case_2, padded_remaining, ""), + when_then(lt_length, padded_lt_length, native), + ), + ) + + +@t.overload +def when_then( + predicate: Predicate, then: SameArrowT, otherwise: SameArrowT | None +) -> SameArrowT: ... +@t.overload +def when_then(predicate: Predicate, then: ScalarAny, otherwise: ArrowT) -> ArrowT: ... +@t.overload +def when_then( + predicate: Predicate, then: ArrowT, otherwise: ScalarAny | NonNestedLiteral = ... +) -> ArrowT: ... +@t.overload +def when_then( + predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None +) -> Incomplete: ... +def when_then( + predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None +) -> Incomplete: + if is_non_nested_literal(otherwise): + otherwise = lit(otherwise, then.type) + return pc.if_else(predicate, then, otherwise) + + +def any_(native: Incomplete, *, ignore_nulls: bool = True) -> pa.BooleanScalar: + return pc.any(native, min_count=0, skip_nulls=ignore_nulls) -def all_(native: Any) -> pa.BooleanScalar: - return pc.all(native, min_count=0) +def all_(native: Incomplete, *, ignore_nulls: bool = True) -> pa.BooleanScalar: + return pc.all(native, min_count=0, skip_nulls=ignore_nulls) -def sum_(native: Any) -> NativeScalar: +def sum_(native: Incomplete) -> NativeScalar: return pc.sum(native, min_count=0) +def first(native: ChunkedOrArrayAny) -> NativeScalar: + return pc.first(native, options=pa_options.scalar_aggregate()) + + +def last(native: ChunkedOrArrayAny) -> NativeScalar: + return pc.last(native, options=pa_options.scalar_aggregate()) + + min_ = pc.min +# TODO @dangotbanned: Wrap horizontal functions with correct typing +# Should only return scalar if all elements are as well min_horizontal = pc.min_element_wise max_ = pc.max max_horizontal = pc.max_element_wise -mean = pc.mean +mean = t.cast("Callable[[ChunkedOrArray[pc.NumericScalar]], pa.DoubleScalar]", pc.mean) count = pc.count median = pc.approximate_median std = pc.stddev @@ -226,10 +1098,79 @@ def sum_(native: Any) -> NativeScalar: quantile = pc.quantile +def mode_all(native: ChunkedArrayAny) -> ChunkedArrayAny: + struct = pc.mode(native, n=len(native)) + indices: pa.Int32Array = struct.field("count").dictionary_encode().indices # type: ignore[attr-defined] + index_true_modes = lit(0) + return chunked_array(struct.field("mode").filter(pc.equal(indices, index_true_modes))) + + +def mode_any(native: ChunkedArrayAny) -> NativeScalar: + return first(pc.mode(native, n=1).field("mode")) + + +def kurtosis_skew( + native: ChunkedArray[pc.NumericScalar], function: Literal["kurtosis", "skew"], / +) -> NativeScalar: + result: NativeScalar + if HAS_KURTOSIS_SKEW: + if pa.types.is_null(native.type): + native = native.cast(F64) + result = getattr(pc, function)(native) + else: + non_null = native.drop_null() + if len(non_null) == 0: + result = lit(None, F64) + elif len(non_null) == 1: + result = lit(float("nan")) + elif function == "skew" and len(non_null) == 2: + result = lit(0.0, F64) + else: + m = sub(non_null, mean(non_null)) + m2 = mean(power(m, lit(2))) + if function == "kurtosis": + m4 = mean(power(m, lit(4))) + result = sub(pc.divide(m4, power(m2, lit(2))), lit(3)) + else: + m3 = mean(power(m, lit(3))) + result = pc.divide(m3, power(m2, lit(1.5))) + return result + + +def clip_lower( + native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + return max_horizontal(native, lower) + + +def clip_upper( + native: ChunkedOrScalarAny, upper: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + return min_horizontal(native, upper) + + +def clip( + native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny, upper: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + return clip_lower(clip_upper(native, upper), lower) + + def n_unique(native: Any) -> pa.Int64Scalar: return count(native, mode="all") +@t.overload +def round(native: ChunkedOrScalarAny, decimals: int = ...) -> ChunkedOrScalarAny: ... +@t.overload +def round(native: ChunkedOrArrayT, decimals: int = ...) -> ChunkedOrArrayT: ... +def round(native: ArrowAny, decimals: int = 0) -> ArrowAny: + return pc.round(native, decimals, round_mode="half_towards_infinity") + + +def log(native: ChunkedOrScalarAny, base: float = math.e) -> ChunkedOrScalarAny: + return t.cast("ChunkedOrScalarAny", pc.logb(native, lit(base))) + + def reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: """Unlike other slicing ops, `[::-1]` creates a full-copy. @@ -258,7 +1199,7 @@ def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: return cum_sum(is_not_null(native).cast(pa.uint32())) -CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { +_CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { F.CumSum: cum_sum, F.CumCount: cum_count, F.CumMin: cum_min, @@ -267,23 +1208,32 @@ def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: } -def diff(native: ChunkedOrArrayT) -> ChunkedOrArrayT: +def cumulative(native: ChunkedArrayAny, f: F.CumAgg, /) -> ChunkedArrayAny: + func = _CUMULATIVE[type(f)] + return func(native) if not f.reverse else reverse(func(reverse(native))) + + +def diff(native: ChunkedOrArrayT, n: int = 1) -> ChunkedOrArrayT: # pyarrow.lib.ArrowInvalid: Vector kernel cannot execute chunkwise and no chunked exec function was defined return ( - pc.pairwise_diff(native) + pc.pairwise_diff(native, n) if isinstance(native, pa.Array) - else chunked_array(pc.pairwise_diff(native.combine_chunks())) + else chunked_array(pc.pairwise_diff(native.combine_chunks(), n)) ) -def shift(native: ChunkedArrayAny, n: int) -> ChunkedArrayAny: +def shift( + native: ChunkedArrayAny, n: int, *, fill_value: NonNestedLiteral = None +) -> ChunkedArrayAny: if n == 0: return native arr = native if n > 0: - arrays = [nulls_like(n, arr), *arr.slice(length=arr.length() - n).chunks] + filled = repeat_like(fill_value, n, arr) + arrays = [filled, *arr.slice(length=arr.length() - n).chunks] else: - arrays = [*arr.slice(offset=-n).chunks, nulls_like(-n, arr)] + filled = repeat_like(fill_value, -n, arr) + arrays = [*arr.slice(offset=-n).chunks, filled] return pa.chunked_array(arrays) @@ -304,26 +1254,175 @@ def null_count(native: ChunkedOrArrayAny) -> pa.Int64Scalar: return pc.count(native, mode="only_null") -def has_nulls(native: ChunkedOrArrayAny) -> bool: - return bool(native.null_count) - - def preserve_nulls( before: ChunkedOrArrayAny, after: ChunkedOrArrayT, / ) -> ChunkedOrArrayT: - if has_nulls(before): - after = pc.if_else(before.is_null(), lit(None, after.type), after) - return after + return when_then(is_not_null(before), after) if before.null_count else after + + +drop_nulls = t.cast("VectorFunction[...]", pc.drop_null) + + +def is_only_nulls(native: ChunkedOrArrayAny, *, nan_is_null: bool = False) -> bool: + """Return True if `native` has no non-null values (and optionally include NaN).""" + return array(native.is_null(nan_is_null=nan_is_null), BOOL).false_count == 0 + + +_FILL_NULL_STRATEGY: Mapping[FillNullStrategy, UnaryFunction] = { + "forward": pc.fill_null_forward, + "backward": pc.fill_null_backward, +} + + +def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArrayAny: + SENTINEL = lit(-1) # noqa: N806 + is_not_null = native.is_valid() + index = int_range(len(native), chunked=False) + index_not_null = cum_max(when_then(is_not_null, index, SENTINEL)) + # NOTE: The correction here is for nulls at either end of the array + # They should be preserved when the `strategy` would need an out-of-bounds index + not_oob = not_eq(index_not_null, SENTINEL) + index_not_null = when_then(not_oob, index_not_null) + beyond_limit = gt(sub(index, index_not_null), lit(limit)) + return when_then(or_(is_not_null, beyond_limit), native, native.take(index_not_null)) + + +@t.overload +def fill_null( + native: ChunkedOrScalarT, value: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarT: ... +@t.overload +def fill_null( + native: ChunkedOrArrayT, value: ScalarAny | NonNestedLiteral | ChunkedOrArrayT +) -> ChunkedOrArrayT: ... +@t.overload +def fill_null( + native: ChunkedOrScalarAny, value: ChunkedOrScalarAny | NonNestedLiteral +) -> ChunkedOrScalarAny: ... +def fill_null(native: ArrowAny, value: ArrowAny | NonNestedLiteral) -> ArrowAny: + fill_value: Incomplete = value + result: ArrowAny = pc.fill_null(native, fill_value) + return result + +@t.overload +def fill_nan( + native: ChunkedOrScalarT, value: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarT: ... +@t.overload +def fill_nan(native: SameArrowT, value: NonNestedLiteral | ArrowAny) -> SameArrowT: ... +def fill_nan(native: ArrowAny, value: NonNestedLiteral | ArrowAny) -> Incomplete: + return when_then(is_not_nan(native), native, value) + + +def fill_null_forward(native: ChunkedArrayAny) -> ChunkedArrayAny: + return fill_null_with_strategy(native, "forward") + + +def fill_null_with_strategy( + native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None +) -> ChunkedArrayAny: + null_count = native.null_count + if null_count == 0 or (null_count == len(native)): + return native + if limit is None: + return _FILL_NULL_STRATEGY[strategy](native) + if strategy == "forward": + return _fill_null_forward_limit(native, limit) + return reverse(_fill_null_forward_limit(reverse(native), limit)) + + +def _ensure_all_replaced( + native: ChunkedOrScalarAny, unmatched: ArrowAny +) -> ValueError | None: + if not any_(unmatched).as_py(): + return None + msg = ( + "replace_strict did not replace all non-null values.\n\n" + f"The following did not get replaced: {chunked_array(native).filter(array(unmatched)).unique().to_pylist()}" + ) + return ValueError(msg) + + +def replace_strict( + native: ChunkedOrScalarAny, + old: Seq[Any], + new: Seq[Any], + dtype: pa.DataType | None = None, +) -> ChunkedOrScalarAny: + if isinstance(native, pa.Scalar): + idxs: ArrayAny = array(pc.index_in(native, pa.array(old))) + result: ChunkedOrScalarAny = pa.array(new).take(idxs)[0] + else: + idxs = pc.index_in(native, pa.array(old)) + result = chunked_array(pa.array(new).take(idxs)) + if err := _ensure_all_replaced(native, and_(is_not_null(native), is_null(idxs))): + raise err + return result.cast(dtype) if dtype else result + + +def replace_strict_default( + native: ChunkedOrScalarAny, + old: Seq[Any], + new: Seq[Any], + default: ChunkedOrScalarAny, + dtype: pa.DataType | None = None, +) -> ChunkedOrScalarAny: + idxs = pc.index_in(native, pa.array(old)) + result = pa.array(new).take(array(idxs)) + result = when_then(is_null(idxs), default, result.cast(dtype) if dtype else result) + return chunked_array(result) if isinstance(native, pa.ChunkedArray) else result[0] + +@overload +def replace_with_mask( + native: ChunkedOrArrayT, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayT: ... +@overload +def replace_with_mask( + native: ChunkedOrArrayAny, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayAny: ... +def replace_with_mask( + native: ChunkedOrArrayAny, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayAny: + """Replace elements of `native`, at positions defined by `mask`. + + The length of `replacements` must equal the number of `True` values in `mask`. + """ + if isinstance(native, pa.ChunkedArray): + args = [array(p) for p in (native, mask, replacements)] + return chunked_array(pc.call_function("replace_with_mask", args)) + args = [native, array(mask), array(replacements)] + result: ChunkedOrArrayAny = pc.call_function("replace_with_mask", args) + return result + + +@t.overload +def is_between( + native: ChunkedArray[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, + closed: ClosedInterval, +) -> ChunkedArray[pa.BooleanScalar]: ... +@t.overload +def is_between( + native: ChunkedOrScalar[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, + closed: ClosedInterval, +) -> ChunkedOrScalar[pa.BooleanScalar]: ... def is_between( native: ChunkedOrScalar[ScalarT], - lower: ChunkedOrScalar[ScalarT], - upper: ChunkedOrScalar[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, closed: ClosedInterval, ) -> ChunkedOrScalar[pa.BooleanScalar]: fn_lhs, fn_rhs = _IS_BETWEEN[closed] - return and_(fn_lhs(native, lower), fn_rhs(native, upper)) + low, high = (el if _is_arrow(el) else lit(el) for el in (lower, upper)) + out: ChunkedOrScalar[pa.BooleanScalar] = and_( + fn_lhs(native, low), fn_rhs(native, high) + ) + return out @t.overload @@ -334,6 +1433,10 @@ def is_in( def is_in(values: ArrayAny, /, other: ChunkedOrArrayAny) -> Array[pa.BooleanScalar]: ... @t.overload def is_in(values: ScalarAny, /, other: ChunkedOrArrayAny) -> pa.BooleanScalar: ... +@t.overload +def is_in( + values: ChunkedOrScalarAny, /, other: ChunkedOrArrayAny +) -> ChunkedOrScalarAny: ... def is_in(values: ArrowAny, /, other: ChunkedOrArrayAny) -> ArrowAny: """Check if elements of `values` are present in `other`. @@ -347,15 +1450,84 @@ def is_in(values: ArrowAny, /, other: ChunkedOrArrayAny) -> ArrowAny: return is_in_(values, other) # type: ignore[no-any-return] +@t.overload +def eq_missing( + native: ChunkedArrayAny, other: NonNestedLiteral | ArrowAny +) -> ChunkedArray[pa.BooleanScalar]: ... +@t.overload +def eq_missing( + native: ArrayAny, other: NonNestedLiteral | ArrowAny +) -> Array[pa.BooleanScalar]: ... +@t.overload +def eq_missing( + native: ScalarAny, other: NonNestedLiteral | ArrowAny +) -> pa.BooleanScalar: ... +@t.overload +def eq_missing( + native: ChunkedOrScalarAny, other: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarAny: ... +def eq_missing(native: ArrowAny, other: NonNestedLiteral | ArrowAny) -> ArrowAny: + """Equivalent to `native == other` where `None == None`. + + This differs from default `eq` where null values are propagated. + + Note: + Unique to `pyarrow`, this wrapper will ensure `None` uses `native.type`. + """ + if isinstance(other, (pa.Array, pa.ChunkedArray)): + return is_in(native, other) + item = array(other if isinstance(other, pa.Scalar) else lit(other, native.type)) + return is_in(native, item) + + +def ir_min_max(name: str, /) -> MinMax: + return MinMax(expr=ir.col(name)) + + +def _boolean_is_unique( + indices: ChunkedArrayAny, aggregated: ChunkedStruct, / +) -> ChunkedArrayAny: + min, max = struct_fields(aggregated, "min", "max") + return and_(is_in(indices, min), is_in(indices, max)) + + +def _boolean_is_duplicated( + indices: ChunkedArrayAny, aggregated: ChunkedStruct, / +) -> ChunkedArrayAny: + return not_(_boolean_is_unique(indices, aggregated)) + + +BOOLEAN_LENGTH_PRESERVING: Mapping[ + type[ir.boolean.BooleanFunction], tuple[IntoColumnAgg, BooleanLengthPreserving] +] = { + ir.boolean.IsFirstDistinct: (ir.min, is_in), + ir.boolean.IsLastDistinct: (ir.max, is_in), + ir.boolean.IsUnique: (ir_min_max, _boolean_is_unique), + ir.boolean.IsDuplicated: (ir_min_max, _boolean_is_duplicated), +} + + def binary( lhs: ChunkedOrScalarAny, op: type[ops.Operator], rhs: ChunkedOrScalarAny ) -> ChunkedOrScalarAny: return _DISPATCH_BINARY[op](lhs, rhs) +@t.overload +def concat_str( + *arrays: ChunkedArrayAny, separator: str = ..., ignore_nulls: bool = ... +) -> ChunkedArray[StringScalar]: ... +@t.overload +def concat_str( + *arrays: ArrayAny, separator: str = ..., ignore_nulls: bool = ... +) -> Array[StringScalar]: ... +@t.overload def concat_str( - *arrays: ChunkedArrayAny, separator: str = "", ignore_nulls: bool = False -) -> ChunkedArray[StringScalar]: + *arrays: ScalarAny, separator: str = ..., ignore_nulls: bool = ... +) -> StringScalar: ... +def concat_str( + *arrays: ArrowAny, separator: str = "", ignore_nulls: bool = False +) -> Arrow[StringScalar]: dtype = string_type(obj.type for obj in arrays) it = (obj.cast(dtype) for obj in arrays) concat: Incomplete = pc.binary_join_element_wise @@ -363,6 +1535,21 @@ def concat_str( return concat(*it, lit(separator, dtype), options=join) # type: ignore[no-any-return] +def random_indices( + end: int, /, n: int, *, with_replacement: bool = False, seed: int | None = None +) -> ArrayAny: + """Generate `n` random indices within the range `[0, end)`.""" + # NOTE: Review this path if anything changes upstream + # https://github.com/apache/arrow/issues/47288#issuecomment-3597653670 + if with_replacement: + rand_values = pc.random(n, initializer="system" if seed is None else seed) + return round(multiply(rand_values, lit(end - 1))).cast(I64) + + import numpy as np # ignore-banned-import + + return array(np.random.default_rng(seed).choice(np.arange(end), n, replace=False)) + + def sort_indices( native: ChunkedOrArrayAny | pa.Table, *order_by: str, @@ -485,21 +1672,175 @@ def date_range( return ca.cast(pa.date32()) +def linear_space( + start: float, end: float, num_samples: int, *, closed: ClosedInterval = "both" +) -> ChunkedArray[pc.NumericScalar]: + """Based on [`new_linear_space_f64`]. + + [`new_linear_space_f64`]: https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-ops/src/series/ops/linear_space.rs#L62-L94 + """ + if num_samples < 0: + msg = f"Number of samples, {num_samples}, must be non-negative." + raise ValueError(msg) + if num_samples == 0: + return chunked_array([[]], F64) + if num_samples == 1: + if closed == "none": + value = (end + start) * 0.5 + elif closed in {"left", "both"}: + value = float(start) + else: + value = float(end) + return chunked_array([[value]], F64) + n = num_samples + span = float(end - start) + if closed == "none": + d = span / (n + 1) + start = start + d + elif closed == "left": + d = span / n + elif closed == "right": + start = start + span / n + d = span / n + else: + d = span / (n - 1) + ca: ChunkedArray[pc.NumericScalar] = multiply(int_range(0, n).cast(F64), lit(d)) + ca = add(ca, lit(start, F64)) + return ca # noqa: RET504 + + +def repeat(value: ScalarAny | NonNestedLiteral, n: int) -> ArrayAny: + value = value if isinstance(value, pa.Scalar) else lit(value) + return repeat_unchecked(value, n) + + +def repeat_unchecked(value: ScalarAny, /, n: int) -> ArrayAny: + repeat_: Incomplete = pa.repeat + result: ArrayAny = repeat_(value, n) + return result + + +def repeat_like(value: NonNestedLiteral, n: int, native: ArrowAny) -> ArrayAny: + return repeat_unchecked(lit(value, native.type), n) + + def nulls_like(n: int, native: ArrowAny) -> ArrayAny: """Create a strongly-typed Array instance with all elements null. Uses the type of `native`. """ - return pa.nulls(n, native.type) # type: ignore[no-any-return] + result: ArrayAny = pa.nulls(n, native.type) + return result + + +def zeros(n: int, /) -> pa.Int64Array: + return pa.repeat(0, n) + + +SearchSortedSide: TypeAlias = Literal["left", "right"] + + +# NOTE @dangotbanned: (wish) replacing `np.searchsorted`? +@t.overload +def search_sorted( + native: ChunkedOrArrayT, + element: ChunkedOrArray[NumericScalar] | Sequence[float], + *, + side: SearchSortedSide = ..., +) -> ChunkedOrArrayT: ... +# NOTE: scalar case may work with only `partition_nth_indices`? +@t.overload +def search_sorted( + native: ChunkedOrArrayT, element: float, *, side: SearchSortedSide = ... +) -> ScalarAny: ... +def search_sorted( + native: ChunkedOrArrayT, + element: ChunkedOrArray[NumericScalar] | Sequence[float] | float, + *, + side: SearchSortedSide = "left", +) -> ChunkedOrArrayT | ScalarAny: + """Find indices where elements should be inserted to maintain order.""" + import numpy as np # ignore-banned-import + + indices = np.searchsorted(element, native, side=side) + if isinstance(indices, np.generic): + return lit(indices) + if isinstance(native, pa.ChunkedArray): + return chunked_array([indices]) + return array(indices) + + +def hist_bins( + native: ChunkedArrayAny, + bins: Sequence[float] | ChunkedArray[NumericScalar], + *, + include_breakpoint: bool, +) -> Mapping[str, Iterable[Any]]: + """Bin values into buckets and count their occurrences. + + Notes: + Assumes that the following edge cases have been handled: + - `len(bins) >= 2` + - `bins` increase monotonically + - `bin[0] != bin[-1]` + - `native` contains values that are non-null (including NaN) + """ + if len(bins) == 2: + upper = bins[1] + count = array(is_between(native, bins[0], upper, closed="both"), BOOL).true_count + if include_breakpoint: + return {"breakpoint": [upper], "count": [count]} + return {"count": [count]} + + # lowest bin is inclusive + # NOTE: `np.unique` behavior sorts first + value_counts = ( + when_then(not_eq(native, lit(bins[0])), search_sorted(native, bins), 1) + .sort() + .value_counts() + ) + values, counts = struct_fields(value_counts, "values", "counts") + bin_count = len(bins) + int_range_ = int_range(1, bin_count, chunked=False) + mask = is_in(int_range_, values) + replacements = counts.filter(is_in(values, int_range_)) + counts = replace_with_mask(zeros(bin_count - 1), mask, replacements) + if include_breakpoint: + return {"breakpoint": bins[1:], "count": counts} + return {"count": counts} + +def hist_zeroed_data( + arg: int | Sequence[float], *, include_breakpoint: bool +) -> Mapping[str, Iterable[Any]]: + # NOTE: If adding `linear_space` and `zeros` to `CompliantNamespace`, consider moving this. + n = arg if isinstance(arg, int) else len(arg) - 1 + if not include_breakpoint: + return {"count": zeros(n)} + bp = linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] + return {"breakpoint": bp, "count": zeros(n)} + + +@overload +def lit(value: Any) -> NativeScalar: ... +@overload +def lit(value: Any, dtype: BoolType) -> pa.BooleanScalar: ... +@overload +def lit(value: Any, dtype: UInt32Type) -> pa.UInt32Scalar: ... +@overload +def lit(value: Any, dtype: DataType | None = ...) -> NativeScalar: ... def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) +# TODO @dangotbanned: Report `ListScalar.values` bug upstream +# See `tests/plan/list_unique_test.py::test_list_unique_scalar[None-None]` @overload def array(data: ArrowAny, /) -> ArrayAny: ... @overload +def array(data: Arrow[BooleanScalar], dtype: BoolType, /) -> pa.BooleanArray: ... +@overload def array( data: Iterable[PythonLiteral], dtype: DataType | None = None, / ) -> ArrayAny: ... @@ -509,13 +1850,17 @@ def array( """Convert `data` into an Array instance. Note: - `dtype` is not used for existing `pyarrow` data, use `cast` instead. + `dtype` is **not used** for existing `pyarrow` data, but it can be used to signal + the concrete `Array` subclass that is returned. + To actually changed the type, use `cast` instead. """ if isinstance(data, pa.ChunkedArray): return data.combine_chunks() if isinstance(data, pa.Array): return data if isinstance(data, pa.Scalar): + if isinstance(data, pa.ListScalar) and data.is_valid is False: + return pa.array([None], data.type) return pa.array([data], data.type) return pa.array(data, dtype) @@ -526,27 +1871,27 @@ def chunked_array( return _chunked_array(array(data) if isinstance(data, pa.Scalar) else data, dtype) -def concat_vertical_chunked( - arrays: Iterable[ChunkedArrayAny], dtype: DataType | None = None, / -) -> ChunkedArrayAny: - v_concat: Incomplete = pa.chunked_array - return v_concat(arrays, dtype) # type: ignore[no-any-return] - - -def concat_vertical_table( - tables: Iterable[pa.Table], /, promote_options: PromoteOptions = "none" +def concat_horizontal( + arrays: Collection[ChunkedOrArrayAny], names: Collection[str] ) -> pa.Table: - return pa.concat_tables(tables, promote_options=promote_options) + """Concatenate `arrays` as columns in a new table.""" + table: Incomplete = pa.Table.from_arrays + result: pa.Table = table(arrays, names) + return result -if BACKEND_VERSION >= (14,): +def concat_vertical( + arrays: Iterable[ChunkedOrArrayAny], dtype: DataType | None = None, / +) -> ChunkedArrayAny: + """Concatenate `arrays` into a new array.""" + v_concat: Incomplete = pa.chunked_array + result: ChunkedArrayAny = v_concat(arrays, dtype) + return result - def concat_diagonal(tables: Iterable[pa.Table]) -> pa.Table: - return pa.concat_tables(tables, promote_options="default") -else: - def concat_diagonal(tables: Iterable[pa.Table]) -> pa.Table: - return pa.concat_tables(tables, promote=True) +def to_table(array: ChunkedOrArrayAny, name: str = "") -> pa.Table: + """Equivalent to `Series.to_frame`, but with an option to insert a name for the column.""" + return concat_horizontal((array,), (name,)) def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: @@ -555,3 +1900,23 @@ def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataT and isinstance(first[0], str) and isinstance(first[1], pa.DataType) ) + + +def _is_arrow(obj: Arrow[ScalarT] | Any) -> TypeIs[Arrow[ScalarT]]: + return isinstance(obj, (pa.Scalar, pa.Array, pa.ChunkedArray)) + + +def filter_arrays( + predicate: ChunkedOrArray[BooleanScalar] | pc.Expression, + *arrays: Unpack[Ts], + ignore_nulls: bool = True, +) -> tuple[Unpack[Ts]]: + """Apply the same filter to multiple arrays, returning them independently. + + Note: + The typing here is a minefield. You'll get an `*arrays`-length `tuple[ChunkedArray, ...]`. + """ + table: Incomplete = pa.Table.from_arrays + tmp = [str(i) for i in range(len(arrays))] + result = table(arrays, tmp).filter(predicate, "drop" if ignore_nulls else "emit_null") + return t.cast("tuple[Unpack[Ts]]", tuple(result.columns)) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index cce6463239..af918a3e4c 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -17,7 +17,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Collection, Iterator, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence from typing_extensions import Self, TypeAlias @@ -49,24 +49,28 @@ agg.NUnique: "hash_count_distinct", agg.First: "hash_first", agg.Last: "hash_last", + fn.MinMax: "hash_min_max", } SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", ir.Column: "hash_list", # `hash_aggregate` only } + +_version_dependent: dict[Any, acero.Aggregation] = {} +if fn.HAS_KURTOSIS_SKEW: + _version_dependent.update( + {ir.functions.Kurtosis: "hash_kurtosis", ir.functions.Skew: "hash_skew"} + ) + SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { ir.boolean.All: "hash_all", ir.boolean.Any: "hash_any", ir.functions.Unique: "hash_distinct", # `hash_aggregate` only ir.functions.NullCount: "hash_count", + **_version_dependent, } -REQUIRES_PYARROW_20: tuple[Literal["kurtosis"], Literal["skew"]] = ("kurtosis", "skew") -"""They don't show in [our version of the stubs], but are possible in [`pyarrow>=20`]. - -[our version of the stubs]: https://github.com/narwhals-dev/narwhals/issues/2124#issuecomment-3191374210 -[`pyarrow>=20`]: https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations -""" +del _version_dependent class AggSpec: @@ -132,6 +136,39 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: fn_name = SUPPORTED_IR[type(expr)] return cls(expr.name if isinstance(expr, ir.Column) else (), fn_name, name=name) + # NOTE: Fast-paths for single column rewrites + @classmethod + def _from_function(cls, tp: type[ir.Function], name: str) -> Self: + return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name) + + @classmethod + def any(cls, name: str) -> Self: + return cls._from_function(ir.boolean.Any, name) + + @classmethod + def unique(cls, name: str) -> Self: + return cls._from_function(ir.functions.Unique, name) + + @classmethod + def implode(cls, name: str) -> Self: + # TODO @dangotbanned: Replace with `agg.Implode` (via `_from_agg`) once both have been dded + # https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-plan/src/dsl/expr/mod.rs#L44 + return cls(name, SUPPORTED_IR[ir.Column], None, name) + + def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: + """Sugar for `native.group_by(keys).aggregate([self])`. + + Returns a table with columns named: `[*keys, self.name]` + """ + return acero.group_by_table(native, keys, [self]) + + def over_index(self, native: pa.Table, index_column: str) -> ChunkedArrayAny: + """Execute this aggregation over `index_column`. + + Returns a single, (unnamed) array, representing the aggregation results. + """ + return acero.group_by_table(native, [index_column], [self]).column(self.name) + def group_by_error( column_name: str, expr: ir.ExprIR, reason: Literal["too complex"] | None = None @@ -145,15 +182,6 @@ def group_by_error( return InvalidOperationError(msg) -def multiple_null_partitions_error(column_names: Collection[str]) -> NotImplementedError: - backend = Implementation.PYARROW - msg = ( - f"`over(*partition_by)` where multiple columns contain null values is not yet supported for {backend!r}\n" - f"Got: {list(column_names)!r}" - ) - return NotImplementedError(msg) - - class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): _df: Frame _keys: Seq[NamedIR] diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index fa0e9e1e71..534aab37d5 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -11,6 +11,7 @@ from narwhals._plan._guards import is_tuple_of from narwhals._plan.arrow import functions as fn from narwhals._plan.compliant.namespace import EagerNamespace +from narwhals._plan.expressions.expr import RangeExpr from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._utils import Implementation, Version from narwhals.exceptions import InvalidOperationError @@ -25,8 +26,8 @@ from narwhals._plan.arrow.typing import ChunkedArray, IntegerScalar from narwhals._plan.expressions import expr, functions as F from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal - from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr - from narwhals._plan.expressions.ranges import DateRange, IntRange + from narwhals._plan.expressions.expr import FunctionExpr as FExpr, RangeExpr + from narwhals._plan.expressions.ranges import DateRange, IntRange, LinearSpace from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.series import Series as NwSeries from narwhals._plan.typing import NonNestedLiteralT @@ -100,15 +101,13 @@ def lit( nw_ser.to_native(), name or node.name, nw_ser.version ) - # NOTE: Update with `ignore_nulls`/`fill_null` behavior once added to each `Function` - # https://github.com/narwhals-dev/narwhals/pull/2719 def _horizontal_function( self, fn_native: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None - ) -> Callable[[FunctionExpr[Any], Frame, str], Expr | Scalar]: - def func(node: FunctionExpr[Any], frame: Frame, name: str) -> Expr | Scalar: + ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: + def func(node: FExpr[Any], frame: Frame, name: str) -> Expr | Scalar: it = (self._expr.from_ir(e, frame, name).native for e in node.input) if fill is not None: - it = (pc.fill_null(native, fn.lit(fill)) for native in it) + it = (fn.fill_null(native, fill) for native in it) result = reduce(fn_native, it) if isinstance(result, pa.Scalar): return self._scalar.from_native(result, name, self.version) @@ -116,37 +115,46 @@ def func(node: FunctionExpr[Any], frame: Frame, name: str) -> Expr | Scalar: return func + def coalesce(self, node: FExpr[F.Coalesce], frame: Frame, name: str) -> Expr | Scalar: + it = (self._expr.from_ir(e, frame, name).native for e in node.input) + result = pc.coalesce(*it) + if isinstance(result, pa.Scalar): + return self._scalar.from_native(result, name, self.version) + return self._expr.from_native(result, name, self.version) + def any_horizontal( - self, node: FunctionExpr[AnyHorizontal], frame: Frame, name: str + self, node: FExpr[AnyHorizontal], frame: Frame, name: str ) -> Expr | Scalar: - return self._horizontal_function(fn.or_)(node, frame, name) + fill = False if node.function.ignore_nulls else None + return self._horizontal_function(fn.or_, fill)(node, frame, name) def all_horizontal( - self, node: FunctionExpr[AllHorizontal], frame: Frame, name: str + self, node: FExpr[AllHorizontal], frame: Frame, name: str ) -> Expr | Scalar: - return self._horizontal_function(fn.and_)(node, frame, name) + fill = True if node.function.ignore_nulls else None + return self._horizontal_function(fn.and_, fill)(node, frame, name) def sum_horizontal( - self, node: FunctionExpr[F.SumHorizontal], frame: Frame, name: str + self, node: FExpr[F.SumHorizontal], frame: Frame, name: str ) -> Expr | Scalar: return self._horizontal_function(fn.add, fill=0)(node, frame, name) def min_horizontal( - self, node: FunctionExpr[F.MinHorizontal], frame: Frame, name: str + self, node: FExpr[F.MinHorizontal], frame: Frame, name: str ) -> Expr | Scalar: return self._horizontal_function(fn.min_horizontal)(node, frame, name) def max_horizontal( - self, node: FunctionExpr[F.MaxHorizontal], frame: Frame, name: str + self, node: FExpr[F.MaxHorizontal], frame: Frame, name: str ) -> Expr | Scalar: return self._horizontal_function(fn.max_horizontal)(node, frame, name) def mean_horizontal( - self, node: FunctionExpr[F.MeanHorizontal], frame: Frame, name: str + self, node: FExpr[F.MeanHorizontal], frame: Frame, name: str ) -> Expr | Scalar: int64 = pa.int64() inputs = [self._expr.from_ir(e, frame, name).native for e in node.input] - filled = (pc.fill_null(native, fn.lit(0)) for native in inputs) + filled = (fn.fill_null(native, 0) for native in inputs) # NOTE: `mypy` doesn't like that `add` is overloaded sum_not_null = reduce( fn.add, # type: ignore[arg-type] @@ -158,7 +166,7 @@ def mean_horizontal( return self._expr.from_native(result, name, self.version) def concat_str( - self, node: FunctionExpr[ConcatStr], frame: Frame, name: str + self, node: FExpr[ConcatStr], frame: Frame, name: str ) -> Expr | Scalar: exprs = (self._expr.from_ir(e, frame, name) for e in node.input) aligned = (ser.native for ser in self._expr.align(exprs)) @@ -169,8 +177,13 @@ def concat_str( return self._scalar.from_native(result, name, self.version) return self._expr.from_native(result, name, self.version) + # TODO @dangotbanned: Refactor alongside `nwp.functions._ensure_range_scalar` + # Consider returning the supertype of inputs def _range_function_inputs( - self, node: RangeExpr, frame: Frame, valid_type: type[NonNestedLiteralT] + self, + node: RangeExpr, + frame: Frame, + valid_type: type[NonNestedLiteralT] | tuple[type[NonNestedLiteralT], ...], ) -> tuple[NonNestedLiteralT, NonNestedLiteralT]: start_: PythonLiteral end_: PythonLiteral @@ -191,8 +204,10 @@ def _range_function_inputs( ) raise InvalidOperationError(msg) if isinstance(start_, valid_type) and isinstance(end_, valid_type): - return start_, end_ - msg = f"All inputs for `{node.function}()` must resolve to {valid_type.__name__}, but got \n{start_!r}\n{end_!r}" + return start_, end_ # type: ignore[return-value] + valid_types = (valid_type,) if not isinstance(valid_type, tuple) else valid_type + tp_names = " | ".join(tp.__name__ for tp in valid_types) + msg = f"All inputs for `{node.function}()` must resolve to {tp_names}, but got \n{start_!r}\n{end_!r}" raise InvalidOperationError(msg) def _int_range( @@ -240,6 +255,24 @@ def date_range_eager( native = fn.date_range(start, end, interval, closed=closed) return self._series.from_native(native, name, version=self.version) + def linear_space(self, node: RangeExpr[LinearSpace], frame: Frame, name: str) -> Expr: + start, end = self._range_function_inputs(node, frame, (int, float)) + func = node.function + native = fn.linear_space(start, end, func.num_samples, closed=func.closed) + return self._expr.from_native(native, name, self.version) + + def linear_space_eager( + self, + start: float, + end: float, + num_samples: int, + *, + closed: ClosedInterval = "both", + name: str = "literal", + ) -> Series: + native = fn.linear_space(start, end, num_samples, closed=closed) + return self._series.from_native(native, name, version=self.version) + @overload def concat(self, items: Iterable[Frame], *, how: ConcatMethod) -> Frame: ... @overload @@ -260,7 +293,7 @@ def concat( def _concat_diagonal(self, items: Iterable[Frame]) -> Frame: return self._dataframe.from_native( - fn.concat_vertical_table(df.native for df in items), self.version + fn.concat_tables((df.native for df in items), "default"), self.version ) def _concat_horizontal(self, items: Iterable[Frame | Series]) -> Frame: @@ -272,14 +305,14 @@ def gen(objs: Iterable[Frame | Series]) -> Iterator[tuple[ChunkedArrayAny, str]] yield from zip(item.native.itercolumns(), item.columns) arrays, names = zip(*gen(items)) - native = pa.Table.from_arrays(arrays, list(names)) + native = fn.concat_horizontal(arrays, names) return self._dataframe.from_native(native, self.version) def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series: collected = items if isinstance(items, tuple) else tuple(items) if is_tuple_of(collected, self._series): sers = collected - chunked = fn.concat_vertical_chunked(ser.native for ser in sers) + chunked = fn.concat_vertical(ser.native for ser in sers) return sers[0]._with_native(chunked) if is_tuple_of(collected, self._dataframe): dfs = collected @@ -293,5 +326,5 @@ def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series: f" - dataframe {i}: {cols_current}\n" ) raise TypeError(msg) - return df._with_native(fn.concat_vertical_table(df.native for df in dfs)) + return df._with_native(fn.concat_tables(df.native for df in dfs)) raise TypeError(items) diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 83e73eff87..3d44487bc7 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -109,21 +109,16 @@ def _sort_keys_every( return tuple((key, order) for key in by) -def _sort_keys( - by: tuple[str, ...], *, descending: bool | Sequence[bool] -) -> Seq[tuple[str, Order]]: - if not isinstance(descending, bool) and len(descending) == 1: - descending = descending[0] - if isinstance(descending, bool): - return _sort_keys_every(by, descending=descending) - it = zip_strict(by, descending) - return tuple(_sort_key(key, descending=desc) for (key, desc) in it) - - def sort( *by: str, descending: bool | Sequence[bool] = False, nulls_last: bool = False ) -> pc.SortOptions: - keys = _sort_keys(by, descending=descending) + if not isinstance(descending, bool) and len(descending) == 1: + descending = descending[0] + if isinstance(descending, bool): + keys = _sort_keys_every(by, descending=descending) + else: + it = zip_strict(by, descending) + keys = tuple(_sort_key(key, descending=desc) for (key, desc) in it) return pc.SortOptions(sort_keys=keys, null_placement=NULL_PLACEMENT[nulls_last]) @@ -138,6 +133,20 @@ def rank( ) +def match_substring(pattern: str) -> pc.MatchSubstringOptions: + return pc.MatchSubstringOptions(pattern) + + +def split_pattern(by: str, n: int | None = None) -> pc.SplitPatternOptions: + """Similar to `str.splitn`. + + Some glue for `max_splits=n - 1` + """ + if n is not None: + return pc.SplitPatternOptions(by, max_splits=n - 1) + return pc.SplitPatternOptions(by) + + def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: from narwhals._plan.expressions import aggregation as agg @@ -157,6 +166,7 @@ def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: boolean.All: scalar_aggregate(ignore_nulls=True), boolean.Any: scalar_aggregate(ignore_nulls=True), functions.NullCount: count("only_null"), + functions.Unique: count("all"), } diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index ffd68b660a..7390e7f7d5 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -1,27 +1,41 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast +import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn, options from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries +from narwhals._plan.compliant.accessors import SeriesStructNamespace as StructNamespace from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.compliant.typing import namespace +from narwhals._plan.expressions import functions as F from narwhals._utils import Version, generate_repr from narwhals.dependencies import is_numpy_array_1d +from narwhals.schema import Schema if TYPE_CHECKING: from collections.abc import Iterable import polars as pl - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan.arrow.dataframe import ArrowDataFrame as DataFrame + from narwhals._plan.arrow.namespace import ArrowNamespace as Namespace from narwhals._plan.arrow.typing import ChunkedArrayAny from narwhals.dtypes import DType - from narwhals.typing import Into1DArray, IntoDType, _1DArray + from narwhals.typing import ( + FillNullStrategy, + Into1DArray, + IntoDType, + NonNestedLiteral, + PythonLiteral, + _1DArray, + ) + +Incomplete: TypeAlias = Any class ArrowSeries(FrameSeries["ChunkedArrayAny"], CompliantSeries["ChunkedArrayAny"]): @@ -46,7 +60,7 @@ def to_polars(self) -> pl.Series: import polars as pl # ignore-banned-import # NOTE: Recommended in https://github.com/pola-rs/polars/issues/22921#issuecomment-2908506022 - return pl.Series(self.native) + return pl.Series(self.name, self.native) def __len__(self) -> int: return self.native.length() @@ -72,7 +86,7 @@ def from_iterable( name: str = "", dtype: IntoDType | None = None, ) -> Self: - dtype_pa = narwhals_to_native_dtype(dtype, version) if dtype else None + dtype_pa = fn.dtype_native(dtype, version) return cls.from_native(fn.chunked_array([data], dtype_pa), name, version=version) def cast(self, dtype: IntoDType) -> Self: @@ -86,11 +100,265 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: def scatter(self, indices: Self, values: Self) -> Self: mask = fn.is_in(fn.int_range(len(self), chunked=False), indices.native) - replacements = fn.array(values._gather(pc.sort_indices(indices.native))) - return self._with_native(pc.replace_with_mask(self.native, mask, replacements)) + replacements = values._gather(pc.sort_indices(indices.native)) + return self._with_native(fn.replace_with_mask(self.native, mask, replacements)) def is_in(self, other: Self) -> Self: return self._with_native(fn.is_in(self.native, other.native)) + def is_nan(self) -> Self: + return self._with_native(fn.is_nan(self.native)) + + def is_null(self) -> Self: + return self._with_native(fn.is_null(self.native)) + + def is_not_nan(self) -> Self: + return self._with_native(fn.is_not_nan(self.native)) + + def is_not_null(self) -> Self: + return self._with_native(fn.is_not_null(self.native)) + def has_nulls(self) -> bool: return bool(self.native.null_count) + + def null_count(self) -> int: + return self.native.null_count + + __add__ = fn.bin_op(fn.add) + __and__ = fn.bin_op(fn.and_) + __eq__ = fn.bin_op(fn.eq) + __floordiv__ = fn.bin_op(fn.floordiv) + __ge__ = fn.bin_op(fn.gt_eq) + __gt__ = fn.bin_op(fn.gt) + __le__ = fn.bin_op(fn.lt_eq) + __lt__ = fn.bin_op(fn.lt) + __mod__ = fn.bin_op(fn.modulus) + __mul__ = fn.bin_op(fn.multiply) + __ne__ = fn.bin_op(fn.not_eq) + __or__ = fn.bin_op(fn.or_) + __pow__ = fn.bin_op(fn.power) + __rfloordiv__ = fn.bin_op(fn.floordiv, reflect=True) + __radd__ = fn.bin_op(fn.add, reflect=True) + __rand__ = fn.bin_op(fn.and_, reflect=True) + __rmod__ = fn.bin_op(fn.modulus, reflect=True) + __rmul__ = fn.bin_op(fn.multiply, reflect=True) + __ror__ = fn.bin_op(fn.or_, reflect=True) + __rpow__ = fn.bin_op(fn.power, reflect=True) + __rsub__ = fn.bin_op(fn.sub, reflect=True) + __rtruediv__ = fn.bin_op(fn.truediv, reflect=True) + __rxor__ = fn.bin_op(fn.xor, reflect=True) + __sub__ = fn.bin_op(fn.sub) + __truediv__ = fn.bin_op(fn.truediv) + __xor__ = fn.bin_op(fn.xor) + + def __invert__(self) -> Self: + return self._with_native(pc.invert(self.native)) + + def cum_sum(self, *, reverse: bool = False) -> Self: + if not reverse: + return self._with_native(fn.cum_sum(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumSum(reverse=reverse))) + + def cum_count(self, *, reverse: bool = False) -> Self: + if not reverse: + return self._with_native(fn.cum_count(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumCount(reverse=reverse))) + + def cum_max(self, *, reverse: bool = False) -> Self: + if not reverse: + return self._with_native(fn.cum_max(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumMax(reverse=reverse))) + + def cum_min(self, *, reverse: bool = False) -> Self: + if not reverse: + return self._with_native(fn.cum_min(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumMin(reverse=reverse))) + + def cum_prod(self, *, reverse: bool = False) -> Self: + if not reverse: + return self._with_native(fn.cum_prod(self.native)) + return self._with_native(fn.cumulative(self.native, F.CumProd(reverse=reverse))) + + def fill_nan(self, value: float | Self | None) -> Self: + fill_value = value.native if isinstance(value, ArrowSeries) else value + return self._with_native(fn.fill_nan(self.native, fill_value)) + + def fill_null(self, value: NonNestedLiteral | Self) -> Self: + fill_value = value.native if isinstance(value, ArrowSeries) else value + return self._with_native(fn.fill_null(self.native, fill_value)) + + def fill_null_with_strategy( + self, strategy: FillNullStrategy, limit: int | None = None + ) -> Self: + return self._with_native(fn.fill_null_with_strategy(self.native, strategy, limit)) + + def diff(self, n: int = 1) -> Self: + return self._with_native(fn.diff(self.native, n)) + + def shift(self, n: int, *, fill_value: NonNestedLiteral = None) -> Self: + return self._with_native(fn.shift(self.native, n, fill_value=fill_value)) + + def _rolling_center(self, window_size: int) -> tuple[Self, int]: + """Think this is similar to [`polars_core::chunked_array::ops::rolling_window::inner_mod::window_edges`]. + + On `main`, this is `narwhals._arrow.utils.pad_series`. + + [`polars_core::chunked_array::ops::rolling_window::inner_mod::window_edges`]: https://github.com/pola-rs/polars/blob/e1d6f294218a36497255e2d872c223e19a47e2ec/crates/polars-core/src/chunked_array/ops/rolling_window.rs#L64-L77 + """ + offset_left = window_size // 2 + # subtract one if window_size is even + offset_right = offset_left - (window_size % 2 == 0) + native = self.native + arrays = ( + fn.nulls_like(offset_left, native), + *native.chunks, + fn.nulls_like(offset_right, native), + ) + offset = offset_left + offset_right + return self._with_native(fn.concat_vertical(arrays)), offset + + def _rolling_sum(self, window_size: int, /) -> Self: + cum_sum = self.cum_sum().fill_null_with_strategy("forward") + return cum_sum.diff(window_size).fill_null(cum_sum) + + def _rolling_count(self, window_size: int, /) -> Self: + cum_count = self.cum_count() + return cum_count.diff(window_size).fill_null(cum_count) + + def rolling_sum( + self, window_size: int, *, min_samples: int, center: bool = False + ) -> Self: + s, offset = self, 0 + if center: + s, offset = self._rolling_center(window_size) + rolling_count = s._rolling_count(window_size) + keep = rolling_count >= min_samples + result = s._rolling_sum(window_size).zip_with(keep, None) + return result.slice(offset) if offset else result + + def rolling_mean( + self, window_size: int, *, min_samples: int, center: bool = False + ) -> Self: + s, offset = self, 0 + if center: + s, offset = self._rolling_center(window_size) + rolling_count = s._rolling_count(window_size) + keep = rolling_count >= min_samples + result = (s._rolling_sum(window_size).zip_with(keep, None)) / rolling_count + return result.slice(offset) if offset else result + + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 + ) -> Self: + s, offset = self, 0 + if center: + s, offset = self._rolling_center(window_size) + rolling_count = s._rolling_count(window_size) + keep = rolling_count >= min_samples + + # NOTE: Yes, these two are different + sq_rolling_sum = s.pow(2)._rolling_sum(window_size) + rolling_sum_sq = s._rolling_sum(window_size).pow(2) + + # NOTE: Please somebody rename these two to *something else*! + rolling_something = sq_rolling_sum - (rolling_sum_sq / rolling_count) + denominator = s._with_native(fn.max_horizontal((rolling_count - ddof).native, 0)) + result = rolling_something.zip_with(keep, None) / denominator + return result.slice(offset) if offset else result + + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 + ) -> Self: + return self.rolling_var( + window_size, min_samples=min_samples, center=center, ddof=ddof + ).pow(0.5) + + def zip_with(self, mask: Self, other: Self | None) -> Self: + predicate = mask.native.combine_chunks() + right = other.native if other is not None else other + return self._with_native(fn.when_then(predicate, self.native, right)) + + def all(self) -> bool: + return fn.all_(self.native).as_py() + + def any(self) -> bool: + return fn.any_(self.native).as_py() + + def sum(self) -> float: + result: float = fn.sum_(self.native).as_py() + return result + + def count(self) -> int: + return fn.count(self.native).as_py() + + def unique(self, *, maintain_order: bool = False) -> Self: + return self._with_native(self.native.unique()) + + def drop_nulls(self) -> Self: + return self._with_native(self.native.drop_null()) + + def drop_nans(self) -> Self: + predicate: Incomplete = fn.is_not_nan(self.native) + return self._with_native( + self.native.filter(predicate, null_selection_behavior="emit_null") + ) + + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Self: + exploder = fn.ExplodeBuilder(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + return self._with_native(exploder.explode(self.native)) + + def first(self) -> PythonLiteral: + return self.native[0].as_py() if len(self) else None + + def last(self) -> PythonLiteral: + ca = self.native + return ca[height - 1].as_py() if (height := len(ca)) else None + + @property + def struct(self) -> SeriesStructNamespace: + return SeriesStructNamespace(self) + + +class SeriesStructNamespace(StructNamespace["DataFrame", ArrowSeries]): + def __init__(self, compliant: ArrowSeries, /) -> None: + self._compliant: ArrowSeries = compliant + + @property + def compliant(self) -> ArrowSeries: + return self._compliant + + @property + def native(self) -> ChunkedArrayAny: + return self.compliant.native + + def __narwhals_namespace__(self) -> Namespace: + return namespace(self.compliant) + + @property + def version(self) -> Version: + return self.compliant.version + + def with_native(self, native: ChunkedArrayAny, name: str, /) -> ArrowSeries: + return self.compliant.from_native(native, name, version=self.version) + + def unnest(self) -> DataFrame: + native = cast("pa.ChunkedArray[pa.StructScalar]", self.native) + if fn.HAS_FROM_TO_STRUCT_ARRAY: + if len(native): + table = pa.Table.from_struct_array(native) + else: + table = fn.struct_schema(native).empty_table() + else: # pragma: no cover + # NOTE: Too strict, doesn't allow `Array[StructScalar]` + rec_batch: Incomplete = pa.RecordBatch.from_struct_array + batches = (rec_batch(chunk) for chunk in native.chunks) + table = pa.Table.from_batches(batches, fn.struct_schema(native)) + return namespace(self)._dataframe.from_native(table, self.version) + + # name overriding *may* be wrong + def field(self, name: str) -> ArrowSeries: + return self.with_native(fn.struct_field(self.native, name), name) + + @property + def schema(self) -> Schema: + return Schema.from_arrow(fn.struct_schema(self.native)) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index ad2d42cb16..c2befc214b 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -10,14 +10,16 @@ if TYPE_CHECKING: import pyarrow as pa import pyarrow.compute as pc + from pyarrow import lib, types from pyarrow.lib import ( + BoolType as BoolType, Date32Type, Int8Type, Int16Type, Int32Type, Int64Type, - LargeStringType as LargeStringType, - StringType as StringType, + LargeStringType as _LargeStringType, + StringType as _StringType, Uint8Type, Uint16Type, Uint32Type, @@ -28,10 +30,29 @@ from narwhals._native import NativeDataFrame, NativeSeries from narwhals.typing import SizedMultiIndexSelector as _SizedMultiIndexSelector - StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]" + UInt32Type: TypeAlias = "Uint32Type" + StringType: TypeAlias = "_StringType | _LargeStringType" IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" + StringScalar: TypeAlias = "Scalar[StringType]" IntegerScalar: TypeAlias = "Scalar[IntegerType]" DateScalar: TypeAlias = "Scalar[Date32Type]" + ListScalar: TypeAlias = "Scalar[pa.ListType[DataTypeT_co]]" + BooleanScalar: TypeAlias = "Scalar[BoolType]" + """Only use this for a parameter type, not as a return type!""" + NumericScalar: TypeAlias = "pc.NumericScalar" + + PrimitiveNumericType: TypeAlias = "types._Integer | types._Floating" + NumericType: TypeAlias = "PrimitiveNumericType | types._Decimal" + NumericOrTemporalType: TypeAlias = "NumericType | types._Temporal" + StringOrBinaryType: TypeAlias = "StringType | lib.StringViewType | lib.BinaryType | lib.LargeBinaryType | lib.BinaryViewType" + BasicType: TypeAlias = ( + "NumericOrTemporalType | StringOrBinaryType | BoolType | lib.NullType" + ) + NonListNestedType: TypeAlias = "pa.StructType | pa.DictionaryType[Any, Any] | pa.MapType[Any, Any] | pa.UnionType" + NonListType: TypeAlias = "BasicType | NonListNestedType" + NestedType: TypeAlias = "NonListNestedType | pa.ListType[Any]" + NonListTypeT = TypeVar("NonListTypeT", bound="NonListType") + ListTypeT = TypeVar("ListTypeT", bound="pa.ListType[Any]") class NativeArrowSeries(NativeSeries, Protocol): @property @@ -44,9 +65,21 @@ def columns(self) -> Sequence[NativeArrowSeries]: ... P = ParamSpec("P") + class UnaryFunctionP(Protocol[P]): + """A function wrapping at-most 1 `Expr` input.""" + + def __call__( + self, native: ChunkedOrScalarAny, /, *args: P.args, **kwds: P.kwargs + ) -> ChunkedOrScalarAny: ... + class VectorFunction(Protocol[P]): def __call__( - self, native: ChunkedArrayAny, *args: P.args, **kwds: P.kwargs + self, native: ChunkedArrayAny, /, *args: P.args, **kwds: P.kwargs + ) -> ChunkedArrayAny: ... + + class BooleanLengthPreserving(Protocol): + def __call__( + self, indices: ChunkedArrayAny, aggregated: ChunkedArrayAny, / ) -> ChunkedArrayAny: ... @@ -62,66 +95,87 @@ def __call__( ) NumericOrTemporalScalar: TypeAlias = "pc.NumericOrTemporalScalar" NumericOrTemporalScalarT = TypeVar( - "NumericOrTemporalScalarT", - bound=NumericOrTemporalScalar, - default=NumericOrTemporalScalar, + "NumericOrTemporalScalarT", bound=NumericOrTemporalScalar, default="NumericScalar" ) class UnaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): @overload - def __call__(self, data: ScalarPT_contra, *args: Any, **kwds: Any) -> ScalarRT_co: ... + def __call__( + self, data: ScalarPT_contra, /, *args: Any, **kwds: Any + ) -> ScalarRT_co: ... @overload def __call__( - self, data: ChunkedArray[ScalarPT_contra], *args: Any, **kwds: Any + self, data: ChunkedArray[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> ChunkedArray[ScalarRT_co]: ... @overload def __call__( - self, data: ChunkedOrScalar[ScalarPT_contra], *args: Any, **kwds: Any + self, data: ChunkedOrScalar[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> ChunkedOrScalar[ScalarRT_co]: ... @overload def __call__( - self, data: Array[ScalarPT_contra], *args: Any, **kwds: Any + self, data: Array[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> Array[ScalarRT_co]: ... @overload def __call__( - self, data: ChunkedOrArray[ScalarPT_contra], *args: Any, **kwds: Any + self, data: ChunkedOrArray[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> ChunkedOrArray[ScalarRT_co]: ... def __call__( - self, data: Arrow[ScalarPT_contra], *args: Any, **kwds: Any + self, data: Arrow[ScalarPT_contra], /, *args: Any, **kwds: Any ) -> Arrow[ScalarRT_co]: ... class BinaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): + @overload + def __call__( + self, x: ChunkedArray[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / + ) -> ChunkedArray[ScalarRT_co]: ... + @overload + def __call__( + self, x: Array[ScalarPT_contra], y: Array[ScalarPT_contra], / + ) -> Array[ScalarRT_co]: ... @overload def __call__(self, x: ScalarPT_contra, y: ScalarPT_contra, /) -> ScalarRT_co: ... - @overload def __call__( - self, x: ChunkedArray[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / + self, x: ChunkedArray[ScalarPT_contra], y: ScalarPT_contra, / ) -> ChunkedArray[ScalarRT_co]: ... - + @overload + def __call__( + self, x: Array[ScalarPT_contra], y: ScalarPT_contra, / + ) -> Array[ScalarRT_co]: ... @overload def __call__( self, x: ScalarPT_contra, y: ChunkedArray[ScalarPT_contra], / ) -> ChunkedArray[ScalarRT_co]: ... - @overload def __call__( - self, x: ChunkedArray[ScalarPT_contra], y: ScalarPT_contra, / + self, x: ScalarPT_contra, y: Array[ScalarPT_contra], / + ) -> Array[ScalarRT_co]: ... + @overload + def __call__( + self, x: ChunkedArray[ScalarPT_contra], y: Array[ScalarPT_contra], / + ) -> ChunkedArray[ScalarRT_co]: ... + @overload + def __call__( + self, x: Array[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / ) -> ChunkedArray[ScalarRT_co]: ... - @overload def __call__( self, x: ChunkedOrScalar[ScalarPT_contra], y: ChunkedOrScalar[ScalarPT_contra], / ) -> ChunkedOrScalar[ScalarRT_co]: ... + @overload def __call__( - self, x: ChunkedOrScalar[ScalarPT_contra], y: ChunkedOrScalar[ScalarPT_contra], / - ) -> ChunkedOrScalar[ScalarRT_co]: ... + self, x: Arrow[ScalarPT_contra], y: Arrow[ScalarPT_contra], / + ) -> Arrow[ScalarRT_co]: ... + + def __call__( + self, x: Arrow[ScalarPT_contra], y: Arrow[ScalarPT_contra], / + ) -> Arrow[ScalarRT_co]: ... class BinaryComp( @@ -129,12 +183,13 @@ class BinaryComp( ): ... -class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Protocol): ... +class BinaryLogical(BinaryFunction["BooleanScalar", "pa.BooleanScalar"], Protocol): ... BinaryNumericTemporal: TypeAlias = BinaryFunction[ NumericOrTemporalScalarT, NumericOrTemporalScalarT ] +UnaryNumeric: TypeAlias = UnaryFunction["NumericScalar", "NumericScalar"] DataType: TypeAlias = "pa.DataType" DataTypeT = TypeVar("DataTypeT", bound=DataType, default=Any) DataTypeT_co = TypeVar("DataTypeT_co", bound=DataType, covariant=True, default=Any) @@ -150,13 +205,27 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ChunkedOrScalarAny: TypeAlias = "ChunkedOrScalar[ScalarAny]" ChunkedOrArrayAny: TypeAlias = "ChunkedOrArray[ScalarAny]" ChunkedOrArrayT = TypeVar("ChunkedOrArrayT", ChunkedArrayAny, ArrayAny) +ChunkedOrScalarT = TypeVar("ChunkedOrScalarT", ChunkedArrayAny, ScalarAny) Indices: TypeAlias = "_SizedMultiIndexSelector[ChunkedOrArray[pc.IntegerScalar]]" +ChunkedStruct: TypeAlias = "ChunkedArray[pa.StructScalar]" +StructArray: TypeAlias = "pa.StructArray | Array[pa.StructScalar]" +ChunkedList: TypeAlias = "ChunkedArray[ListScalar[DataTypeT_co]]" +ListArray: TypeAlias = "Array[ListScalar[DataTypeT_co]]" + Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" +SameArrowT = TypeVar("SameArrowT", ChunkedArrayAny, ArrayAny, ScalarAny) +ArrowT = TypeVar("ArrowT", bound=ArrowAny) +ArrowListT = TypeVar("ArrowListT", bound="Arrow[ListScalar[Any]]") +Predicate: TypeAlias = "Arrow[BooleanScalar]" +"""Any `pyarrow` container that wraps boolean.""" + NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] -StoresNativeT_co = TypeVar("StoresNativeT_co", bound=StoresNative[Any], covariant=True) +StoresNativeT_co = TypeVar( + "StoresNativeT_co", bound=StoresNative[ChunkedOrScalarAny], covariant=True +) DataTypeRemap: TypeAlias = Mapping[DataType, DataType] NullPlacement: TypeAlias = Literal["at_start", "at_end"] diff --git a/narwhals/_plan/compliant/accessors.py b/narwhals/_plan/compliant/accessors.py new file mode 100644 index 0000000000..26df3ff3ba --- /dev/null +++ b/narwhals/_plan/compliant/accessors.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +from narwhals._plan.compliant.typing import ( + DataFrameT_co, + ExprT_co, + FrameT_contra, + SeriesT_co, +) + +if TYPE_CHECKING: + from narwhals._plan.expressions import FunctionExpr as FExpr, lists, strings + from narwhals._plan.expressions.categorical import GetCategories + from narwhals._plan.expressions.struct import FieldByName + from narwhals.schema import Schema + + +class ExprCatNamespace(Protocol[FrameT_contra, ExprT_co]): + def get_categories( + self, node: FExpr[GetCategories], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + + +class ExprListNamespace(Protocol[FrameT_contra, ExprT_co]): + def contains( + self, node: FExpr[lists.Contains], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def get( + self, node: FExpr[lists.Get], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def len( + self, node: FExpr[lists.Len], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def unique( + self, node: FExpr[lists.Unique], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def join( + self, node: FExpr[lists.Join], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + + +class ExprStringNamespace(Protocol[FrameT_contra, ExprT_co]): + def contains( + self, node: FExpr[strings.Contains], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def ends_with( + self, node: FExpr[strings.EndsWith], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def len_chars( + self, node: FExpr[strings.LenChars], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def replace( + self, node: FExpr[strings.Replace], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def replace_all( + self, node: FExpr[strings.ReplaceAll], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def slice( + self, node: FExpr[strings.Slice], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def split( + self, node: FExpr[strings.Split], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def starts_with( + self, node: FExpr[strings.StartsWith], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def strip_chars( + self, node: FExpr[strings.StripChars], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_uppercase( + self, node: FExpr[strings.ToUppercase], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_lowercase( + self, node: FExpr[strings.ToLowercase], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_titlecase( + self, node: FExpr[strings.ToTitlecase], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_date( + self, node: FExpr[strings.ToDate], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def to_datetime( + self, node: FExpr[strings.ToDatetime], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + def zfill( + self, node: FExpr[strings.ZFill], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + + +class ExprStructNamespace(Protocol[FrameT_contra, ExprT_co]): + def field( + self, node: FExpr[FieldByName], frame: FrameT_contra, name: str + ) -> ExprT_co: ... + + +class SeriesStructNamespace(Protocol[DataFrameT_co, SeriesT_co]): + def field(self, name: str) -> SeriesT_co: ... + def unnest(self) -> DataFrameT_co: ... + @property + def schema(self) -> Schema: ... diff --git a/narwhals/_plan/compliant/column.py b/narwhals/_plan/compliant/column.py index 2669a598db..598f462286 100644 --- a/narwhals/_plan/compliant/column.py +++ b/narwhals/_plan/compliant/column.py @@ -42,16 +42,25 @@ def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT: @classmethod def _length_required( - cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / + cls, + exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], + /, + default: LengthT | None = None, ) -> LengthT | None: """Return the broadcast length, if all lengths do not equal the maximum.""" @classmethod def align( - cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]] + cls, + *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]], + default: LengthT | None = None, ) -> Iterator[SeriesT]: + """Yield broadcasted `Scalar`s and unwrapped `Expr`s from `exprs`. + + `default` must be provided when operating in a `with_columns` context. + """ exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs)) - length = cls._length_required(exprs) + length = default if len(exprs) == 1 else cls._length_required(exprs, default) if length is None: for e in exprs: yield e.to_series() @@ -85,12 +94,15 @@ def _length_max(cls, lengths: Sequence[int], /) -> int: @classmethod def _length_required( - cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], / + cls, + exprs: Sequence[SupportsBroadcast[SeriesT, int]], + /, + default: int | None = None, ) -> int | None: lengths = cls._length_all(exprs) max_length = cls._length_max(lengths) required = any(len_ != max_length for len_ in lengths) - return max_length if required else None + return max_length if required else default class ExprDispatch(HasVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index c439cda262..1e953d9c38 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -6,6 +6,7 @@ from narwhals._plan.compliant.group_by import Grouped from narwhals._plan.compliant.typing import ColumnT_co, HasVersion, SeriesT from narwhals._plan.typing import ( + IncompleteCyclic, IntoExpr, NativeDataFrameT, NativeFrameT_co, @@ -30,7 +31,7 @@ from narwhals._plan.compliant.namespace import EagerNamespace from narwhals._plan.dataframe import BaseFrame, DataFrame from narwhals._plan.expressions import NamedIR - from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.options import ExplodeOptions, SortMultipleOptions from narwhals._plan.typing import Seq from narwhals._typing import _EagerAllowedImpl from narwhals._utils import Implementation, Version @@ -43,7 +44,7 @@ class CompliantFrame(HasVersion, Protocol[ColumnT_co, NativeFrameT_co]): implementation: ClassVar[Implementation] - def __narwhals_namespace__(self) -> Any: ... + def __narwhals_namespace__(self) -> IncompleteCyclic: ... def _evaluate_irs( self, nodes: Iterable[NamedIR[ir.ExprIR]], / ) -> Iterator[ColumnT_co]: ... @@ -59,6 +60,7 @@ def to_narwhals(self) -> BaseFrame[NativeFrameT_co]: ... def columns(self) -> list[str]: ... def drop(self, columns: Sequence[str]) -> Self: ... def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... + def explode(self, subset: Sequence[str], options: ExplodeOptions) -> Self: ... # Shouldn't *need* to be `NamedIR`, but current impl depends on a name being passed around def filter(self, predicate: NamedIR, /) -> Self: ... def rename(self, mapping: Mapping[str, str]) -> Self: ... @@ -80,6 +82,8 @@ class CompliantDataFrame( implementation: ClassVar[_EagerAllowedImpl] _native: NativeDataFrameT + @property + def shape(self) -> tuple[int, int]: ... def __len__(self) -> int: ... @property def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... @@ -105,6 +109,7 @@ def native(self) -> NativeDataFrameT: def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... + def gather_every(self, n: int, offset: int = 0) -> Self: ... def get_column(self, name: str) -> SeriesT: ... def group_by_agg( self, by: OneOrIterable[IntoExpr], aggs: OneOrIterable[IntoExpr], / @@ -135,6 +140,7 @@ def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Se return self._group_by.from_resolver(self, resolver) def filter(self, predicate: NamedIR, /) -> Self: ... + def iter_columns(self) -> Iterator[SeriesT]: ... def join( self, other: Self, @@ -166,9 +172,19 @@ def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: return DataFrame[NativeDataFrameT, NativeSeriesT](self) def to_series(self, index: int = 0) -> SeriesT: ... + def to_struct(self, name: str = "") -> SeriesT: ... def to_polars(self) -> pl.DataFrame: ... def with_row_index(self, name: str) -> Self: ... def slice(self, offset: int, length: int | None = None) -> Self: ... + def sample_frac( + self, fraction: float, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: + n = int(len(self) * fraction) + return self.sample_n(n, with_replacement=with_replacement, seed=seed) + + def sample_n( + self, n: int = 1, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: ... class EagerDataFrame( @@ -178,6 +194,9 @@ class EagerDataFrame( def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... @property def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... + def _evaluate_irs( + self, nodes: Iterable[NamedIR[ir.ExprIR]], /, *, length: int | None = None + ) -> Iterator[SeriesT]: ... def group_by_resolver( self, resolver: GroupByResolver, / @@ -188,7 +207,9 @@ def select(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) def with_columns(self, irs: Seq[NamedIR]) -> Self: - return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) + return self.__narwhals_namespace__()._concat_horizontal( + self._evaluate_irs(irs, length=len(self)) + ) def to_series(self, index: int = 0) -> SeriesT: return self.get_column(self.columns[index]) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index aead4a8e1a..39a7a912a1 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Literal, Protocol from narwhals._plan.compliant.column import EagerBroadcast, SupportsBroadcast from narwhals._plan.compliant.typing import ( @@ -13,10 +13,16 @@ from narwhals._utils import Version if TYPE_CHECKING: - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan import expressions as ir - from narwhals._plan.compliant.scalar import CompliantScalar + from narwhals._plan.compliant.accessors import ( + ExprCatNamespace, + ExprListNamespace, + ExprStringNamespace, + ExprStructNamespace, + ) + from narwhals._plan.compliant.scalar import CompliantScalar, EagerScalar from narwhals._plan.expressions import ( BinaryExpr, FunctionExpr, @@ -30,11 +36,12 @@ IsFirstDistinct, IsLastDistinct, IsNan, + IsNotNan, + IsNotNull, IsNull, Not, ) - -Incomplete: TypeAlias = Any + from narwhals._plan.typing import IncompleteCyclic class CompliantExpr(HasVersion, Protocol[FrameT_contra, SeriesT_co]): @@ -62,6 +69,9 @@ def ewm_mean( def fill_null( self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str ) -> Self: ... + def fill_nan( + self, node: FunctionExpr[F.FillNan], frame: FrameT_contra, name: str + ) -> Self: ... def is_between( self, node: FunctionExpr[IsBetween], frame: FrameT_contra, name: str ) -> Self: ... @@ -80,8 +90,11 @@ def is_nan( def is_null( self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str ) -> Self: ... - def map_batches( - self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str + def is_not_nan( + self, node: FunctionExpr[IsNotNan], frame: FrameT_contra, name: str + ) -> Self: ... + def is_not_null( + self, node: FunctionExpr[IsNotNull], frame: FrameT_contra, name: str ) -> Self: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... @@ -187,10 +200,16 @@ def skew( self, node: FunctionExpr[F.Skew], frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... - # mixed/todo + # TODO @dangotbanned: Reorder these def clip( self, node: FunctionExpr[F.Clip], frame: FrameT_contra, name: str ) -> Self: ... + def clip_lower( + self, node: FunctionExpr[F.ClipLower], frame: FrameT_contra, name: str + ) -> Self: ... + def clip_upper( + self, node: FunctionExpr[F.ClipUpper], frame: FrameT_contra, name: str + ) -> Self: ... def drop_nulls( self, node: FunctionExpr[F.DropNulls], frame: FrameT_contra, name: str ) -> Self: ... @@ -217,12 +236,18 @@ def is_unique( self, node: FunctionExpr[boolean.IsUnique], frame: FrameT_contra, name: str ) -> Self: ... def log(self, node: FunctionExpr[F.Log], frame: FrameT_contra, name: str) -> Self: ... - def mode( - self, node: FunctionExpr[F.Mode], frame: FrameT_contra, name: str + def mode_all( + self, node: FunctionExpr[F.ModeAll], frame: FrameT_contra, name: str ) -> Self: ... + def mode_any( + self, node: FunctionExpr[F.ModeAny], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... def replace_strict( self, node: FunctionExpr[F.ReplaceStrict], frame: FrameT_contra, name: str ) -> Self: ... + def replace_strict_default( + self, node: FunctionExpr[F.ReplaceStrictDefault], frame: FrameT_contra, name: str + ) -> Self: ... def round( self, node: FunctionExpr[F.Round], frame: FrameT_contra, name: str ) -> Self: ... @@ -232,6 +257,33 @@ def sqrt( def unique( self, node: FunctionExpr[F.Unique], frame: FrameT_contra, name: str ) -> Self: ... + def ceil( + self, node: FunctionExpr[F.Ceil], frame: FrameT_contra, name: str + ) -> Self: ... + def floor( + self, node: FunctionExpr[F.Floor], frame: FrameT_contra, name: str + ) -> Self: ... + @property + def cat( + self, + ) -> ExprCatNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... + @property + def list( + self, + ) -> ExprListNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... + @property + def str( + self, + ) -> ExprStringNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... + @property + def struct( + self, + ) -> ExprStructNamespace[FrameT_contra, CompliantExpr[FrameT_contra, SeriesT_co]]: ... + + # NOTE: This test has a case for detecting `Expr` impl, but missing `CompliantExpr` member + # `tests/plan/dispatch_test.py::test_dispatch` + # TODO @dangotbanned: Update that logic when `dt` namespace is actually implemented + # dt: not_implemented = not_implemented()` class EagerExpr( @@ -244,10 +296,25 @@ def gather_every( ) -> Self: ... def is_in_series( self, - node: FunctionExpr[boolean.IsInSeries[Incomplete]], + node: FunctionExpr[boolean.IsInSeries[IncompleteCyclic]], frame: FrameT_contra, name: str, ) -> Self: ... + # NOTE: `Scalar` when using `returns_scalar=True` + def map_batches( + self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str + ) -> Self | EagerScalar[FrameT_contra, SeriesT]: ... + # NOTE: `n=1` can behave similar to an aggregation in `select(...)`, but requires `.first()` + # to trigger broadcasting in `with_columns(...)` + def sample_n( + self, node: FunctionExpr[F.SampleN], frame: FrameT_contra, name: str + ) -> Self: ... + def sample_frac( + self, node: FunctionExpr[F.SampleFrac], frame: FrameT_contra, name: str + ) -> Self: ... + def __bool__(self) -> Literal[True]: + # NOTE: Avoids falling back to `__len__` when truth-testing on dispatch + return True class LazyExpr( diff --git a/narwhals/_plan/compliant/namespace.py b/narwhals/_plan/compliant/namespace.py index 9fa62aad93..a5a132e0c5 100644 --- a/narwhals/_plan/compliant/namespace.py +++ b/narwhals/_plan/compliant/namespace.py @@ -26,7 +26,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import FunctionExpr, boolean, functions as F - from narwhals._plan.expressions.ranges import DateRange, IntRange + from narwhals._plan.expressions.ranges import DateRange, IntRange, LinearSpace from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.series import Series from narwhals.dtypes import IntegerType @@ -54,12 +54,18 @@ def col(self, node: ir.Column, frame: FrameT, name: str) -> ExprT_co: ... def concat_str( self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... + def coalesce( + self, node: FunctionExpr[F.Coalesce], frame: FrameT, name: str + ) -> ExprT_co | ScalarT_co: ... def date_range( self, node: ir.RangeExpr[DateRange], frame: FrameT, name: str ) -> ExprT_co: ... def int_range( self, node: ir.RangeExpr[IntRange], frame: FrameT, name: str ) -> ExprT_co: ... + def linear_space( + self, node: ir.RangeExpr[LinearSpace], frame: FrameT, name: str + ) -> ExprT_co: ... def len(self, node: ir.Len, frame: FrameT, name: str) -> ScalarT_co: ... def lit( self, node: ir.Literal[Any], frame: FrameT, name: str @@ -156,6 +162,15 @@ def int_range_eager( dtype: IntegerType = Int64, name: str = "literal", ) -> SeriesT: ... + def linear_space_eager( + self, + start: float, + end: float, + num_samples: int, + *, + closed: ClosedInterval = "both", + name: str = "literal", + ) -> SeriesT: ... class LazyNamespace( diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index 3e86327240..ca4562c4ab 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Literal, Protocol from narwhals._plan.compliant.expr import CompliantExpr, EagerExpr, LazyExpr from narwhals._plan.compliant.typing import FrameT_contra, LengthT, SeriesT, SeriesT_co @@ -10,7 +10,11 @@ from typing_extensions import Self from narwhals._plan import expressions as ir - from narwhals._plan.expressions import FunctionExpr, aggregation as agg + from narwhals._plan.expressions import ( + FunctionExpr, + aggregation as agg, + functions as F, + ) from narwhals._plan.expressions.functions import EwmMean, NullCount, Shift from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -102,6 +106,12 @@ def shift(self, node: FunctionExpr[Shift], frame: FrameT_contra, name: str) -> S return self._with_evaluated(self._evaluated, name) return self.from_python(None, name, dtype=None, version=self.version) + def drop_nulls( # type: ignore[override] + self, node: FunctionExpr[F.DropNulls], frame: FrameT_contra, name: str + ) -> Self | CompliantExpr[FrameT_contra, SeriesT_co]: + """Returns a 0-length Series if null, else noop.""" + ... + arg_max = _always_zero # type: ignore[misc] arg_min = _always_zero # type: ignore[misc] is_first_distinct = _always_true # type: ignore[misc] @@ -120,6 +130,8 @@ def shift(self, node: FunctionExpr[Shift], frame: FrameT_contra, name: str) -> S sum = _always_noop # type: ignore[misc] mode = _always_noop # type: ignore[misc] unique = _always_noop # type: ignore[misc] + mode_all = not_implemented() # type: ignore[misc] + mode_any = _always_noop # type: ignore[misc] kurtosis = _always_nan # type: ignore[misc] skew = _always_nan # type: ignore[misc] fill_null_with_strategy = not_implemented() # type: ignore[misc] @@ -136,9 +148,17 @@ class EagerScalar( def __len__(self) -> int: return 1 + def __bool__(self) -> Literal[True]: + # NOTE: Avoids falling back to `__len__` when truth-testing on dispatch + return True + def to_python(self) -> PythonLiteral: ... gather_every = not_implemented() # type: ignore[misc] + # NOTE: `n=1` and `fraction=1.0` *could* be special-cased here + # but seems low-priority for a deprecated method + sample_n = not_implemented() # type: ignore[misc] + sample_frac = not_implemented() # type: ignore[misc] class LazyScalar( diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index f9c33523ff..faf14bf2f2 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -1,23 +1,34 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Protocol +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol from narwhals._plan.compliant.typing import HasVersion -from narwhals._plan.typing import NativeSeriesT -from narwhals._utils import Version, _StoresNative +from narwhals._plan.typing import IncompleteCyclic, NativeSeriesT +from narwhals._utils import Version, _StoresNative, unstable if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Sequence import polars as pl - from typing_extensions import Self, TypeAlias + from typing_extensions import Self + from narwhals._plan.compliant.accessors import SeriesStructNamespace + from narwhals._plan.compliant.dataframe import CompliantDataFrame + from narwhals._plan.dataframe import DataFrame from narwhals._plan.series import Series from narwhals._typing import _EagerAllowedImpl from narwhals.dtypes import DType - from narwhals.typing import Into1DArray, IntoDType, SizedMultiIndexSelector, _1DArray - -Incomplete: TypeAlias = Any + from narwhals.typing import ( + FillNullStrategy, + Into1DArray, + IntoDType, + NonNestedLiteral, + NumericLiteral, + PythonLiteral, + SizedMultiIndexSelector, + TemporalLiteral, + _1DArray, + ) class CompliantSeries(HasVersion, Protocol[NativeSeriesT]): @@ -28,10 +39,46 @@ class CompliantSeries(HasVersion, Protocol[NativeSeriesT]): def __len__(self) -> int: return len(self.native) + def __add__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __and__(self, other: bool | Self, /) -> Self: ... + def __eq__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... # type: ignore[override] + def __floordiv__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __ge__(self, other: NonNestedLiteral | Self, /) -> Self: ... + def __gt__(self, other: NonNestedLiteral | Self, /) -> Self: ... + def __invert__(self) -> Self: ... + def __le__(self, other: NonNestedLiteral | Self, /) -> Self: ... + def __lt__(self, other: NonNestedLiteral | Self, /) -> Self: ... + def __mod__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __mul__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __ne__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... # type: ignore[override] + def __or__(self, other: bool | Self, /) -> Self: ... + def __pow__(self, other: float | Self, /) -> Self: ... + def __rfloordiv__( + self, other: NumericLiteral | TemporalLiteral | Self, / + ) -> Self: ... + def __radd__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __rand__(self, other: bool | Self, /) -> Self: ... + def __rmod__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __rmul__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __ror__(self, other: bool | Self, /) -> Self: ... + def __rpow__(self, other: float | Self, /) -> Self: ... + def __rsub__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __rtruediv__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __rxor__(self, other: bool | Self, /) -> Self: ... + def __sub__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __truediv__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: ... + def __xor__(self, other: bool | Self, /) -> Self: ... + def len(self) -> int: return len(self.native) - def __narwhals_namespace__(self) -> Incomplete: ... + def not_(self) -> Self: + return self.__invert__() + + def pow(self, exponent: float | Self) -> Self: + return self.__pow__(exponent) + + def __narwhals_namespace__(self) -> IncompleteCyclic: ... def __narwhals_series__(self) -> Self: return self @@ -71,23 +118,75 @@ def name(self) -> str: def native(self) -> NativeSeriesT: return self._native + def all(self) -> bool: ... + def any(self) -> bool: ... + def sum(self) -> float: ... + def count(self) -> int: ... def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) def cast(self, dtype: IntoDType) -> Self: ... + def cum_count(self, *, reverse: bool = False) -> Self: ... + def cum_max(self, *, reverse: bool = False) -> Self: ... + def cum_min(self, *, reverse: bool = False) -> Self: ... + def cum_prod(self, *, reverse: bool = False) -> Self: ... + def cum_sum(self, *, reverse: bool = False) -> Self: ... + def diff(self, n: int = 1) -> Self: ... + def drop_nulls(self) -> Self: ... + def drop_nans(self) -> Self: ... + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Self: ... + def fill_nan(self, value: float | Self | None) -> Self: ... + def fill_null(self, value: NonNestedLiteral | Self) -> Self: ... + def fill_null_with_strategy( + self, strategy: FillNullStrategy, limit: int | None = None + ) -> Self: ... + def first(self) -> PythonLiteral: ... + def shift(self, n: int, *, fill_value: NonNestedLiteral = None) -> Self: ... def gather( self, indices: SizedMultiIndexSelector[NativeSeriesT] | _StoresNative[NativeSeriesT], ) -> Self: ... + def gather_every(self, n: int, offset: int = 0) -> Self: ... def has_nulls(self) -> bool: ... + def null_count(self) -> int: ... def is_empty(self) -> bool: return len(self) == 0 def is_in(self, other: Self) -> Self: ... + def is_nan(self) -> Self: ... + def is_null(self) -> Self: ... + def is_not_nan(self) -> Self: + return self.is_nan().__invert__() + + def is_not_null(self) -> Self: + return self.is_null().__invert__() + + def last(self) -> PythonLiteral: ... + def rolling_mean( + self, window_size: int, *, min_samples: int, center: bool = False + ) -> Self: ... + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 + ) -> Self: ... + def rolling_sum( + self, window_size: int, *, min_samples: int, center: bool = False + ) -> Self: ... + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool = False, ddof: int = 1 + ) -> Self: ... + def sample_frac( + self, fraction: float, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: + n = int(len(self) * fraction) + return self.sample_n(n, with_replacement=with_replacement, seed=seed) + + def sample_n( + self, n: int = 1, *, with_replacement: bool = False, seed: int | None = None + ) -> Self: ... def scatter(self, indices: Self, values: Self) -> Self: ... def slice(self, offset: int, length: int | None = None) -> Self: ... def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: ... - def to_frame(self) -> Incomplete: ... + def to_frame(self) -> IncompleteCyclic: ... def to_list(self) -> list[Any]: ... def to_narwhals(self) -> Series[NativeSeriesT]: from narwhals._plan.series import Series @@ -96,3 +195,39 @@ def to_narwhals(self) -> Series[NativeSeriesT]: def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... def to_polars(self) -> pl.Series: ... + def unique(self, *, maintain_order: bool = False) -> Self: ... + def zip_with(self, mask: Self, other: Self) -> Self: ... + @unstable + def hist( + self, + bins: Sequence[float] | None = None, + *, + bin_count: int | None = None, + include_breakpoint: bool = True, + include_category: bool = False, + _compatibility_behavior: Literal["narwhals", "polars"] = "narwhals", + ) -> CompliantDataFrame[Self, IncompleteCyclic, NativeSeriesT]: + from narwhals._plan.expressions import col as ir_col + + expr = ( + ir_col(self.name) + .to_narwhals(self.version) + .hist( + bins, + bin_count=bin_count, + include_breakpoint=include_breakpoint, + include_category=include_category, + ) + ) + df: DataFrame[IncompleteCyclic, NativeSeriesT] = ( + self.to_narwhals().to_frame().select(expr) + ) + if not include_breakpoint and not include_category: + if _compatibility_behavior == "narwhals": + df = df.rename({self.name: "count"}) + else: + df = df.to_series().struct.unnest() + return df._compliant + + @property + def struct(self) -> SeriesStructNamespace[IncompleteCyclic, Self]: ... diff --git a/narwhals/_plan/compliant/typing.py b/narwhals/_plan/compliant/typing.py index 91ad9320ed..01daea4c40 100644 --- a/narwhals/_plan/compliant/typing.py +++ b/narwhals/_plan/compliant/typing.py @@ -20,17 +20,11 @@ from narwhals._plan.compliant.series import CompliantSeries from narwhals._utils import Version -T = TypeVar("T") R_co = TypeVar("R_co", covariant=True) LengthT = TypeVar("LengthT") -NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) - ConcatT1 = TypeVar("ConcatT1") ConcatT2 = TypeVar("ConcatT2", default=ConcatT1) - -ColumnT = TypeVar("ColumnT") ColumnT_co = TypeVar("ColumnT_co", covariant=True) - ResolverT_co = TypeVar("ResolverT_co", bound="GroupByResolver", covariant=True) ExprAny: TypeAlias = "CompliantExpr[Any, Any]" @@ -48,7 +42,6 @@ LazyScalarAny: TypeAlias = "LazyScalar[Any, Any, Any]" ExprT_co = TypeVar("ExprT_co", bound=ExprAny, covariant=True) -ScalarT = TypeVar("ScalarT", bound=ScalarAny) ScalarT_co = TypeVar("ScalarT_co", bound=ScalarAny, covariant=True) SeriesT = TypeVar("SeriesT", bound=SeriesAny) SeriesT_co = TypeVar("SeriesT_co", bound=SeriesAny, covariant=True) @@ -56,6 +49,7 @@ FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) DataFrameT = TypeVar("DataFrameT", bound=DataFrameAny) +DataFrameT_co = TypeVar("DataFrameT_co", bound=DataFrameAny, covariant=True) NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 59a22c5438..08f943978e 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -4,18 +4,21 @@ from narwhals._plan import _parse from narwhals._plan._expansion import expand_selector_irs_names, prepare_projection +from narwhals._plan._guards import is_series from narwhals._plan.common import ensure_seq_str, temp from narwhals._plan.group_by import GroupBy, Grouped -from narwhals._plan.options import SortMultipleOptions +from narwhals._plan.options import ExplodeOptions, SortMultipleOptions from narwhals._plan.series import Series from narwhals._plan.typing import ( ColumnNameOrSelector, + IncompleteCyclic, IntoExpr, IntoExprColumn, NativeDataFrameT, NativeDataFrameT_co, NativeFrameT_co, NativeSeriesT, + NativeSeriesT2, NonCrossJoinStrategy, OneOrIterable, PartialSeries, @@ -23,19 +26,34 @@ ) from narwhals._utils import Implementation, Version, generate_repr from narwhals.dependencies import is_pyarrow_table +from narwhals.exceptions import InvalidOperationError, ShapeError from narwhals.schema import Schema -from narwhals.typing import IntoDType, JoinStrategy +from narwhals.typing import EagerAllowed, IntoBackend, IntoDType, IntoSchema, JoinStrategy if TYPE_CHECKING: - from collections.abc import Iterable, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence import polars as pl import pyarrow as pa from typing_extensions import Self, TypeAlias, TypeIs + from narwhals._native import NativeSeries from narwhals._plan.arrow.typing import NativeArrowDataFrame - from narwhals._plan.compliant.dataframe import CompliantDataFrame, CompliantFrame - from narwhals._typing import _EagerAllowedImpl + from narwhals._plan.compliant.dataframe import ( + CompliantDataFrame, + CompliantFrame, + EagerDataFrame, + ) + from narwhals._plan.compliant.namespace import EagerNamespace + from narwhals._plan.compliant.series import CompliantSeries + from narwhals._typing import Arrow, _EagerAllowedImpl + + EagerNs: TypeAlias = EagerNamespace[ + EagerDataFrame[Any, NativeDataFrameT, Any], + CompliantSeries[NativeSeriesT], + Any, + Any, + ] Incomplete: TypeAlias = Any @@ -70,7 +88,7 @@ def __init__(self, compliant: CompliantFrame[Any, NativeFrameT_co], /) -> None: def _with_compliant(self, compliant: CompliantFrame[Any, Incomplete], /) -> Self: return type(self)(compliant) - def to_native(self) -> NativeFrameT_co: # pragma: no cover + def to_native(self) -> NativeFrameT_co: return self._compliant.native def filter( @@ -141,17 +159,50 @@ def with_row_index( by_names = expand_selector_irs_names(by_selectors, schema=self, require_any=True) return self._with_compliant(self._compliant.with_row_index_by(name, by_names)) + def explode( + self, + columns: OneOrIterable[ColumnNameOrSelector], + *more_columns: ColumnNameOrSelector, + empty_as_null: bool = True, + keep_nulls: bool = True, + ) -> Self: + s_ir = _parse.parse_into_combined_selector_ir(columns, *more_columns) + schema = self.collect_schema() + subset = expand_selector_irs_names((s_ir,), schema=schema, require_any=True) + dtypes = self.version.dtypes + tp_list = dtypes.List + for col_to_explode in subset: + dtype = schema[col_to_explode] + if dtype != tp_list: + msg = f"`explode` operation is not supported for dtype `{dtype}`, expected List type" + raise InvalidOperationError(msg) + options = ExplodeOptions(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + return self._with_compliant(self._compliant.explode(subset, options)) + + +def _dataframe_from_dict( + data: Mapping[str, Any], + schema: IntoSchema | None, + ns: EagerNs[NativeDataFrameT, NativeSeriesT], + /, +) -> DataFrame[NativeDataFrameT, NativeSeriesT]: + return ns._dataframe.from_dict(data, schema=schema).to_narwhals() + class DataFrame( BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT] ): - _compliant: CompliantDataFrame[Any, NativeDataFrameT_co, NativeSeriesT] + _compliant: CompliantDataFrame[IncompleteCyclic, NativeDataFrameT_co, NativeSeriesT] @property def implementation(self) -> _EagerAllowedImpl: return self._compliant.implementation - def __len__(self) -> int: # pragma: no cover + @property + def shape(self) -> tuple[int, int]: + return self._compliant.shape + + def __len__(self) -> int: return len(self._compliant) @property @@ -191,6 +242,63 @@ def from_native( raise NotImplementedError(type(native)) + @overload + @classmethod + def from_dict( + cls: type[DataFrame[Any, Any]], + data: Mapping[str, Any], + schema: IntoSchema | None = ..., + *, + backend: Arrow, + ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: ... + @overload + @classmethod + def from_dict( + cls: type[DataFrame[Any, Any]], + data: Mapping[str, Any], + schema: IntoSchema | None = ..., + *, + backend: IntoBackend[EagerAllowed], + ) -> DataFrame[Any, Any]: ... + @overload + @classmethod + def from_dict( + cls: type[DataFrame[Any, Any]], + data: Mapping[str, Series[NativeSeriesT2]], + schema: IntoSchema | None = ..., + ) -> DataFrame[Any, NativeSeriesT2]: ... + @classmethod + def from_dict( + cls: type[DataFrame[Any, Any]], + data: Mapping[str, Any], + schema: IntoSchema | None = None, + *, + backend: IntoBackend[EagerAllowed] | None = None, + ) -> DataFrame[Any, Any]: + from narwhals._plan import functions as F + + if backend is None: + unwrapped: dict[str, NativeSeries | Any] = {} + impl: _EagerAllowedImpl | None = backend + for k, v in data.items(): + if is_series(v): + current = v.implementation + if impl is None: + impl = current + elif current is not impl: + msg = f"All `Series` must share the same backend, but got:\n -{impl!r}\n -{current!r}" + raise NotImplementedError(msg) + unwrapped[k] = v.to_native() + else: + unwrapped[k] = v + if impl is None: + msg = "Calling `from_dict` without `backend` is only supported if all input values are already Narwhals Series" + raise TypeError(msg) + return _dataframe_from_dict(unwrapped, schema, F._eager_namespace(impl)) + + ns = F._eager_namespace(backend) + return _dataframe_from_dict(data, schema, ns) + @overload def to_dict( self, *, as_series: Literal[True] = ... @@ -211,13 +319,19 @@ def to_dict( } return self._compliant.to_dict(as_series=as_series) - def to_series(self, index: int = 0) -> Series[NativeSeriesT]: # pragma: no cover + def to_series(self, index: int = 0) -> Series[NativeSeriesT]: return self._series(self._compliant.to_series(index)) + def to_struct(self, name: str = "") -> Series[NativeSeriesT]: + return self._series(self._compliant.to_struct(name)) + def to_polars(self) -> pl.DataFrame: return self._compliant.to_polars() - def get_column(self, name: str) -> Series[NativeSeriesT]: # pragma: no cover + def gather_every(self, n: int, offset: int = 0) -> Self: + return self._with_compliant(self._compliant.gather_every(n, offset)) + + def get_column(self, name: str) -> Series[NativeSeriesT]: return self._series(self._compliant.get_column(name)) @overload @@ -246,6 +360,10 @@ def group_by( def row(self, index: int) -> tuple[Any, ...]: return self._compliant.row(index) + def iter_columns(self) -> Iterator[Series[NativeSeriesT]]: + for series in self._compliant.iter_columns(): + yield self._series(series) + def join( self, other: Self, @@ -304,9 +422,34 @@ def with_row_index( return self._with_compliant(self._compliant.with_row_index(name)) return super().with_row_index(name, order_by=order_by) - def slice(self, offset: int, length: int | None = None) -> Self: # pragma: no cover + def slice(self, offset: int, length: int | None = None) -> Self: return type(self)(self._compliant.slice(offset=offset, length=length)) + def sample( + self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + if n is not None and fraction is not None: + msg = "cannot specify both `n` and `fraction`" + raise ValueError(msg) + df = self._compliant + if fraction is not None: + result = df.sample_frac( + fraction, with_replacement=with_replacement, seed=seed + ) + elif n is None: + result = df.sample_n(with_replacement=with_replacement, seed=seed) + elif not with_replacement and n > len(self): + msg = "cannot take a larger sample than the total population when `with_replacement=false`" + raise ShapeError(msg) + else: + result = df.sample_n(n, with_replacement=with_replacement, seed=seed) + return type(self)(result) + def _is_join_strategy(obj: Any) -> TypeIs[JoinStrategy]: return obj in {"inner", "left", "full", "cross", "anti", "semi"} diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 05348372c0..53aaa49f67 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -48,12 +48,29 @@ def function_expr_invalid_operation_error( return InvalidOperationError(msg) +def function_arg_non_scalar_error( + function: Function, arg_name: str, arg_value: Any +) -> InvalidOperationError: + msg = f"`{function!r}({arg_name}=...)` does not support non-scalar expression `{arg_value!r}`." + return InvalidOperationError(msg) + + +def list_literal_error(value: Any) -> TypeError: + msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." + return TypeError(msg) + + # TODO @dangotbanned: Use arguments in error message def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001 msg = "bins must increase monotonically" return ComputeError(msg) +def shape_error(expected_length: int, actual_length: int) -> ShapeError: + msg = f"Expected object of length {expected_length}, got {actual_length}." + return ShapeError(msg) + + def _binary_underline( left: ir.ExprIR, operator: Operator, diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 6b5798f8b1..6c8f5a6154 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -24,7 +24,8 @@ SortOptions, rolling_options, ) -from narwhals._utils import Version +from narwhals._typing_compat import deprecated +from narwhals._utils import Version, no_default, not_implemented from narwhals.exceptions import ComputeError if TYPE_CHECKING: @@ -42,10 +43,12 @@ from narwhals._plan.expressions.temporal import ExprDateTimeNamespace from narwhals._plan.meta import MetaNamespace from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf + from narwhals._typing import NoDefault from narwhals.typing import ( ClosedInterval, FillNullStrategy, IntoDType, + ModeKeepStrategy, NumericLiteral, RankMethod, RollingInterpolationMethod, @@ -192,8 +195,12 @@ def hist( bins: Sequence[float] | None = None, *, bin_count: int | None = None, - include_breakpoint: bool = True, + include_breakpoint: bool = False, + include_category: bool = False, ) -> Self: + if include_category: + msg = f"`Expr.hist({include_category=})` is not yet implemented" + raise NotImplementedError(msg) node: F.Hist if bins is not None: if bin_count is not None: @@ -215,12 +222,21 @@ def exp(self) -> Self: def sqrt(self) -> Self: return self._with_unary(F.Sqrt()) - def kurtosis(self, *, fisher: bool = True, bias: bool = True) -> Self: - return self._with_unary(F.Kurtosis(fisher=fisher, bias=bias)) + def kurtosis(self) -> Self: + return self._with_unary(F.Kurtosis()) def null_count(self) -> Self: return self._with_unary(F.NullCount()) + def fill_nan(self, value: float | Self | None) -> Self: + fill_value = parse_into_expr_ir(value, str_as_lit=True) + root = self._ir + if any(e.meta.has_multiple_outputs() for e in (root, fill_value)): + return self._from_ir(F.FillNan().to_function_expr(root, fill_value)) + # https://github.com/pola-rs/polars/blob/e1d6f294218a36497255e2d872c223e19a47e2ec/crates/polars-plan/src/dsl/mod.rs#L894-L902 + predicate = self.is_not_nan() | self.is_null() + return self._from_ir(ir.ternary_expr(predicate._ir, root, fill_value)) + def fill_null( self, value: IntoExpr = None, @@ -238,8 +254,11 @@ def shift(self, n: int) -> Self: def drop_nulls(self) -> Self: return self._with_unary(F.DropNulls()) - def mode(self) -> Self: - return self._with_unary(F.Mode()) + def mode(self, *, keep: ModeKeepStrategy = "all") -> Self: + if func := {"all": F.ModeAll, "any": F.ModeAny}.get(keep): + return self._with_unary(func()) + msg = f"`keep` must be one of ('all', 'any'), but got {keep!r}" + raise TypeError(msg) def skew(self) -> Self: return self._with_unary(F.Skew()) @@ -253,8 +272,15 @@ def clip( lower_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, upper_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, ) -> Self: - it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) - return self._from_ir(F.Clip().to_function_expr(self._ir, *it)) + f: ir.FunctionExpr + if upper_bound is None: + f = F.ClipLower().to_function_expr(self._ir, parse_into_expr_ir(lower_bound)) + elif lower_bound is None: + f = F.ClipUpper().to_function_expr(self._ir, parse_into_expr_ir(upper_bound)) + else: + it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) + f = F.Clip().to_function_expr(self._ir, *it) + return self._from_ir(f) def cum_count(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumCount(reverse=reverse)) @@ -273,7 +299,7 @@ def cum_sum(self, *, reverse: bool = False) -> Self: def rolling_sum( self, window_size: int, *, min_samples: int | None = None, center: bool = False - ) -> Self: # pragma: no cover + ) -> Self: options = rolling_options(window_size, min_samples, center=center) return self._with_unary(F.RollingSum(options=options)) @@ -290,7 +316,7 @@ def rolling_var( min_samples: int | None = None, center: bool = False, ddof: int = 1, - ) -> Self: # pragma: no cover + ) -> Self: options = rolling_options(window_size, min_samples, center=center, ddof=ddof) return self._with_unary(F.RollingVar(options=options)) @@ -301,7 +327,7 @@ def rolling_std( min_samples: int | None = None, center: bool = False, ddof: int = 1, - ) -> Self: # pragma: no cover + ) -> Self: options = rolling_options(window_size, min_samples, center=center, ddof=ddof) return self._with_unary(F.RollingStd(options=options)) @@ -314,6 +340,12 @@ def unique(self) -> Self: def round(self, decimals: int = 0) -> Self: return self._with_unary(F.Round(decimals=decimals)) + def ceil(self) -> Self: + return self._with_unary(F.Ceil()) + + def floor(self) -> Self: + return self._with_unary(F.Floor()) + def ewm_mean( self, *, @@ -339,13 +371,14 @@ def ewm_mean( def replace_strict( self, old: Sequence[Any] | Mapping[Any, Any], - new: Sequence[Any] | None = None, + new: Sequence[Any] | NoDefault = no_default, *, + default: IntoExpr | NoDefault = no_default, return_dtype: IntoDType | None = None, ) -> Self: before: Seq[Any] after: Seq[Any] - if new is None: + if new is no_default: if not isinstance(old, Mapping): msg = "`new` argument is required if `old` argument is not a Mapping type" raise TypeError(msg) @@ -359,8 +392,15 @@ def replace_strict( after = tuple(new) if return_dtype is not None: return_dtype = common.into_dtype(return_dtype) - function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) - return self._with_unary(function) + + if default is no_default: + function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) + return self._with_unary(function) + function = F.ReplaceStrictDefault( + old=before, new=after, return_dtype=return_dtype + ) + default_ir = parse_into_expr_ir(default, str_as_lit=True) + return self._from_ir(function.to_function_expr(self._ir, default_ir)) def gather_every(self, n: int, offset: int = 0) -> Self: return self._with_unary(F.GatherEvery(n=n, offset=offset)) @@ -384,6 +424,19 @@ def map_batches( ) ) + # TODO @dangotbanned: Come back to this when *properly* building out `Version` support + @deprecated("Use `v1.Expr.sample` or `{DataFrame,Series}.sample` instead") + def sample( + self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + f = F.sample(n, fraction=fraction, with_replacement=with_replacement, seed=seed) + return self._with_unary(f) + def any(self) -> Self: return self._with_unary(ir.boolean.Any()) @@ -402,6 +455,12 @@ def is_nan(self) -> Self: def is_null(self) -> Self: return self._with_unary(ir.boolean.IsNull()) + def is_not_nan(self) -> Self: + return self._with_unary(ir.boolean.IsNotNan()) + + def is_not_null(self) -> Self: + return self._with_unary(ir.boolean.IsNotNull()) + def is_first_distinct(self) -> Self: return self._with_unary(ir.boolean.IsFirstDistinct()) @@ -548,14 +607,14 @@ def name(self) -> ExprNameNamespace: >>> >>> renamed = nw.col("a", "b").name.suffix("_changed") >>> str(renamed._ir) - "RenameAlias(expr=RootSelector(selector=ByName(names=[a, b], require_all=True)), function=Suffix(suffix='_changed'))" + "RenameAlias(expr=RootSelector(selector=ByName(names=['a', 'b'], require_all=True)), function=Suffix(suffix='_changed'))" """ from narwhals._plan.expressions.name import ExprNameNamespace return ExprNameNamespace(_expr=self) @property - def cat(self) -> ExprCatNamespace: # pragma: no cover + def cat(self) -> ExprCatNamespace: from narwhals._plan.expressions.categorical import ExprCatNamespace return ExprCatNamespace(_expr=self) @@ -573,7 +632,7 @@ def dt(self) -> ExprDateTimeNamespace: return ExprDateTimeNamespace(_expr=self) @property - def list(self) -> ExprListNamespace: # pragma: no cover + def list(self) -> ExprListNamespace: from narwhals._plan.expressions.lists import ExprListNamespace return ExprListNamespace(_expr=self) @@ -584,6 +643,10 @@ def str(self) -> ExprStringNamespace: return ExprStringNamespace(_expr=self) + is_close = not_implemented() + head = not_implemented() + tail = not_implemented() + class ExprV1(Expr): _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 0d97c20288..56e4f5d4ce 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -38,9 +38,11 @@ RootSelector, Sort, SortBy, + StructExpr, TernaryExpr, WindowExpr, col, + ternary_expr, ) from narwhals._plan.expressions.name import KeepName, RenameAlias from narwhals._plan.expressions.window import over, over_ordered @@ -71,6 +73,7 @@ "SelectorIR", "Sort", "SortBy", + "StructExpr", "TernaryExpr", "WindowExpr", "aggregation", @@ -89,4 +92,5 @@ "strings", "struct", "temporal", + "ternary_expr", ] diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index b6bc93dd88..5b24fed778 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -8,7 +8,6 @@ from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._plan.typing import NativeSeriesT -from narwhals._typing_compat import TypeVar if TYPE_CHECKING: from typing_extensions import Self @@ -19,22 +18,25 @@ from narwhals._plan.typing import Seq from narwhals.typing import ClosedInterval -OtherT = TypeVar("OtherT") -ExprT = TypeVar("ExprT", bound="ExprIR", default="ExprIR") - # fmt: off class BooleanFunction(Function, options=FunctionOptions.elementwise): ... class All(BooleanFunction, options=FunctionOptions.aggregation): ... -class AllHorizontal(HorizontalFunction, BooleanFunction): ... class Any(BooleanFunction, options=FunctionOptions.aggregation): ... -class AnyHorizontal(HorizontalFunction, BooleanFunction): ... +class AllHorizontal(HorizontalFunction, BooleanFunction): + __slots__ = ("ignore_nulls",) + ignore_nulls: bool +class AnyHorizontal(HorizontalFunction, BooleanFunction): + __slots__ = ("ignore_nulls",) + ignore_nulls: bool class IsDuplicated(BooleanFunction, options=FunctionOptions.length_preserving): ... class IsFinite(BooleanFunction): ... class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... class IsLastDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... class IsNan(BooleanFunction): ... class IsNull(BooleanFunction): ... +class IsNotNan(BooleanFunction): ... +class IsNotNull(BooleanFunction): ... class IsUnique(BooleanFunction, options=FunctionOptions.length_preserving): ... class Not(BooleanFunction, config=FEOptions.renamed("not_")): ... # fmt: on diff --git a/narwhals/_plan/expressions/categorical.py b/narwhals/_plan/expressions/categorical.py index 5bb7157f5d..7c59fd4443 100644 --- a/narwhals/_plan/expressions/categorical.py +++ b/narwhals/_plan/expressions/categorical.py @@ -20,7 +20,7 @@ class IRCatNamespace(IRNamespace): class ExprCatNamespace(ExprNamespace[IRCatNamespace]): @property def _ir_namespace(self) -> type[IRCatNamespace]: - return IRCatNamespace # pragma: no cover + return IRCatNamespace def get_categories(self) -> Expr: - return self._with_unary(self._ir.get_categories()) # pragma: no cover + return self._with_unary(self._ir.get_categories()) diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 62c211500e..87271ea1db 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -27,6 +27,7 @@ SelectorOperatorT, SelectorT, Seq, + StructT_co, ) from narwhals.exceptions import InvalidOperationError @@ -60,9 +61,11 @@ "SelectorIR", "Sort", "SortBy", + "StructExpr", "TernaryExpr", "WindowExpr", "col", + "ternary_expr", ] @@ -273,7 +276,11 @@ def dispatch( return self.function.__expr_ir_dispatch__(self, ctx, frame, name) -class RollingExpr(FunctionExpr[RollingT_co]): ... +class RollingExpr(FunctionExpr[RollingT_co]): + def dispatch( + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.__expr_ir_dispatch__(self, ctx, frame, name) class AnonymousExpr( @@ -316,6 +323,20 @@ def __repr__(self) -> str: return f"{self.function!r}({list(self.input)!r})" +class StructExpr(FunctionExpr[StructT_co]): + """E.g. `col("a").struct.field(...)`. + + Requires special handling during expression expansion. + """ + + def needs_expansion(self) -> bool: + return self.function.needs_expansion or super().needs_expansion() + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield self + yield from super().iter_output_name() # pragma: no cover + + class Filter(ExprIR, child=("expr", "by")): __slots__ = ("expr", "by") # noqa: RUF023 expr: ExprIR @@ -376,8 +397,9 @@ def __repr__(self) -> str: args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" return f"{self.expr!r}.over({args})" + # TODO @dangotbanned: Update to align with https://github.com/pola-rs/polars/pull/25117/files#diff-45d1f22172e291bd4a5ce36d1fb8233698394f9590bcf11382b9c99b5449fff5 def iter_root_names(self) -> t.Iterator[ExprIR]: - # NOTE: `order_by` is never considered in `polars` + # NOTE: `order_by` ~~is~~ was never considered in `polars` # To match that behavior for `root_names` - but still expand in all other cases # - this little escape hatch exists # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 @@ -524,3 +546,7 @@ def matches(self, dtype: IntoDType) -> bool: def to_dtype_selector(self) -> Self: return replace(self, selector=self.selector.to_dtype_selector()) + + +def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: + return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index 8e910113a1..89ae663241 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -9,7 +9,7 @@ from narwhals._plan.options import FunctionFlags, FunctionOptions if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Mapping from typing import Any from _typeshed import ConvertibleToInt @@ -44,10 +44,25 @@ class Abs(Function, options=FunctionOptions.elementwise): ... class NullCount(Function, options=FunctionOptions.aggregation): ... class Exp(Function, options=FunctionOptions.elementwise): ... class Sqrt(Function, options=FunctionOptions.elementwise): ... +class Ceil(Function, options=FunctionOptions.elementwise): ... +class Floor(Function, options=FunctionOptions.elementwise): ... class DropNulls(Function, options=FunctionOptions.row_separable): ... -class Mode(Function): ... +class ModeAll(Function): ... +class ModeAny(Function, options=FunctionOptions.aggregation): ... +class Kurtosis(Function, options=FunctionOptions.aggregation): ... class Skew(Function, options=FunctionOptions.aggregation): ... -class Clip(Function, options=FunctionOptions.elementwise): ... +class Clip(Function, options=FunctionOptions.elementwise): + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, ExprIR]: + expr, lower_bound, upper_bound = node.input + return expr, lower_bound, upper_bound +class ClipLower(Function, options=FunctionOptions.elementwise): + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, lower_bound = node.input + return expr, lower_bound +class ClipUpper(Function, options=FunctionOptions.elementwise): + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, upper_bound = node.input + return expr, upper_bound class CumCount(CumAgg): ... class CumMin(CumAgg): ... class CumMax(CumAgg): ... @@ -63,10 +78,9 @@ class SumHorizontal(HorizontalFunction): ... class MinHorizontal(HorizontalFunction): ... class MaxHorizontal(HorizontalFunction): ... class MeanHorizontal(HorizontalFunction): ... +class Coalesce(HorizontalFunction): ... # fmt: on class Hist(Function): - """Only supported for `Series` so far.""" - __slots__ = ("include_breakpoint",) include_breakpoint: bool @@ -78,7 +92,7 @@ def __repr__(self) -> str: # They're also more widely defined to what will work at runtime @staticmethod def from_bins( - bins: Iterable[float], /, *, include_breakpoint: bool = True + bins: Iterable[float], /, *, include_breakpoint: bool = False ) -> HistBins: bins = tuple(bins) for i in range(1, len(bins)): @@ -88,10 +102,17 @@ def from_bins( @staticmethod def from_bin_count( - count: ConvertibleToInt = 10, /, *, include_breakpoint: bool = True + count: ConvertibleToInt = 10, /, *, include_breakpoint: bool = False ) -> HistBinCount: return HistBinCount(bin_count=int(count), include_breakpoint=include_breakpoint) + @property + def empty_data(self) -> Mapping[str, Iterable[Any]]: + # NOTE: May need to adapt for `include_category`? + return ( + {"breakpoint": [], "count": []} if self.include_breakpoint else {"count": []} + ) + class HistBins(Hist): __slots__ = ("bins",) @@ -116,13 +137,15 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: return base, exponent -class Kurtosis(Function, options=FunctionOptions.aggregation): - __slots__ = ("bias", "fisher") - fisher: bool - bias: bool +class FillNull(Function, options=FunctionOptions.elementwise): + """N-ary (expr, value).""" + + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, value = node.input + return expr, value -class FillNull(Function, options=FunctionOptions.elementwise): +class FillNan(Function, options=FunctionOptions.elementwise): """N-ary (expr, value).""" def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: @@ -163,6 +186,12 @@ class ReplaceStrict(Function, options=FunctionOptions.elementwise): return_dtype: DType | None +class ReplaceStrictDefault(ReplaceStrict): + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, default = node.input + return expr, default + + class GatherEvery(Function): __slots__ = ("n", "offset") n: int @@ -190,3 +219,32 @@ def to_function_expr(self, *inputs: ExprIR) -> AnonymousExpr: options = self.function_options return AnonymousExpr(input=inputs, function=self, options=options) + + +class SampleN(Function): + __slots__ = ("n", "seed", "with_replacement") + n: int + with_replacement: bool + seed: int | None + + +class SampleFrac(Function): + __slots__ = ("fraction", "seed", "with_replacement") + fraction: float + with_replacement: bool + seed: int | None + + +def sample( + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, +) -> SampleFrac | SampleN: + if n is not None and fraction is not None: + msg = "cannot specify both `n` and `fraction`" + raise ValueError(msg) + if fraction is not None: + return SampleFrac(fraction=fraction, with_replacement=with_replacement, seed=seed) + return SampleN(n=1 if n is None else n, with_replacement=with_replacement, seed=seed) diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index b14090b985..e35b4fb41e 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -3,25 +3,78 @@ from typing import TYPE_CHECKING, ClassVar from narwhals._plan._function import Function +from narwhals._plan._parse import parse_into_expr_ir +from narwhals._plan.exceptions import function_arg_non_scalar_error from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions +from narwhals._utils import ensure_type +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import ExprIR, FunctionExpr as FExpr + from narwhals._plan.typing import IntoExpr # fmt: off -class ListFunction(Function, accessor="list"): ... -class Len(ListFunction, options=FunctionOptions.elementwise): ... +class ListFunction(Function, accessor="list", options=FunctionOptions.elementwise): ... +class Len(ListFunction): ... +class Unique(ListFunction): ... +class Get(ListFunction): + __slots__ = ("index",) + index: int +class Join(ListFunction): + """Join all string items in a sublist and place a separator between them.""" + + __slots__ = ("ignore_nulls", "separator") + separator: str + ignore_nulls: bool # fmt: on +class Contains(ListFunction): + """N-ary (expr, item).""" + + def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, item = node.input + return expr, item + + class IRListNamespace(IRNamespace): len: ClassVar = Len + unique: ClassVar = Unique + contains: ClassVar = Contains + get: ClassVar = Get + join: ClassVar = Join class ExprListNamespace(ExprNamespace[IRListNamespace]): @property def _ir_namespace(self) -> type[IRListNamespace]: - return IRListNamespace # pragma: no cover + return IRListNamespace def len(self) -> Expr: - return self._with_unary(self._ir.len()) # pragma: no cover + return self._with_unary(self._ir.len()) + + def unique(self) -> Expr: + return self._with_unary(self._ir.unique()) + + def get(self, index: int) -> Expr: + ensure_type(index, int, param_name="index") + if index < 0: + msg = f"`index` is out of bounds; must be >= 0, got {index}" + raise InvalidOperationError(msg) + return self._with_unary(self._ir.get(index=index)) + + def contains(self, item: IntoExpr) -> Expr: + item_ir = parse_into_expr_ir(item, str_as_lit=True) + contains = self._ir.contains() + if not item_ir.is_scalar: + raise function_arg_non_scalar_error(contains, "item", item_ir) + return self._expr._from_ir(contains.to_function_expr(self._expr._ir, item_ir)) + + def join(self, separator: str, *, ignore_nulls: bool = True) -> Expr: + ensure_type(separator, str, param_name="separator") + return self._with_unary( + self._ir.join(separator=separator, ignore_nulls=ignore_nulls) + ) diff --git a/narwhals/_plan/expressions/literal.py b/narwhals/_plan/expressions/literal.py index 7d46c8436c..b6b8659512 100644 --- a/narwhals/_plan/expressions/literal.py +++ b/narwhals/_plan/expressions/literal.py @@ -7,6 +7,8 @@ from narwhals._plan.typing import LiteralT, NativeSeriesT, NonNestedLiteralT if TYPE_CHECKING: + from collections.abc import Iterator + from typing_extensions import TypeIs from narwhals._plan.expressions.expr import Literal @@ -74,6 +76,11 @@ def __repr__(self) -> str: def unwrap(self) -> Series[NativeSeriesT]: return self.value + @property + def __immutable_values__(self) -> Iterator[Any]: + # NOTE: Adding `Series.__eq__` means this needed a manual override + yield from (self.name, self.dtype, id(self.value)) + def is_literal_scalar( obj: Literal[NonNestedLiteralT] | Any, diff --git a/narwhals/_plan/expressions/ranges.py b/narwhals/_plan/expressions/ranges.py index 644a89dc1e..fba04afbbf 100644 --- a/narwhals/_plan/expressions/ranges.py +++ b/narwhals/_plan/expressions/ranges.py @@ -38,3 +38,11 @@ class DateRange(RangeFunction, options=FunctionOptions.row_separable): __slots__ = ("interval", "closed") # noqa: RUF023 interval: int closed: ClosedInterval + + +class LinearSpace(RangeFunction, options=FunctionOptions.row_separable): + """N-ary (start, end).""" + + __slots__ = ("num_samples", "closed") # noqa: RUF023 + num_samples: int + closed: ClosedInterval diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 5478a7154c..42ea31bb70 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -3,11 +3,15 @@ from typing import TYPE_CHECKING, ClassVar from narwhals._plan._function import Function, HorizontalFunction +from narwhals._plan._parse import parse_into_expr_ir from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import ExprIR, FunctionExpr as FExpr # fmt: off @@ -15,6 +19,7 @@ class StringFunction(Function, accessor="str", options=FunctionOptions.elementwi class LenChars(StringFunction): ... class ToLowercase(StringFunction): ... class ToUppercase(StringFunction): ... +class ToTitlecase(StringFunction): ... # fmt: on class ConcatStr(HorizontalFunction, StringFunction): __slots__ = ("ignore_nulls", "separator") @@ -34,17 +39,32 @@ class EndsWith(StringFunction): class Replace(StringFunction): - __slots__ = ("literal", "n", "pattern", "value") + """N-ary (expr, value).""" + + def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, value = node.input + return expr, value + + __slots__ = ("literal", "n", "pattern") pattern: str - value: str literal: bool n: int class ReplaceAll(StringFunction): - __slots__ = ("literal", "pattern", "value") + """N-ary (expr, value).""" + + def unwrap_input( + self, node: FExpr[Self], / + ) -> tuple[ExprIR, ExprIR]: # pragma: no cover + expr, value = node.input + return expr, value + + def to_replace_n(self, n: int) -> Replace: + return Replace(pattern=self.pattern, literal=self.literal, n=n) + + __slots__ = ("literal", "pattern") pattern: str - value: str literal: bool @@ -69,46 +89,55 @@ class StripChars(StringFunction): characters: str | None +class ToDate(StringFunction): + __slots__ = ("format",) + format: str | None + + class ToDatetime(StringFunction): __slots__ = ("format",) format: str | None +class ZFill(StringFunction, config=FEOptions.renamed("zfill")): + __slots__ = ("length",) + length: int + + class IRStringNamespace(IRNamespace): len_chars: ClassVar = LenChars - to_lowercase: ClassVar = ToUppercase - to_uppercase: ClassVar = ToLowercase + to_lowercase: ClassVar = ToLowercase + to_uppercase: ClassVar = ToUppercase + to_titlecase: ClassVar = ToTitlecase split: ClassVar = Split starts_with: ClassVar = StartsWith ends_with: ClassVar = EndsWith + zfill: ClassVar = ZFill - def replace( - self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Replace: # pragma: no cover - return Replace(pattern=pattern, value=value, literal=literal, n=n) + def replace(self, pattern: str, *, literal: bool = False, n: int = 1) -> Replace: + return Replace(pattern=pattern, literal=literal, n=n) - def replace_all( - self, pattern: str, value: str, *, literal: bool = False - ) -> ReplaceAll: # pragma: no cover - return ReplaceAll(pattern=pattern, value=value, literal=literal) + def replace_all(self, pattern: str, *, literal: bool = False) -> ReplaceAll: + return ReplaceAll(pattern=pattern, literal=literal) - def strip_chars( - self, characters: str | None = None - ) -> StripChars: # pragma: no cover + def strip_chars(self, characters: str | None = None) -> StripChars: return StripChars(characters=characters) def contains(self, pattern: str, *, literal: bool = False) -> Contains: return Contains(pattern=pattern, literal=literal) - def slice(self, offset: int, length: int | None = None) -> Slice: # pragma: no cover + def slice(self, offset: int, length: int | None = None) -> Slice: return Slice(offset=offset, length=length) - def head(self, n: int = 5) -> Slice: # pragma: no cover + def head(self, n: int = 5) -> Slice: return self.slice(0, n) - def tail(self, n: int = 5) -> Slice: # pragma: no cover + def tail(self, n: int = 5) -> Slice: return self.slice(-n) + def to_date(self, format: str | None = None) -> ToDate: # pragma: no cover + return ToDate(format=format) + def to_datetime(self, format: str | None = None) -> ToDatetime: # pragma: no cover return ToDatetime(format=format) @@ -122,44 +151,57 @@ def len_chars(self) -> Expr: return self._with_unary(self._ir.len_chars()) def replace( - self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Expr: # pragma: no cover - return self._with_unary(self._ir.replace(pattern, value, literal=literal, n=n)) + self, pattern: str, value: str | Expr, *, literal: bool = False, n: int = 1 + ) -> Expr: + other = parse_into_expr_ir(value, str_as_lit=True) + replace = self._ir.replace(pattern, literal=literal, n=n) + return self._expr._from_ir(replace.to_function_expr(self._expr._ir, other)) def replace_all( - self, pattern: str, value: str, *, literal: bool = False - ) -> Expr: # pragma: no cover - return self._with_unary(self._ir.replace_all(pattern, value, literal=literal)) + self, pattern: str, value: str | Expr, *, literal: bool = False + ) -> Expr: + other = parse_into_expr_ir(value, str_as_lit=True) + replace = self._ir.replace_all(pattern, literal=literal) + return self._expr._from_ir(replace.to_function_expr(self._expr._ir, other)) - def strip_chars(self, characters: str | None = None) -> Expr: # pragma: no cover + def strip_chars(self, characters: str | None = None) -> Expr: return self._with_unary(self._ir.strip_chars(characters)) - def starts_with(self, prefix: str) -> Expr: # pragma: no cover + def starts_with(self, prefix: str) -> Expr: return self._with_unary(self._ir.starts_with(prefix=prefix)) - def ends_with(self, suffix: str) -> Expr: # pragma: no cover + def ends_with(self, suffix: str) -> Expr: return self._with_unary(self._ir.ends_with(suffix=suffix)) def contains(self, pattern: str, *, literal: bool = False) -> Expr: return self._with_unary(self._ir.contains(pattern, literal=literal)) - def slice(self, offset: int, length: int | None = None) -> Expr: # pragma: no cover + def slice(self, offset: int, length: int | None = None) -> Expr: return self._with_unary(self._ir.slice(offset, length)) - def head(self, n: int = 5) -> Expr: # pragma: no cover + def head(self, n: int = 5) -> Expr: return self._with_unary(self._ir.head(n)) - def tail(self, n: int = 5) -> Expr: # pragma: no cover + def tail(self, n: int = 5) -> Expr: return self._with_unary(self._ir.tail(n)) - def split(self, by: str) -> Expr: # pragma: no cover + def split(self, by: str) -> Expr: return self._with_unary(self._ir.split(by=by)) + def to_date(self, format: str | None = None) -> Expr: # pragma: no cover + return self._with_unary(self._ir.to_date(format)) + def to_datetime(self, format: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_datetime(format)) - def to_lowercase(self) -> Expr: # pragma: no cover + def to_lowercase(self) -> Expr: return self._with_unary(self._ir.to_lowercase()) - def to_uppercase(self) -> Expr: # pragma: no cover + def to_uppercase(self) -> Expr: return self._with_unary(self._ir.to_uppercase()) + + def to_titlecase(self) -> Expr: + return self._with_unary(self._ir.to_titlecase()) + + def zfill(self, length: int) -> Expr: + return self._with_unary(self._ir.zfill(length=length)) diff --git a/narwhals/_plan/expressions/struct.py b/narwhals/_plan/expressions/struct.py index e3625adb8a..6350f0668e 100644 --- a/narwhals/_plan/expressions/struct.py +++ b/narwhals/_plan/expressions/struct.py @@ -7,10 +7,23 @@ from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan._expr_ir import ExprIR from narwhals._plan.expr import Expr + from narwhals._plan.expressions.expr import StructExpr + + +class StructFunction(Function, accessor="struct"): + def to_function_expr(self, *inputs: ExprIR) -> StructExpr[Self]: + from narwhals._plan.expressions.expr import StructExpr + return StructExpr(input=inputs, function=self, options=self.function_options) -class StructFunction(Function, accessor="struct"): ... + @property + def needs_expansion(self) -> bool: + msg = f"{type(self).__name__}.needs_expansion" + raise NotImplementedError(msg) class FieldByName( @@ -22,6 +35,10 @@ class FieldByName( def __repr__(self) -> str: return f"{super().__repr__()}({self.name!r})" + @property + def needs_expansion(self) -> bool: + return True + class IRStructNamespace(IRNamespace): field: ClassVar = FieldByName diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index 35a622ebd2..257495d2b6 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -8,7 +8,7 @@ from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import Self, TypeAlias, TypeIs from narwhals._duration import IntervalUnit from narwhals._plan.expr import Expr @@ -71,18 +71,24 @@ def __repr__(self) -> str: return f"{super().__repr__()}[{self.time_unit!r}]" -class Truncate(TemporalFunction): +class _IntervalFunction(TemporalFunction): __slots__ = ("multiple", "unit") multiple: int unit: IntervalUnit - @staticmethod - def from_string(every: str, /) -> Truncate: - return Truncate.from_interval(Interval.parse(every)) + @classmethod + def from_string(cls, interval: str, /) -> Self: + return cls.from_interval(Interval.parse(interval)) - @staticmethod - def from_interval(every: Interval, /) -> Truncate: - return Truncate(multiple=every.multiple, unit=every.unit) + @classmethod + def from_interval(cls, interval: Interval, /) -> Self: + return cls(multiple=interval.multiple, unit=interval.unit) + + +# fmt: off +class Truncate(_IntervalFunction): ... +class OffsetBy(_IntervalFunction): ... +# fmt: on class IRDateTimeNamespace(IRNamespace): @@ -106,6 +112,7 @@ class IRDateTimeNamespace(IRNamespace): to_string: ClassVar = ToString replace_time_zone: ClassVar = ReplaceTimeZone convert_time_zone: ClassVar = ConvertTimeZone + offset_by: ClassVar = staticmethod(OffsetBy.from_string) truncate: ClassVar = staticmethod(Truncate.from_string) timestamp: ClassVar = staticmethod(Timestamp.from_time_unit) @@ -169,7 +176,7 @@ def total_nanoseconds(self) -> Expr: # pragma: no cover def to_string(self, format: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_string(format=format)) - def replace_time_zone(self, time_zone: str | None) -> Expr: # pragma: no cover + def replace_time_zone(self, time_zone: str | None) -> Expr: return self._with_unary(self._ir.replace_time_zone(time_zone=time_zone)) def convert_time_zone(self, time_zone: str) -> Expr: # pragma: no cover @@ -180,3 +187,6 @@ def timestamp(self, time_unit: TimeUnit = "us") -> Expr: def truncate(self, every: str) -> Expr: return self._with_unary(self._ir.truncate(every)) + + def offset_by(self, by: str) -> Expr: # pragma: no cover + return self._with_unary(self._ir.offset_by(by)) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 6feb2f1700..11900306aa 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -3,19 +3,26 @@ import builtins import datetime as dt import typing as t -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Final from narwhals._duration import Interval from narwhals._plan import _guards, _parse, common, expressions as ir, selectors as cs from narwhals._plan._dispatch import get_dispatch_name +from narwhals._plan.exceptions import list_literal_error from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral -from narwhals._plan.expressions.ranges import DateRange, IntRange, RangeFunction +from narwhals._plan.expressions.ranges import ( + DateRange, + IntRange, + LinearSpace, + RangeFunction, +) from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.when_then import When from narwhals._utils import ( Implementation, Version, + ensure_type, flatten, is_eager_allowed, qualified_type_name, @@ -52,6 +59,8 @@ t.Any, CompliantSeries[NativeSeriesT], t.Any, t.Any ] +_dtypes: Final = Version.MAIN.dtypes + def col(*names: str | t.Iterable[str]) -> Expr: flat = tuple(flatten(names)) @@ -72,8 +81,7 @@ def lit( if _guards.is_series(value): return SeriesLiteral(value=value).to_literal().to_narwhals() if not _guards.is_non_nested_literal(value): - msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." - raise TypeError(msg) + raise list_literal_error(value) if dtype is None: dtype = common.py_to_narwhals_dtype(value, Version.MAIN) else: @@ -113,14 +121,26 @@ def sum(*columns: str) -> Expr: return col(columns).sum() -def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: +def all_horizontal( + *exprs: IntoExpr | t.Iterable[IntoExpr], ignore_nulls: bool = False +) -> Expr: it = _parse.parse_into_seq_of_expr_ir(*exprs) - return ir.boolean.AllHorizontal().to_function_expr(*it).to_narwhals() + return ( + ir.boolean.AllHorizontal(ignore_nulls=ignore_nulls) + .to_function_expr(*it) + .to_narwhals() + ) -def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: +def any_horizontal( + *exprs: IntoExpr | t.Iterable[IntoExpr], ignore_nulls: bool = False +) -> Expr: it = _parse.parse_into_seq_of_expr_ir(*exprs) - return ir.boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() + return ( + ir.boolean.AnyHorizontal(ignore_nulls=ignore_nulls) + .to_function_expr(*it) + .to_narwhals() + ) def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: @@ -157,6 +177,33 @@ def concat_str( ) +def coalesce(exprs: IntoExpr | t.Iterable[IntoExpr], *more_exprs: IntoExpr) -> Expr: + it = _parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) + return F.Coalesce().to_function_expr(*it).to_narwhals() + + +def format(f_string: str, *args: IntoExpr) -> Expr: + """Format expressions as a string. + + Arguments: + f_string: A string that with placeholders. + args: Expression(s) that fill the placeholders. + """ + if (n_placeholders := f_string.count("{}")) != builtins.len(args): + msg = f"number of placeholders should equal the number of arguments. Expected {n_placeholders} arguments, got {builtins.len(args)}." + raise ValueError(msg) + string = _dtypes.String() + exprs: list[ir.ExprIR] = [] + it = iter(args) + for i, s in enumerate(f_string.split("{}")): + if i > 0: + exprs.append(_parse.parse_into_expr_ir(next(it))) + if s: + exprs.append(lit(s, string)._ir) + f = ConcatStr(separator="", ignore_nulls=False) + return f.to_function_expr(*exprs).to_narwhals() + + def when( *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], **constraints: t.Any ) -> When: @@ -273,6 +320,95 @@ def date_range( ) +@t.overload +def linear_space( + start: float | IntoExprColumn, + end: float | IntoExprColumn, + num_samples: int, + *, + closed: ClosedInterval = ..., + eager: t.Literal[False] = ..., +) -> Expr: ... +@t.overload +def linear_space( + start: float, + end: float, + num_samples: int, + *, + closed: ClosedInterval = ..., + eager: Arrow, +) -> Series[pa.ChunkedArray[t.Any]]: ... +@t.overload +def linear_space( + start: float, + end: float, + num_samples: int, + *, + closed: ClosedInterval = ..., + eager: IntoBackend[EagerAllowed], +) -> Series: ... +def linear_space( + start: float | IntoExprColumn, + end: float | IntoExprColumn, + num_samples: int, + *, + closed: ClosedInterval = "both", + eager: IntoBackend[EagerAllowed] | t.Literal[False] = False, +) -> Expr | Series: + """Create sequence of evenly-spaced points. + + Arguments: + start: Lower bound of the range. + end: Upper bound of the range. + num_samples: Number of samples in the output sequence. + closed: Define which sides of the interval are closed (inclusive). + eager: If set to `False` (default), then an expression is returned. + If set to an (eager) implementation ("pandas", "polars" or "pyarrow"), then + a `Series` is returned. + + Notes: + Unlike `pl.linear_space`, *currently* only numeric dtypes (and not temporal) are supported. + + Examples: + >>> import narwhals._plan as nwp + >>> nwp.linear_space(start=0, end=1, num_samples=3, eager="pyarrow").to_list() + [0.0, 0.5, 1.0] + + >>> nwp.linear_space(0, 1, 3, closed="left", eager="pyarrow").to_list() + [0.0, 0.3333333333333333, 0.6666666666666666] + + >>> nwp.linear_space(0, 1, 3, closed="right", eager="pyarrow").to_list() + [0.3333333333333333, 0.6666666666666666, 1.0] + + >>> nwp.linear_space(0, 1, 3, closed="none", eager="pyarrow").to_list() + [0.25, 0.5, 0.75] + + >>> df = nwp.DataFrame.from_dict({"a": [1, 2, 3, 4, 5]}, backend="pyarrow") + >>> df.with_columns(nwp.linear_space(0, 10, 5).alias("ls")) + ┌──────────────────────┐ + | nw.DataFrame | + |----------------------| + |pyarrow.Table | + |a: int64 | + |ls: double | + |---- | + |a: [[1,2,3,4,5]] | + |ls: [[0,2.5,5,7.5,10]]| + └──────────────────────┘ + """ + ensure_type(num_samples, int, param_name="num_samples") + closed = _ensure_closed_interval(closed) + if eager: + ns = _eager_namespace(eager) + start, end = _ensure_range_scalar(start, end, (float, int), LinearSpace, eager) + return _linear_space_eager(start, end, num_samples, closed, ns) + return ( + LinearSpace(num_samples=num_samples, closed=closed) + .to_function_expr(*_parse.parse_into_seq_of_expr_ir(start, end)) + .to_narwhals() + ) + + @t.overload def _eager_namespace(backend: Arrow, /) -> _arrow.Namespace: ... @t.overload @@ -291,11 +427,12 @@ def _eager_namespace( raise ValueError(msg) -# NOTE: If anything beyond `{date,int}_range` are added, move to `RangeFunction` +# TODO @dangotbanned: Handle this in `RangeFunction` or `RangeExpr` +# NOTE: `ArrowNamespace._range_function_inputs` has some duplicated logic too def _ensure_range_scalar( start: t.Any, end: t.Any, - valid_type: type[NonNestedLiteralT], + valid_type: type[NonNestedLiteralT] | tuple[type[NonNestedLiteralT], ...], function: type[RangeFunction], eager: IntoBackend[EagerAllowed], ) -> tuple[NonNestedLiteralT, NonNestedLiteralT]: @@ -303,8 +440,10 @@ def _ensure_range_scalar( return start, end tp_start = qualified_type_name(start) tp_end = qualified_type_name(end) + valid_types = (valid_type,) if not isinstance(valid_type, tuple) else valid_type + tp_names = " | ".join(tp.__name__ for tp in valid_types) msg = ( - f"Expected `start` and `end` to be {valid_type.__name__} values since `eager={eager}`, but got: ({tp_start}, {tp_end})\n\n" + f"Expected `start` and `end` to be {tp_names} values since `eager={eager}`, but got: ({tp_start}, {tp_end})\n\n" f"Hint: Calling `nw.{get_dispatch_name(function)}` with expressions requires:\n" " - `eager=False`\n" " - a context such as `select` or `with_columns`" @@ -333,6 +472,17 @@ def _int_range_eager( return ns.int_range_eager(start, end, step, dtype=dtype).to_narwhals() +def _linear_space_eager( + start: float, + end: float, + num_samples: int, + closed: ClosedInterval, + ns: EagerNs[NativeSeriesT], + /, +) -> Series[NativeSeriesT]: + return ns.linear_space_eager(start, end, num_samples, closed=closed).to_narwhals() + + def _ensure_closed_interval(closed: ClosedInterval, /) -> ClosedInterval: closed_intervals = "left", "right", "none", "both" if closed not in closed_intervals: diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 41487b503b..5938795a13 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -14,6 +14,7 @@ from narwhals._plan.expressions import selectors as cs from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._plan.expressions.namespace import IRNamespace +from narwhals._plan.expressions.struct import FieldByName from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.utils import Version @@ -113,6 +114,8 @@ def _expr_output_name(expr: ir.ExprIR, /) -> str | ComputeError: isinstance(e.selector, cs.ByName) and len(e.selector.names) == 1 ): return e.selector.names[0] + if isinstance(e, ir.StructExpr) and isinstance(e.function, FieldByName): + return e.function.name msg = ( f"unable to find root column name for expr '{expr!r}' when calling 'output_name'" ) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 4a31389170..ab43aa23ed 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable -from narwhals._utils import Implementation +from narwhals._utils import Implementation, ensure_type from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: @@ -76,6 +76,9 @@ def __str__(self) -> str: return name.replace("|", " | ") +_INVALID = FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING + + class FunctionOptions(Immutable): """https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs""" # noqa: D415 @@ -101,11 +104,12 @@ def is_input_wildcard_expansion(self) -> bool: return self.flags.is_input_wildcard_expansion() def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: - if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: - msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." # pragma: no cover - raise TypeError(msg) # pragma: no cover + new_flags = self.flags | flags + if _INVALID in new_flags: + msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." + raise TypeError(msg) obj = FunctionOptions.__new__(FunctionOptions) - object.__setattr__(obj, "flags", self.flags | flags) + object.__setattr__(obj, "flags", new_flags) return obj def with_elementwise(self) -> FunctionOptions: @@ -246,13 +250,30 @@ class RollingOptionsFixedWindow(Immutable): center: bool fn_params: RollingVarParams | None + @property + def ddof(self) -> int: + return 1 if self.fn_params is None else self.fn_params.ddof + def rolling_options( window_size: int, min_samples: int | None, /, *, center: bool, ddof: int | None = None ) -> RollingOptionsFixedWindow: + ensure_type(window_size, int, param_name="window_size") + ensure_type(min_samples, int, type(None), param_name="min_samples") + if window_size < 1: + msg = "`window_size` must be >= 1" + raise InvalidOperationError(msg) + if min_samples is None: + min_samples = window_size + elif min_samples < 1: + msg = "`min_samples` must be >= 1" + raise InvalidOperationError(msg) + elif min_samples > window_size: + msg = "`min_samples` must be <= `window_size`" + raise InvalidOperationError(msg) return RollingOptionsFixedWindow( window_size=window_size, - min_samples=window_size if min_samples is None else min_samples, + min_samples=min_samples, center=center, fn_params=ddof if ddof is None else RollingVarParams(ddof=ddof), ) @@ -307,3 +328,15 @@ def default(cls) -> Self: FEOptions = FunctionExprOptions + + +class ExplodeOptions(Immutable): + __slots__ = ("empty_as_null", "keep_nulls") + empty_as_null: bool + """Explode an empty list into a `null`.""" + keep_nulls: bool + """Explode a `null` into a `null`.""" + + def any(self) -> bool: + """Return True if we need to handle empty lists and/or nulls.""" + return self.empty_as_null or self.keep_nulls diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 0220253087..075fe569d7 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -1,32 +1,46 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, ClassVar, Generic +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal from narwhals._plan._guards import is_series -from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co, OneOrIterable +from narwhals._plan.typing import ( + IncompleteCyclic, + NativeSeriesT, + NativeSeriesT_co, + OneOrIterable, + SeriesT, +) from narwhals._utils import ( Implementation, Version, generate_repr, is_eager_allowed, qualified_type_name, + unstable, ) from narwhals.dependencies import is_pyarrow_chunked_array +from narwhals.exceptions import ShapeError if TYPE_CHECKING: from collections.abc import Iterator import polars as pl - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.dataframe import DataFrame from narwhals._typing import EagerAllowed, IntoBackend, _EagerAllowedImpl from narwhals.dtypes import DType - from narwhals.typing import IntoDType, NonNestedLiteral, SizedMultiIndexSelector - -Incomplete: TypeAlias = Any + from narwhals.schema import Schema + from narwhals.typing import ( + IntoDType, + NonNestedLiteral, + NumericLiteral, + PythonLiteral, + SizedMultiIndexSelector, + TemporalLiteral, + ) class Series(Generic[NativeSeriesT_co]): @@ -49,6 +63,10 @@ def name(self) -> str: def implementation(self) -> _EagerAllowedImpl: return self._compliant.implementation + @property + def shape(self) -> tuple[int]: + return (self._compliant.len(),) + def __init__(self, compliant: CompliantSeries[NativeSeriesT_co], /) -> None: self._compliant = compliant @@ -90,9 +108,7 @@ def from_native( raise NotImplementedError(type(native)) - # NOTE: `Incomplete` until `CompliantSeries` can avoid a cyclic dependency back to `CompliantDataFrame` - # Currently an issue on `main` and leads to a lot of intermittent warnings - def to_frame(self) -> DataFrame[Incomplete, NativeSeriesT_co]: + def to_frame(self) -> DataFrame[IncompleteCyclic, NativeSeriesT_co]: import narwhals._plan.dataframe as _df # NOTE: Missing placeholder for `DataFrameV1` @@ -107,34 +123,40 @@ def to_list(self) -> list[Any]: def to_polars(self) -> pl.Series: return self._compliant.to_polars() - def __iter__(self) -> Iterator[Any]: # pragma: no cover + # TODO @dangotbanned: Figure out if this should be yielding `pa.Scalar` + def __iter__(self) -> Iterator[Any]: yield from self.to_native() def alias(self, name: str) -> Self: return type(self)(self._compliant.alias(name)) + rename = alias + + def cast(self, dtype: IntoDType) -> Self: + return type(self)(self._compliant.cast(dtype)) + def __len__(self) -> int: return len(self._compliant) - def gather(self, indices: SizedMultiIndexSelector[Self]) -> Self: # pragma: no cover + def gather(self, indices: SizedMultiIndexSelector[Self]) -> Self: if len(indices) == 0: return self.slice(0, 0) - rows = indices._compliant if isinstance(indices, Series) else indices - return type(self)(self._compliant.gather(rows)) + return type(self)(self._compliant.gather(self._parse_into_compliant(indices))) - def has_nulls(self) -> bool: # pragma: no cover + def gather_every(self, n: int, offset: int = 0) -> Self: + return type(self)(self._compliant.gather_every(n, offset)) + + def has_nulls(self) -> bool: return self._compliant.has_nulls() - def slice(self, offset: int, length: int | None = None) -> Self: # pragma: no cover + def slice(self, offset: int, length: int | None = None) -> Self: return type(self)(self._compliant.slice(offset=offset, length=length)) - def sort( - self, *, descending: bool = False, nulls_last: bool = False - ) -> Self: # pragma: no cover + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: result = self._compliant.sort(descending=descending, nulls_last=nulls_last) return type(self)(result) - def is_empty(self) -> bool: # pragma: no cover + def is_empty(self) -> bool: return self._compliant.is_empty() def _unwrap_compliant( @@ -173,6 +195,140 @@ def scatter( def is_in(self, other: Iterable[Any]) -> Self: return type(self)(self._compliant.is_in(self._parse_into_compliant(other))) + def is_nan(self) -> Self: + return type(self)(self._compliant.is_nan()) + + def is_null(self) -> Self: + return type(self)(self._compliant.is_null()) + + def is_not_nan(self) -> Self: + return type(self)(self._compliant.is_not_nan()) + + def is_not_null(self) -> Self: + return type(self)(self._compliant.is_not_null()) + + def null_count(self) -> int: + return self._compliant.null_count() + + def fill_nan(self, value: float | Self | None) -> Self: + other = self._unwrap_compliant(value) if is_series(value) else value + return type(self)(self._compliant.fill_nan(other)) + + def sample( + self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> Self: + if n is not None and fraction is not None: + msg = "cannot specify both `n` and `fraction`" + raise ValueError(msg) + s = self._compliant + if fraction is not None: + result = s.sample_frac(fraction, with_replacement=with_replacement, seed=seed) + elif n is None: + result = s.sample_n(with_replacement=with_replacement, seed=seed) + elif not with_replacement and n > len(self): + msg = "cannot take a larger sample than the total population when `with_replacement=false`" + raise ShapeError(msg) + else: + result = s.sample_n(n, with_replacement=with_replacement, seed=seed) + return type(self)(result) + + def __eq__(self, other: NumericLiteral | TemporalLiteral | Self) -> Self: # type: ignore[override] + other_ = self._unwrap_compliant(other) if is_series(other) else other + return type(self)(self._compliant.__eq__(other_)) + + def __and__(self, other: bool | Self, /) -> Self: + other_ = self._unwrap_compliant(other) if is_series(other) else other + return type(self)(self._compliant.__and__(other_)) + + def __or__(self, other: bool | Self, /) -> Self: + other_ = self._unwrap_compliant(other) if is_series(other) else other + return type(self)(self._compliant.__or__(other_)) + + def __invert__(self) -> Self: + return type(self)(self._compliant.__invert__()) + + def __add__(self, other: NumericLiteral | TemporalLiteral | Self, /) -> Self: + other_ = self._unwrap_compliant(other) if is_series(other) else other + return type(self)(self._compliant.__add__(other_)) + + def all(self) -> bool: + return self._compliant.all() + + def any(self) -> bool: + return self._compliant.any() + + def sum(self) -> float: + return self._compliant.sum() + + def count(self) -> int: + return self._compliant.count() + + def first(self) -> PythonLiteral: + return self._compliant.first() + + def last(self) -> PythonLiteral: + return self._compliant.last() + + def unique(self, *, maintain_order: bool = False) -> Self: + return type(self)(self._compliant.unique(maintain_order=maintain_order)) + + def drop_nulls(self) -> Self: + return type(self)(self._compliant.drop_nulls()) + + def drop_nans(self) -> Self: + return type(self)(self._compliant.drop_nans()) + + @unstable + def hist( + self, + bins: Sequence[float] | None = None, + *, + bin_count: int | None = None, + include_breakpoint: bool = True, + include_category: bool = False, + _compatibility_behavior: Literal["narwhals", "polars"] = "narwhals", + ) -> DataFrame[IncompleteCyclic, NativeSeriesT_co]: + return self._compliant.hist( + bins, + bin_count=bin_count, + include_breakpoint=include_breakpoint, + include_category=include_category, + _compatibility_behavior=_compatibility_behavior, + ).to_narwhals() + + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Self: + return type(self)( + self._compliant.explode(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + ) + + @property + def struct(self) -> SeriesStructNamespace[Self]: + return SeriesStructNamespace(self) + + +class SeriesStructNamespace(Generic[SeriesT]): + def __init__(self, series: SeriesT) -> None: + self._series: SeriesT = series + + def unnest(self) -> DataFrame[Any, Any]: + """Convert this struct Series to a DataFrame with a separate column for each field.""" + result: DataFrame[Any, Any] = ( + self._series._compliant.struct.unnest().to_narwhals() + ) + return result + + def field(self, name: str) -> SeriesT: # pragma: no cover + return type(self._series)(self._series._compliant.struct.field(name)) + + @property + def schema(self) -> Schema: + return self._series._compliant.struct.schema + class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index c0973d0487..4b35e7722b 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -21,6 +21,7 @@ from narwhals._plan.expressions.functions import RollingWindow from narwhals._plan.expressions.namespace import IRNamespace from narwhals._plan.expressions.ranges import RangeFunction + from narwhals._plan.expressions.struct import StructFunction from narwhals._plan.selectors import Selector from narwhals._plan.series import Series from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -61,6 +62,9 @@ RangeT_co = TypeVar( "RangeT_co", bound="RangeFunction", default="RangeFunction", covariant=True ) +StructT_co = TypeVar( + "StructT_co", bound="StructFunction", default="StructFunction", covariant=True +) LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") @@ -88,11 +92,11 @@ "NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral" ) NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries", default="NativeSeries") +NativeSeriesT2 = TypeVar("NativeSeriesT2", bound="NativeSeries", default="NativeSeries") NativeSeriesAnyT = TypeVar("NativeSeriesAnyT", bound="NativeSeries", default="t.Any") NativeSeriesT_co = TypeVar( "NativeSeriesT_co", bound="NativeSeries", covariant=True, default="NativeSeries" ) -NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame", default="NativeFrame") NativeFrameT_co = TypeVar( "NativeFrameT_co", bound="NativeFrame", covariant=True, default="NativeFrame" ) @@ -123,9 +127,10 @@ IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" ColumnNameOrSelector: TypeAlias = "str | Selector" -OneOrIterable: TypeAlias = "T | t.Iterable[T]" +OneOrIterable: TypeAlias = "T | Iterable[T]" OneOrSeq: TypeAlias = t.Union[T, Seq[T]] DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") +SeriesT = TypeVar("SeriesT", bound="Series[t.Any]") Order: TypeAlias = t.Literal["ascending", "descending"] NonCrossJoinStrategy: TypeAlias = t.Literal["inner", "left", "full", "semi", "anti"] PartialSeries: TypeAlias = "Callable[[Iterable[t.Any]], Series[NativeSeriesAnyT]]" @@ -136,3 +141,13 @@ [^1]: `ByName`, `ByIndex` will never be ignored. """ + + +IncompleteCyclic: TypeAlias = "t.Any" +"""Placeholder for typing that introduces a cyclic dependency. + +Mainly for spelling `(Compliant)DataFrame` from within `(Compliant)Series`. + +On `main`, this works fine when running a type checker from the CLI - but causes +intermittent warnings when running in a language server. +""" diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index ce51e19087..63347ee429 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -8,10 +8,11 @@ parse_predicates_constraints_into_expr_ir, ) from narwhals._plan.expr import Expr +from narwhals._plan.expressions import ternary_expr from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: - from narwhals._plan.expressions import ExprIR, TernaryExpr + from narwhals._plan.expressions import ExprIR from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq @@ -114,9 +115,3 @@ def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] def __eq__(self, other: IntoExpr) -> Expr: # type: ignore[override] return Expr.__eq__(self, other) - - -def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: - from narwhals._plan.expressions.expr import TernaryExpr - - return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/tests/plan/all_any_horizontal_test.py b/tests/plan/all_any_horizontal_test.py new file mode 100644 index 0000000000..3db55ec694 --- /dev/null +++ b/tests/plan/all_any_horizontal_test.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "a": [False, False, True], + "b": [False, True, True], + "c": [True, True, False], + "d": [True, None, None], + "e": [None, True, False], + } + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.all_horizontal("a", nwp.col("b"), ignore_nulls=True), + {"a": [False, False, True]}, + ), + pytest.param( + nwp.all_horizontal("c", "d", ignore_nulls=True), + {"c": [True, True, False]}, + id="ignore_nulls-1", + ), + (nwp.all_horizontal("c", "d", ignore_nulls=False), {"c": [True, None, False]}), + ( + nwp.all_horizontal(nwp.nth(0, 1), ignore_nulls=True), + {"a": [False, False, True]}, + ), + pytest.param( + nwp.all_horizontal( + nwp.col("a"), nwp.nth(0), ncs.first(), "a", ignore_nulls=True + ), + {"a": [False, False, True]}, + id="duplicated", + ), + ( + nwp.all_horizontal(nwp.exclude("a", "b"), ignore_nulls=False), + {"c": [None, None, False]}, + ), + pytest.param( + nwp.all_horizontal(ncs.all() - ncs.by_index(0, 1), ignore_nulls=True), + {"c": [True, True, False]}, + id="ignore_nulls-2", + ), + ], +) +def test_all_horizontal(data: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data).select(expr) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.any_horizontal("a", nwp.col("b"), ignore_nulls=True), + {"a": [False, True, True]}, + ), + (nwp.any_horizontal("c", "d", ignore_nulls=False), {"c": [True, True, None]}), + pytest.param( + nwp.any_horizontal("c", "d", ignore_nulls=True), + {"c": [True, True, False]}, + id="ignore_nulls-1", + ), + ( + nwp.any_horizontal(nwp.nth(0, 1), ignore_nulls=False), + {"a": [False, True, True]}, + ), + pytest.param( + nwp.any_horizontal( + nwp.col("a"), nwp.nth(0), ncs.first(), "a", ignore_nulls=True + ), + {"a": [False, False, True]}, + id="duplicated", + ), + ( + nwp.any_horizontal(nwp.exclude("a", "b"), ignore_nulls=False), + {"c": [True, True, None]}, + ), + pytest.param( + nwp.any_horizontal(ncs.all() - ncs.by_index(0, 1), ignore_nulls=True), + {"c": [True, True, False]}, + id="ignore_nulls-2", + ), + ], +) +def test_any_horizontal(data: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data).select(expr) + assert_equal_data(result, expected) diff --git a/tests/plan/cat_get_categories_test.py b/tests/plan/cat_get_categories_test.py new file mode 100644 index 0000000000..fa13f3c40b --- /dev/null +++ b/tests/plan/cat_get_categories_test.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import pytest + +import narwhals as nw +import narwhals._plan as nwp # noqa: F401 +import narwhals._plan.selectors as ncs +from narwhals._utils import Implementation +from tests.plan.utils import assert_equal_data, dataframe +from tests.utils import PYARROW_VERSION + +pytest.importorskip("pyarrow") + + +@pytest.mark.parametrize( + ("values", "expected"), + [(["one", "two", "two"], ["one", "two"]), (["A", "B", None, "D"], ["A", "B", "D"])], + ids=["full", "nulls"], +) +def test_get_categories( + values: list[str], expected: list[str], request: pytest.FixtureRequest +) -> None: + data = {"a": values} + df = dataframe(data) + request.applymarker( + pytest.mark.xfail( + (df.implementation is Implementation.PYARROW and PYARROW_VERSION < (15,)), + reason="Unsupported cast from string to dictionary using function cast_dictionary", + ) + ) + result = df.select(ncs.first().cast(nw.Categorical).cat.get_categories()) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/ceil_floor_test.py b/tests/plan/ceil_floor_test.py new file mode 100644 index 0000000000..ff08fce41a --- /dev/null +++ b/tests/plan/ceil_floor_test.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [1.12345, 2.56789, 3.901234, -0.5], "b": [1.045, None, 2.221, -5.9446]} + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("a").ceil(), [2.0, 3.0, 4.0, 0.0]), + (nwp.col("a").floor(), [1.0, 2.0, 3.0, -1.0]), + (nwp.col("b").ceil(), [2.0, None, 3.0, -5.0]), + (nwp.col("b").floor(), [1.0, None, 2.0, -6.0]), + ], + ids=["ceil", "floor", "ceil-nulls", "floor-nulls"], +) +def test_ceil_floor(data: Data, expr: nwp.Expr, expected: list[Any]) -> None: + result = dataframe(data).select(result=expr) + assert_equal_data(result, {"result": expected}) diff --git a/tests/plan/clip_test.py b/tests/plan/clip_test.py new file mode 100644 index 0000000000..3beb3f4db1 --- /dev/null +++ b/tests/plan/clip_test.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from narwhals.exceptions import MultiOutputExpressionError +from tests.plan.utils import assert_equal_data, dataframe, series + +if TYPE_CHECKING: + from narwhals._plan.typing import IntoExprColumn + from narwhals.typing import NumericLiteral, TemporalLiteral + +pytest.importorskip("pyarrow") + + +@pytest.mark.parametrize( + ("lower", "upper", "expected"), + [ + (3, 4, [3, 3, 3, 3, 4]), + (0, 4, [1, 2, 3, 0, 4]), + (None, 4, [1, 2, 3, -4, 4]), + (-2, 0, [0, 0, 0, -2, 0]), + (-2, None, [1, 2, 3, -2, 5]), + ("lb", nwp.col("ub") + 1, [3, 2, 3, 1, 3]), + (series([1, 1, 2, 4, 3]), None, [1, 2, 3, 4, 5]), + ], +) +def test_clip_expr( + lower: IntoExprColumn | NumericLiteral | TemporalLiteral | None, + upper: IntoExprColumn | NumericLiteral | TemporalLiteral | None, + expected: list[int], +) -> None: + data = {"a": [1, 2, 3, -4, 5], "lb": [3, 2, 1, 1, 1], "ub": [4, 4, 2, 2, 2]} + result = dataframe(data).select(nwp.col("a").clip(lower, upper)) + assert_equal_data(result, {"a": expected}) + + +def test_clip_invalid() -> None: + df = dataframe({"a": [1, 2, 3], "b": [4, 5, 6]}) + with pytest.raises(MultiOutputExpressionError): + df.select(nwp.col("a").clip(nwp.all(), nwp.col("a", "b"))) diff --git a/tests/plan/coalesce_test.py b/tests/plan/coalesce_test.py new file mode 100644 index 0000000000..a3005d7a0d --- /dev/null +++ b/tests/plan/coalesce_test.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data_int() -> Data: + return { + "a": [0, None, None, None, None], + "b": [1, None, None, 5, 3], + "c": [5, None, 3, 2, 1], + } + + +@pytest.fixture(scope="module") +def data_str() -> Data: + return { + "a": ["0", None, None, None, None], + "b": ["1", None, None, "5", "3"], + "c": ["5", None, "3", "2", "1"], + } + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.coalesce("a", "b", "c"), {"a": [0, None, 3, 5, 3]}), + ( + nwp.coalesce("a", "b", "c", nwp.lit(-100)).alias("lit"), + {"lit": [0, -100, 3, 5, 3]}, + ), + ( + nwp.coalesce(nwp.lit(None, nw.Int64), "b", "c", 500).alias("into_lit"), + {"into_lit": [1, 500, 3, 5, 3]}, + ), + ], +) +def test_coalesce_numeric(data_int: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data_int).select(expr) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.coalesce("a", "b", "c").alias("no_lit"), + {"no_lit": ["0", None, "3", "5", "3"]}, + ), + (nwp.coalesce("a", "b", "c", nwp.lit("xyz")), {"a": ["0", "xyz", "3", "5", "3"]}), + ], +) +def test_coalesce_strings(data_str: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data_str).select(expr) + assert_equal_data(result, expected) + + +def test_coalesce_series(data_str: Data) -> None: + df = dataframe(data_str) + ser = df.get_column("b").alias("b_renamed") + exprs = nwp.coalesce(ser, "a", nwp.col("c").fill_null("filled")), nwp.lit("ignored") + result = df.select(exprs) + assert_equal_data(result, {"b_renamed": ["1", "filled", "3", "5", "3"]}) + + +def test_coalesce_raises_non_expr() -> None: + class NotAnExpr: ... + + with pytest.raises( + TypeError, match=re.escape("'NotAnExpr' is not supported in `nw.lit`") + ): + nwp.coalesce("a", "b", "c", NotAnExpr()) # type: ignore[arg-type] + + +def test_coalesce_multi_output() -> None: + data = { + "col1": [True, None, False, False, None], + "col2": [True, False, True, False, None], + } + df = dataframe(data) + result = df.select(nwp.coalesce(nwp.all(), True)) + expected = {"col1": [True, False, False, False, True]} + assert_equal_data(result, expected) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 541983246b..37d6cc605e 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -11,17 +11,22 @@ pytest.importorskip("pyarrow") pytest.importorskip("numpy") -import datetime as dt -import numpy as np import pyarrow as pa import narwhals as nw from narwhals import _plan as nwp -from narwhals._utils import Version -from tests.plan.utils import assert_equal_data, dataframe, first, last, series +from tests.plan.utils import ( + assert_equal_data, + assert_equal_series, + dataframe, + first, + last, + series, +) if TYPE_CHECKING: + import datetime as dt from collections.abc import Sequence from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable @@ -124,39 +129,6 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: .name.to_uppercase(), {"C": [2.0, 9.0, 4.0], "D": [7.0, 8.0, 8.0]}, ), - ([nwp.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), - ([nwp.int_range(nwp.len())], {"literal": [0, 1, 2]}), - (nwp.int_range(nwp.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), - (nwp.int_range(nwp.col("b").min() + 4, nwp.col("d").last()), {"b": [5, 6, 7]}), - ( - [ - nwp.date_range( - dt.date(2020, 1, 1), - dt.date(2020, 4, 30), - interval="25d", - closed="none", - ) - ], - { - "literal": [ - dt.date(2020, 1, 26), - dt.date(2020, 2, 20), - dt.date(2020, 3, 16), - dt.date(2020, 4, 10), - ] - }, - ), - ( - ( - nwp.date_range( - dt.date(2021, 1, 30), - nwp.lit(18747, nw.Int32).cast(nw.Date), - interval="90d", - closed="left", - ).alias("date_range_cast_expr"), - {"date_range_cast_expr": [dt.date(2021, 1, 30)]}, - ) - ), (nwp.col("b") ** 2, {"b": [1, 4, 9]}), ( [2 ** nwp.col("b"), (nwp.lit(2.0) ** nwp.nth(1)).alias("lit")], @@ -205,7 +177,7 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: ), (nwp.col("e", "d").is_null().any(), {"e": [True], "d": [False]}), ( - [(~nwp.col("e", "d").is_null()).all(), "b"], + [(nwp.col("e", "d").is_not_null()).all(), "b"], {"e": [False, False, False], "d": [True, True, True], "b": [1, 2, 3]}, ), pytest.param( @@ -390,48 +362,6 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: {"literal": ["a|b|c|d|20"]}, id="concat_str-all-lit", ), - pytest.param( - [ - nwp.col("a") - .alias("...") - .map_batches( - lambda s: s.from_iterable( - [*((len(s) - 1) * [type(s.dtype).__name__.lower()]), "last"], - version=Version.MAIN, - name="funky", - ), - is_elementwise=True, - ), - nwp.col("a"), - ], - {"funky": ["string", "string", "last"], "a": ["A", "B", "A"]}, - id="map_batches-series", - ), - pytest.param( - nwp.col("b") - .map_batches(lambda s: s.to_numpy() + 1, nw.Float64(), is_elementwise=True) - .sum(), - {"b": [9.0]}, - id="map_batches-numpy", - ), - pytest.param( - ncs.by_name("b", "c", "d") - .map_batches(lambda s: np.append(s.to_numpy(), [10, 2]), is_elementwise=True) - .sort(), - {"b": [1, 2, 2, 3, 10], "c": [2, 2, 4, 9, 10], "d": [2, 7, 8, 8, 10]}, - id="map_batches-selector", - ), - pytest.param( - nwp.col("j", "k") - .fill_null(15) - .map_batches(lambda s: (s.to_numpy().max()), returns_scalar=True), - {"j": [15], "k": [42]}, - id="map_batches-return_scalar", - marks=pytest.mark.xfail( - reason="not implemented `map_batches(returns_scalar=True)` for `pyarrow`", - raises=NotImplementedError, - ), - ), pytest.param( [nwp.col("g").len(), nwp.col("m").last(), nwp.col("h").count()], {"g": [3], "m": [2], "h": [1]}, @@ -521,6 +451,26 @@ def test_with_columns( assert_equal_data(result, expected) +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.all().first(), {"a": 8, "b": 58, "c": 2.5, "d": 2, "idx": 0}), + (ncs.numeric().null_count(), {"a": 1, "b": 0, "c": 0, "d": 0, "idx": 0}), + ( + ncs.by_index(range(5)).cast(nw.Boolean).fill_null(False).all(), + {"a": False, "b": True, "c": True, "d": True, "idx": False}, + ), + ], +) +def test_with_columns_all_aggregates( + data_indexed: dict[str, Any], expr: nwp.Expr, expected: dict[str, PythonLiteral] +) -> None: + height = len(next(iter(data_indexed.values()))) + expected_full = {k: height * [v] for k, v in expected.items()} + result = dataframe(data_indexed).with_columns(expr) + assert_equal_data(result, expected_full) + + @pytest.mark.parametrize( ("agg", "expected"), [ @@ -694,6 +644,113 @@ def test_series_to_polars(values: Sequence[PythonLiteral]) -> None: pl_assert_series_equal(result, expected) +def test_dataframe_iter_columns(data_small: Data) -> None: + df = dataframe(data_small) + result = df.from_dict({s.name: s for s in df.iter_columns()}).to_dict(as_series=False) + assert_equal_data(df, result) + + +def test_dataframe_from_dict_misc(data_small: Data) -> None: + pytest.importorskip("pyarrow") + items = iter(data_small.items()) + name, values = next(items) + mapping: dict[str, Any] = { + name: nwp.Series.from_iterable(values, name=name, backend="pyarrow") + } + mapping.update(items) + result = nwp.DataFrame.from_dict(mapping) + assert_equal_data(result, data_small) + + with pytest.raises(TypeError, match=r"from_dict.+without.+backend"): + nwp.DataFrame.from_dict(data_small) # type: ignore[arg-type] + + +def test_dataframe_to_struct(data_small_af: Data) -> None: + pytest.importorskip("pyarrow") + + schema = { + "a": nw.String(), + "b": nw.Int64(), + "c": nw.Int64(), + "d": nw.Int64(), + "e": nw.Int64(), + "f": nw.Boolean(), + } + + df = dataframe(data_small_af).with_columns( + nwp.col(name).cast(dtype) for name, dtype in schema.items() + ) + result = df.to_struct("struct_series") + result_dtype = result.dtype + assert isinstance(result_dtype, nw.Struct) + result_schema = dict(result_dtype.to_schema()) + assert result_schema == schema + + expected = [ + {"a": "A", "b": 1, "c": 9, "d": 8, "e": None, "f": True}, + {"a": "B", "b": 2, "c": 2, "d": 7, "e": 9, "f": False}, + {"a": "A", "b": 3, "c": 4, "d": 8, "e": 7, "f": None}, + ] + assert_equal_series(result, expected, "struct_series") + + +# TODO @dangotbanned: Split this up +def test_series_misc() -> None: + pytest.importorskip("pyarrow") + + values = [1.0, None, 7.1, float("nan"), 4.9, 12.0, 1.1, float("nan"), 0.2, None] + name = "ser" + ser = nwp.Series.from_iterable(values, name=name, dtype=nw.Float64, backend="pyarrow") + assert ser.is_empty() is False + assert ser.has_nulls() + assert ser.null_count() == 2 + + is_null = ser.is_null() + is_nan = ser.is_nan() + is_not_null = ser.is_not_null() + is_not_nan = ser.is_not_nan() + is_useful = ~(is_null | is_nan) + + assert is_useful.any() + + assert_equal_series(is_null, ~is_not_null) + assert_equal_series(~is_null, is_not_null) + assert_equal_series(is_nan, ~is_not_nan) + assert_equal_series(~is_nan, is_not_nan) + + expected = [False, None, False, False, False, False, False, False, False, None] + assert_equal_series(is_null & is_nan, expected, name) + expected = [False, True, False, False, False, False, False, False, False, True] + assert_equal_series(is_null, expected, name) + expected = [True, False, True, False, True, True, True, False, True, False] + assert_equal_series(is_not_nan & is_not_null, expected, name) + + assert ser.unique().drop_nans().drop_nulls().count() == 6 + assert len(list(ser)) == len(values) + + +def test_series_sort() -> None: + ser = series([1.0, 7.1, None, 4.9]) + assert_equal_series(ser.sort(), [None, 1.0, 4.9, 7.1], "") + assert_equal_series(ser.sort(nulls_last=True), [1.0, 4.9, 7.1, None], "") + assert_equal_series(ser.sort(descending=True), [None, 7.1, 4.9, 1.0], "") + assert_equal_series( + ser.sort(descending=True, nulls_last=True), [7.1, 4.9, 1.0, None], "" + ) + + +def test_series_cast() -> None: + pytest.importorskip("pyarrow") + ser = nwp.int_range(10, step=2, eager="pyarrow", dtype=nw.Int64) + assert ser.dtype == nw.Int64 + ser_float = ser.cast(nw.Float64) + assert ser_float.dtype == nw.Float64 + assert ser.dtype == nw.Int64 + result = ser_float + 0.5 + expected = [0.5, 2.5, 4.5, 6.5, 8.5] + assert_equal_series(result, expected, "literal") + + if TYPE_CHECKING: from typing_extensions import assert_type diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py index 8d8f8d994e..23cc085959 100644 --- a/tests/plan/dispatch_test.py +++ b/tests/plan/dispatch_test.py @@ -47,11 +47,11 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: df.select(nwp.col("c").ewm_mean()) missing_protocol = re_compile( - r"str\.contains.+has not been implemented.+compliant.+" - r"Hint.+try adding.+CompliantExpr\.str\.contains\(\)" + r"dt\.offset_by.+has not been implemented.+compliant.+" + r"Hint.+try adding.+CompliantExpr\.dt\.offset_by\(\)" ) with pytest.raises(NotImplementedError, match=missing_protocol): - df.select(nwp.col("d").str.contains("a")) + df.select(nwp.col("d").dt.offset_by("1d")) with pytest.raises( TypeError, @@ -78,11 +78,23 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: (nwp.int_range(10), "int_range"), (nwp.col("a") + nwp.col("b") + 10, "binary_expr"), (nwp.when(nwp.col("c")).then(5).when(nwp.col("d")).then(20), "ternary_expr"), + (nwp.col("a").rolling_sum(2), "rolling_expr"), + (nwp.col("a").cum_sum(), "cum_sum"), + (nwp.col("a").cat.get_categories(), "cat.get_categories"), + (nwp.col("a").dt.timestamp(), "dt.timestamp"), + (nwp.col("a").dt.replace_time_zone(None), "dt.replace_time_zone"), + (nwp.col("a").list.len(), "list.len"), (nwp.col("a").cast(nw.String).str.starts_with("something"), ("str.starts_with")), + (nwp.col("a").str.slice(1), ("str.slice")), + (nwp.col("a").str.head(), ("str.slice")), + (nwp.col("a").str.tail(), ("str.slice")), + (nwp.col("a").struct.field("b"), "struct.field"), (nwp.mean("a"), "mean"), (nwp.nth(1).first(), "first"), (nwp.col("a").sum(), "sum"), + (~nwp.col("a"), "not_"), (nwp.col("a").drop_nulls().arg_min(), "arg_min"), + (nwp.col("a").map_batches(lambda x: x), "map_batches"), pytest.param(nwp.col("a").alias("b"), "Alias", id="no_dispatch-Alias"), pytest.param(ncs.string(), "RootSelector", id="no_dispatch-RootSelector"), ], diff --git a/tests/plan/explode_test.py b/tests/plan/explode_test.py new file mode 100644 index 0000000000..bb8601dc8c --- /dev/null +++ b/tests/plan/explode_test.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Final + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from narwhals.exceptions import InvalidOperationError, ShapeError +from tests.plan.utils import ( + assert_equal_data, + assert_equal_series, + dataframe, + re_compile, + series, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from narwhals._plan.typing import ColumnNameOrSelector + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.fixture(scope="module") +def data() -> Data: + # For context, polars allows to explode multiple columns only if the columns + # have matching element counts, therefore, l1 and l2 but not l1 and l3 together. + return { + "a": ["x", "y", "z", "w"], + "l1": [[1, 2], None, [None], []], + "l2": [[3, None], None, [42], []], + "l3": [[1, 2], [3], [None], [1]], + "l4": [[1, 2], [3], [123], [456]], + "l5": [[None, None], [None], [99], [83]], + } + + +@pytest.mark.parametrize( + ("column", "expected_values"), + [("l2", [None, 3, None, None, 42]), ("l3", [1, 1, 2, 3, None])], +) +def test_explode_frame_single_col( + column: str, expected_values: list[int | None], data: Data +) -> None: + result = ( + dataframe(data) + .with_columns(nwp.col(column).cast(nw.List(nw.Int32()))) + .explode(column) + .select("a", column) + .sort("a", column, nulls_last=True) + ) + expected = {"a": ["w", "x", "x", "y", "z"], column: expected_values} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("column", "expected_values"), + [ + ("l2", [None, None, None, 3, 42]), + ("l3", [None, 1, 1, 2, 3]), + ("l4", [1, 2, 3, 123, 456]), + ("l5", [None, None, None, 83, 99]), + ], +) +def test_explode_frame_only_column( + column: str, expected_values: list[int | None], data: Data +) -> None: + result = ( + dataframe(data) + .select(nwp.col(column).cast(nw.List(nw.Int32()))) + .explode(column) + .sort(column) + ) + assert_equal_data(result, {column: expected_values}) + + +@pytest.mark.parametrize( + ("column", "more_columns", "expected"), + [ + ( + "l1", + ["l2"], + { + "a": ["w", "x", "x", "y", "z"], + "l1": [None, 1, 2, None, None], + "l2": [None, 3, None, None, 42], + }, + ), + ( + "l3", + ["l4"], + { + "a": ["w", "x", "x", "y", "z"], + "l3": [1, 1, 2, 3, None], + "l4": [456, 1, 2, 3, 123], + }, + ), + ], +) +def test_explode_frame_multiple_cols( + column: str, + more_columns: Sequence[str], + expected: dict[str, list[str | int | None]], + data: Data, +) -> None: + result = ( + dataframe(data) + .with_columns(nwp.col(column, *more_columns).cast(nw.List(nw.Int32()))) + .explode(column, *more_columns) + .select("a", column, *more_columns) + .sort("a", column, nulls_last=True) + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + ncs.by_index(-1, -2, -3), + { + "a": ["w", "x", "x", "y", "z"], + "l5": [83, None, None, None, 99], + "l4": [456, 1, 2, 3, 123], + "l3": [1, 1, 2, 3, None], + }, + ), + ( + ncs.matches(r"l[3|5]"), + { + "a": ["w", "x", "x", "y", "z"], + "l3": [1, 1, 2, 3, None], + "l5": [83, None, None, None, 99], + }, + ), + ], +) +def test_explode_frame_selectors(expr: nwp.Selector, expected: Data, data: Data) -> None: + result = ( + dataframe(data) + .with_columns(expr.cast(nw.List(nw.Int32()))) + .explode(expr) + .select("a", expr) + .sort("a", expr, nulls_last=True) + ) + assert_equal_data(result, expected) + + +def test_explode_frame_shape_error(data: Data) -> None: + with pytest.raises( + ShapeError, match=r".*exploded columns (must )?have matching element counts" + ): + dataframe(data).with_columns( + nwp.col("l1", "l2", "l3").cast(nw.List(nw.Int32())) + ).explode(ncs.list()) + + +def test_explode_frame_invalid_operation_error(data: Data) -> None: + with pytest.raises( + InvalidOperationError, + match=re_compile(r"explode.+not supported for.+string.+expected.+list"), + ): + dataframe(data).explode("a") + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ([[1, 2, 3]], [1, 2, 3]), + ([[1, 2, 3], None], [1, 2, 3, None]), + ([[1, 2, 3], []], [1, 2, 3, None]), + ], +) +def test_explode_series_default(values: list[Any], expected: list[Any]) -> None: + # Based on https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/operations/test_explode.py#L465-L470 + result = series(values).explode() + assert_equal_series(result, expected, "") + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ([[1, 2, 3], [1, 2], [1, 2]], [1, 2, 3, None, 1, 2]), + ([[1, 2, 3], [], [1, 2]], [1, 2, 3, None, 1, 2]), + ], +) +def test_explode_series_default_masked(values: list[Any], expected: list[Any]) -> None: + # Based on https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/operations/test_explode.py#L471-484 + result = ( + series(values) + .to_frame() + .select(nwp.when(series([True, False, True])).then(nwp.col(""))) + .to_series() + .explode() + ) + assert_equal_series(result, expected, "") + + +DROP_EMPTY: Final = {"empty_as_null": False} +DROP_NULLS: Final = {"keep_nulls": False} +DROP_BOTH: Final = {"empty_as_null": False, "keep_nulls": False} +DEFAULT: Final[Data] = {} + + +@pytest.mark.parametrize( + ("values", "kwds", "expected"), + [ + ([[1, 2, 3]], DROP_BOTH, [1, 2, 3]), + ([[1, 2, 3], None], DROP_NULLS, [1, 2, 3]), + ([[1, 2, 3], [None]], DROP_NULLS, [1, 2, 3, None]), + ([[1, 2, 3], []], DROP_EMPTY, [1, 2, 3]), + ([[1, 2, 3], [None]], DROP_EMPTY, [1, 2, 3, None]), + ], +) +def test_explode_series_options( + values: list[Any], kwds: dict[str, Any], expected: list[Any] +) -> None: + # Based on https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/operations/test_explode.py#L486-L505 + result = series(values).explode(**kwds) + assert_equal_series(result, expected, "") + + +A = ("a",) +BA = "b", "a" + +DEFAULT_A: Final = [1, 2, 3, None, 4, 5, 6, None] +DEFAULT_I: Final = [1, 1, 1, 2, 3, 3, 3, 4] +DEFAULT_B: Final = [None, "dog", "cat", None, "narwhal", None, "orca", None] +EMPTY_A: Final = [1, 2, 3, None, 4, 5, 6] +EMPTY_I: Final = [1, 1, 1, 2, 3, 3, 3] +EMPTY_B: Final = [None, "dog", "cat", None, "narwhal", None, "orca"] +NULLS_A: Final = [1, 2, 3, 4, 5, 6, None] +NULLS_I: Final = [1, 1, 1, 3, 3, 3, 4] +NULLS_B: Final = [None, "dog", "cat", "narwhal", None, "orca", None] +BOTH_A: Final = [1, 2, 3, 4, 5, 6] +BOTH_I: Final = [1, 1, 1, 3, 3, 3] +BOTH_B: Final = [None, "dog", "cat", "narwhal", None, "orca"] + + +@pytest.mark.parametrize( + ("columns", "kwds", "expected"), + [ + (A, DEFAULT, {"a": DEFAULT_A, "i": DEFAULT_I}), + (A, DROP_EMPTY, {"a": EMPTY_A, "i": EMPTY_I}), + (A, DROP_NULLS, {"a": NULLS_A, "i": NULLS_I}), + (A, DROP_BOTH, {"a": BOTH_A, "i": BOTH_I}), + (BA, DEFAULT, {"b": DEFAULT_B, "a": DEFAULT_A, "i": DEFAULT_I}), + (BA, DROP_EMPTY, {"b": EMPTY_B, "a": EMPTY_A, "i": EMPTY_I}), + (BA, DROP_NULLS, {"b": NULLS_B, "a": NULLS_A, "i": NULLS_I}), + (BA, DROP_BOTH, {"b": BOTH_B, "a": BOTH_A, "i": BOTH_I}), + ], +) +def test_explode_frame_options( + columns: Sequence[ColumnNameOrSelector], kwds: dict[str, Any], expected: Data +) -> None: + # Based on https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/operations/test_explode.py#L596-L616 + data = { + "a": [[1, 2, 3], None, [4, 5, 6], []], + "b": [[None, "dog", "cat"], None, ["narwhal", None, "orca"], []], + "i": [1, 2, 3, 4], + } + result = ( + dataframe(data) + .with_columns( + nwp.col("a").cast(nw.List(nw.Int32())), nwp.col("b").cast(nw.List(nw.String)) + ) + .select(*columns, "i") + .explode(columns, **kwds) + ) + assert_equal_data(result, expected) + + +def test_explode_frame_single_elements() -> None: + data = {"a": [[1], [2], [3]], "b": [[4], [5], [6]], "i": [0, 10, 20]} + df = dataframe(data).with_columns(nwp.col("a", "b").cast(nw.List(nw.Int32()))) + + result = df.explode("a") + expected = {"a": [1, 2, 3], "b": [[4], [5], [6]], "i": [0, 10, 20]} + assert_equal_data(result, expected) + + result = df.explode("b", "a") + expected = {"a": [1, 2, 3], "b": [4, 5, 6], "i": [0, 10, 20]} + assert_equal_data(result, expected) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index b588b966b8..34ab20a1c0 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -15,7 +15,7 @@ from narwhals._plan import expressions as ir from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.expressions import functions as F, operators as ops -from narwhals._plan.expressions.literal import SeriesLiteral +from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.expressions.ranges import IntRange from narwhals._utils import Implementation from narwhals.exceptions import ( @@ -26,7 +26,7 @@ MultiOutputExpressionError, ShapeError, ) -from tests.plan.utils import assert_expr_ir_equal, re_compile +from tests.plan.utils import assert_equal_data, assert_expr_ir_equal, re_compile if TYPE_CHECKING: from contextlib import AbstractContextManager @@ -209,13 +209,7 @@ def test_date_range_invalid() -> None: nwp.date_range(start, end, interval="3y") -def test_int_range_eager() -> None: - series = nwp.int_range(50, eager="pyarrow") - assert isinstance(series, nwp.Series) - assert series.to_list() == list(range(50)) - series = nwp.int_range(50, eager=Implementation.PYARROW) - assert series.to_list() == list(range(50)) - +def test_int_range_eager_invalid() -> None: with pytest.raises(InvalidOperationError): nwp.int_range(nwp.len(), eager="pyarrow") # type: ignore[call-overload] with pytest.raises(InvalidOperationError): @@ -226,49 +220,6 @@ def test_int_range_eager() -> None: nwp.int_range(10, eager="duckdb") # type: ignore[call-overload] -def test_date_range_eager() -> None: - leap_year = 2024 - series_leap = nwp.date_range( - dt.date(leap_year, 2, 25), dt.date(leap_year, 3, 25), eager="pyarrow" - ) - series_regular = nwp.date_range( - dt.date(leap_year + 1, 2, 25), - dt.date(leap_year + 1, 3, 25), - interval=dt.timedelta(days=1), - eager="pyarrow", - ) - assert len(series_regular) == 29 - assert len(series_leap) == 30 - - expected = [ - dt.date(2000, 1, 1), - dt.date(2002, 9, 14), - dt.date(2005, 5, 28), - dt.date(2008, 2, 9), - dt.date(2010, 10, 23), - dt.date(2013, 7, 6), - dt.date(2016, 3, 19), - dt.date(2018, 12, 1), - dt.date(2021, 8, 14), - ] - - series = nwp.date_range( - dt.date(2000, 1, 1), dt.date(2023, 8, 31), interval="987d", eager="pyarrow" - ) - result = series.to_list() - assert result == expected - - expected = [dt.date(2006, 10, 14), dt.date(2013, 7, 27), dt.date(2020, 5, 9)] - result = nwp.date_range( - dt.date(2000, 1, 1), - dt.date(2023, 8, 31), - interval="354w", - closed="right", - eager="pyarrow", - ).to_list() - assert result == expected - - def test_date_range_eager_invalid() -> None: start, end = dt.date(2000, 1, 1), dt.date(2001, 1, 1) @@ -377,6 +328,14 @@ def test_binary_expr_shape_invalid() -> None: a.fill_null(1) // b.rolling_mean(5) +def test_map_batches_invalid() -> None: + with pytest.raises( + TypeError, + match=r"A function cannot both return a scalar and preserve length, they are mutually exclusive", + ): + nwp.col("a").map_batches(lambda x: x, is_elementwise=True, returns_scalar=True) + + @pytest.mark.parametrize("into_iter", [list, tuple, deque, iter, dict.fromkeys, set]) def test_is_in_seq(into_iter: IntoIterable) -> None: expected = 1, 2, 3 @@ -576,14 +535,14 @@ def test_hist_bins() -> None: def test_hist_bin_count() -> None: bin_count_default = 10 - include_breakpoint_default = True + include_breakpoint_default = False a = nwp.col("a") hist_1 = a.hist( bin_count=bin_count_default, include_breakpoint=include_breakpoint_default ) hist_2 = a.hist() hist_3 = a.hist(bin_count=5) - hist_4 = a.hist(include_breakpoint=False) + hist_4 = a.hist(include_breakpoint=True) ir_1 = hist_1._ir ir_2 = hist_2._ir @@ -693,3 +652,113 @@ def test_replace_strict_invalid() -> None: match="`new` argument cannot be used if `old` argument is a Mapping type", ): nwp.col("a").replace_strict(old={1: 2, 3: 4}, new=[5, 6, 7]) + + +def test_mode_invalid() -> None: + with pytest.raises( + TypeError, match=r"keep.+must be one of.+all.+any.+but got 'first'" + ): + nwp.col("a").mode(keep="first") # type: ignore[arg-type] + + +def test_broadcast_len_1_series_invalid() -> None: + pytest.importorskip("pyarrow") + data = {"a": [1, 2, 3]} + values = [4] + df = nwp.DataFrame.from_dict(data, backend="pyarrow") + ser = nwp.Series.from_iterable(values, name="bad", backend="pyarrow") + with pytest.raises( + ShapeError, + match=re_compile( + r"series.+bad.+length.+1.+match.+DataFrame.+height.+3.+broadcasted.+\.first\(\)" + ), + ): + df.with_columns(ser) + + expected_series = {"a": [1, 2, 3], "literal": [4, 4, 4]} + # we can only preserve `Series.name` if we got a `lit(Series).first()`, not `lit(Series.first())` + expected_series_literal = {"a": [1, 2, 3], "bad": [4, 4, 4]} + + assert_equal_data(df.with_columns(ser.first()), expected_series) + assert_equal_data(df.with_columns(ser.last()), expected_series) + assert_equal_data(df.with_columns(nwp.lit(ser).first()), expected_series_literal) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "context"), + [ + (-1, None, pytest.raises(ValueError, match=r"window_size.+>= 1")), + (2, -1, pytest.raises(ValueError, match=r"min_samples.+>= 1")), + ( + 1, + 2, + pytest.raises(InvalidOperationError, match=r"min_samples.+<=.+window_size"), + ), + ( + 4.2, + None, + pytest.raises(TypeError, match=r"Expected.+int.+got.+float.+\s+window_size="), + ), + ( + 2, + 4.2, + pytest.raises(TypeError, match=r"Expected.+int.+got.+float.+\s+min_samples="), + ), + ], +) +def test_rolling_expr_invalid( + window_size: int, min_samples: int | None, context: pytest.RaisesExc[Any] +) -> None: + a = nwp.col("a") + with context: + a.rolling_sum(window_size, min_samples=min_samples) + with context: + a.rolling_mean(window_size, min_samples=min_samples) + with context: + a.rolling_var(window_size, min_samples=min_samples) + with context: + a.rolling_std(window_size, min_samples=min_samples) + + +def test_list_contains_invalid() -> None: + a = nwp.col("a") + + ok = a.list.contains("a") + assert_expr_ir_equal( + ok, + ir.FunctionExpr( + input=( + ir.col("a"), + ir.Literal(value=ScalarLiteral(value="a", dtype=nw.String())), + ), + function=ir.lists.Contains(), + options=ir.lists.Contains().function_options, + ), + ) + assert a.list.contains(a.first()) + assert a.list.contains(1) + assert a.list.contains(nwp.lit(1)) + assert a.list.contains(dt.datetime(2000, 2, 1, 9, 26, 5)) + assert a.list.contains(a.abs().fill_null(5).mode(keep="any")) + + with pytest.raises( + InvalidOperationError, match=r"list.contains.+non-scalar.+`col\('a'\)" + ): + a.list.contains(a) + + with pytest.raises(InvalidOperationError, match=r"list.contains.+non-scalar.+abs"): + a.list.contains(a.abs()) + + with pytest.raises(TypeError, match=r"list.+not.+supported.+nw.lit.+1.+2.+3"): + a.list.contains([1, 2, 3]) # type: ignore[arg-type] + + +def test_list_get_invalid() -> None: + a = nwp.col("a") + assert a.list.get(0) + pattern = re_compile(r"expected.+int.+got.+str.+'not an index'") + with pytest.raises(TypeError, match=pattern): + a.list.get("not an index") # type: ignore[arg-type] + pattern = re_compile(r"index.+out of bounds.+>= 0.+got -1") + with pytest.raises(InvalidOperationError, match=pattern): + a.list.get(-1) diff --git a/tests/plan/fill_nan_test.py b/tests/plan/fill_nan_test.py new file mode 100644 index 0000000000..a0c57fc90a --- /dev/null +++ b/tests/plan/fill_nan_test.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from tests.plan.utils import assert_equal_data, assert_equal_series, dataframe, series + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"int": [-1, 1, None]} + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + ( + [nwp.col("no_nan").fill_nan(None), nwp.col("float_nan").fill_nan(None)], + [None, 1.0, None], + ), + ( + [nwp.col("no_nan").fill_nan(3.0), nwp.col("float_nan").fill_nan(3.0)], + [3.0, 1.0, None], + ), + (nwp.all().fill_nan(None), [None, 1.0, None]), + (nwp.all().fill_nan(3.0), [3.0, 1.0, None]), + ( + ncs.numeric().as_expr().fill_nan(nwp.lit(series([55.5, -100, -200]))), + [55.5, 1.0, None], + ), + ( + [ + nwp.col("no_nan"), + nwp.col("float_nan").fill_nan(nwp.col("no_nan").max() * 6), + ], + [6.0, 1.0, None], + ), + ], +) +def test_fill_nan( + data: Data, exprs: OneOrIterable[nwp.Expr], expected: list[Any] +) -> None: + base = nwp.col("int") + df = dataframe(data).select( + base.cast(nw.Float64).alias("no_nan"), (base**0.5).alias("float_nan") + ) + result = df.select(exprs) + assert_equal_data(result, {"no_nan": [-1.0, 1.0, None], "float_nan": expected}) + assert result.get_column("float_nan").null_count() == expected.count(None) + + +def test_fill_nan_series(data: Data) -> None: + ser = dataframe(data).select(float_nan=nwp.col("int") ** 0.5).get_column("float_nan") + result = ser.fill_nan(999) + assert_equal_series(result, [999.0, 1.0, None], "float_nan") + result = ser.fill_nan(series([1.23, None, None])) + assert_equal_series(result, [1.23, 1.0, None], "float_nan") diff --git a/tests/plan/fill_null_test.py b/tests/plan/fill_null_test.py new file mode 100644 index 0000000000..450eca14d5 --- /dev/null +++ b/tests/plan/fill_null_test.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from narwhals._plan import selectors as ncs +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + +DATA_1 = { + "a": [0.0, None, 2.0, 3.0, 4.0], + "b": [1.0, None, None, 5.0, 3.0], + "c": [5.0, None, 3.0, 2.0, 1.0], +} +DATA_2 = { + "a": [0.0, None, 2.0, 3.0, 4.0], + "b": [1.0, None, None, 5.0, 3.0], + "c": [5.0, 2.0, None, 2.0, 1.0], +} +DATA_LIMITS = { + "a": [1, None, None, None, 5, 6, None, None, None, 10], + "b": ["a", None, None, None, "b", "c", None, None, None, "d"], + "c": [None, 2.5, None, None, None, None, 3.6, None, 2.2, 3.0], + "d": [1, None, None, None, None, None, None, None, 2, None], + "idx": list(range(10)), +} + + +@pytest.mark.parametrize( + ("data", "exprs", "expected"), + [ + pytest.param( + DATA_1, + nwp.all().fill_null(value=99), + {"a": [0.0, 99, 2, 3, 4], "b": [1.0, 99, 99, 5, 3], "c": [5.0, 99, 3, 2, 1]}, + id="literal", + ), + pytest.param( + {"a": [0.5, None, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", None, "yy"]}, + [nwp.col("a").fill_null(nwp.col("a").mean()), nwp.col("b").fill_null("a")], + {"a": [0.5, 2.5, 2.0, 3.0, 4.5], "b": ["xx", "yy", "zz", "a", "yy"]}, + id="expr-aggregate", + ), + pytest.param( + DATA_2, + nwp.nth(0, 1).fill_null(nwp.col("c")), + {"a": [0.0, 2, 2, 3, 4], "b": [1.0, 2, None, 5, 3]}, + id="expr-column", + ), + pytest.param( + DATA_LIMITS, + ncs.by_index(0, 1).fill_null(strategy="forward").over(order_by="idx"), + { + "a": [1, 1, 1, 1, 5, 6, 6, 6, 6, 10], + "b": ["a", "a", "a", "a", "b", "c", "c", "c", "c", "d"], + }, + id="forward", + ), + pytest.param( + DATA_LIMITS, + nwp.exclude("idx").fill_null(strategy="backward").over(order_by="idx"), + { + "a": [1, 5, 5, 5, 5, 6, 10, 10, 10, 10], + "b": ["a", "b", "b", "b", "b", "c", "d", "d", "d", "d"], + "c": [2.5, 2.5, 3.6, 3.6, 3.6, 3.6, 3.6, 2.2, 2.2, 3.0], + "d": [1, 2, 2, 2, 2, 2, 2, 2, 2, None], + }, + id="backward", + ), + pytest.param( + DATA_LIMITS, + nwp.col("a", "b").fill_null(strategy="forward", limit=2).over(order_by="idx"), + { + "a": [1, 1, 1, None, 5, 6, 6, 6, None, 10], + "b": ["a", "a", "a", None, "b", "c", "c", "c", None, "d"], + }, + id="forward-limit", + ), + pytest.param( + DATA_LIMITS, + [ + nwp.col("a", "b") + .fill_null(strategy="backward", limit=2) + .over(order_by="idx"), + nwp.col("c").fill_null(strategy="backward", limit=3).over(order_by="idx"), + ], + { + "a": [1, None, 5, 5, 5, 6, None, 10, 10, 10], + "b": ["a", None, "b", "b", "b", "c", None, "d", "d", "d"], + "c": [2.5, 2.5, None, 3.6, 3.6, 3.6, 3.6, 2.2, 2.2, 3.0], + }, + id="backward-limit", + ), + pytest.param( + DATA_LIMITS, + nwp.col("c").fill_null(strategy="forward", limit=3).over(order_by="idx"), + {"c": [None, 2.5, 2.5, 2.5, 2.5, None, 3.6, 3.6, 2.2, 3.0]}, + id="forward-limit-nulls-first", + ), + pytest.param( + DATA_LIMITS, + nwp.col("d").fill_null(strategy="backward", limit=3).over(order_by="idx"), + {"d": [1, None, None, None, None, 2, 2, 2, 2, None]}, + id="backward-limit-nulls-last", + ), + ], +) +def test_fill_null(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: + df = dataframe(data) + assert_equal_data(df.select(exprs), expected) + + +@pytest.mark.parametrize( + "expr", + [ + (~ncs.last()).fill_null(strategy="forward"), + (~ncs.last()).fill_null(strategy="backward"), + (~ncs.last()).fill_null(strategy="forward", limit=100), + (~ncs.last()).fill_null(strategy="backward", limit=20), + ], +) +def test_fill_null_strategy_noop(expr: nwp.Expr) -> None: + data = {"a": [1, 2, 3], "b": [None, None, None], "i": [0, 1, 2]} + expected = {"a": [1, 2, 3], "b": [None, None, None]} + df = dataframe(data) + assert_equal_data(df.select(expr), expected) + assert_equal_data(df.select(expr.over(order_by=ncs.last())), expected) diff --git a/tests/plan/format_test.py b/tests/plan/format_test.py new file mode 100644 index 0000000000..84c479fd77 --- /dev/null +++ b/tests/plan/format_test.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.format("hello {} {} wassup", "name", nwp.col("surname")), + [ + "hello bob builder wassup", + "hello alice wonderlander wassup", + "hello dodo extinct wassup", + ], + ), + ( + nwp.format("{} {} wassup", "name", nwp.col("surname")), + ["bob builder wassup", "alice wonderlander wassup", "dodo extinct wassup"], + ), + ], +) +def test_format(expr: nwp.Expr, expected: list[str]) -> None: + data = { + "name": ["bob", "alice", "dodo"], + "surname": ["builder", "wonderlander", "extinct"], + } + result = dataframe(data).select(fmt=expr) + assert_equal_data(result, {"fmt": expected}) + + +def test_format_invalid() -> None: + with pytest.raises(ValueError, match="Expected 2 arguments, got 1"): + nwp.format("hello {} {} wassup", "name") diff --git a/tests/plan/gather_test.py b/tests/plan/gather_test.py new file mode 100644 index 0000000000..d8708297bf --- /dev/null +++ b/tests/plan/gather_test.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from narwhals.exceptions import ShapeError +from tests.plan.utils import assert_equal_data, assert_equal_series, dataframe, series + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "idx": list(range(10)), + "name": ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"], + } + + +@pytest.mark.parametrize("n", [1, 2, 3]) +@pytest.mark.parametrize("offset", [0, 1, 2, 3]) +@pytest.mark.parametrize("column", ["idx", "name"]) +def test_gather_every_series(data: Data, n: int, offset: int, column: str) -> None: + ser = series(data[column]).alias(column) + result = ser.gather_every(n, offset) + expected = data[column][offset::n] + assert_equal_series(result, expected, column) + + +@pytest.mark.parametrize( + ("column", "indices", "expected"), + [ + ("idx", [], []), + ("name", [], []), + ("idx", [0, 4, 2], [0, 4, 2]), + ("name", [1, 5, 5], ["b", "f", "f"]), + pytest.param( + "idx", + [-1], + [9], + marks=pytest.mark.xfail( + reason="TODO: Handle negative indices", raises=IndexError + ), + ), + ("name", range(5, 7), ["f", "g"]), + ], +) +def test_gather_series( + data: Data, column: str, indices: Any, expected: list[Any] +) -> None: + ser = series(data[column]).alias(column) + result = ser.gather(indices) + assert_equal_series(result, expected, column) + + +@pytest.mark.parametrize("n", [1, 2, 3]) +@pytest.mark.parametrize("offset", [0, 1, 2, 3]) +def test_gather_every_dataframe(data: Data, n: int, offset: int) -> None: + result = dataframe(data).gather_every(n, offset) + indices = slice(offset, None, n) + expected = {"idx": data["idx"][indices], "name": data["name"][indices]} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize("n", [1, 2, 3]) +@pytest.mark.parametrize("offset", [0, 1, 2, 3]) +def test_gather_every_expr(data: Data, n: int, offset: int) -> None: + df = dataframe(data) + indices = slice(offset, None, n) + v_idx, v_name = data["idx"][indices], data["name"][indices] + e_idx, e_name = nwp.col("idx"), nwp.col("name") + gather = partial(nwp.Expr.gather_every, n=n, offset=offset) + + result = df.select(gather(nwp.col("idx", "name"))) + expected = {"idx": v_idx, "name": v_name} + assert_equal_data(result, expected) + expected = {"name": v_name} + assert_equal_data(df.select(gather(e_name)), expected) + expected = {"name": v_name, "idx": v_idx} + assert_equal_data(df.select(gather(nwp.nth(1, 0))), expected) + expected = {"idx": v_idx, "name": v_name} + assert_equal_data(df.select(gather(e_idx), gather(ncs.last())), expected) + + if n == 1 and offset == 0: + result = df.select(gather(e_name), e_idx) + expected = {"name": data["name"], "idx": data["idx"]} + assert_equal_data(result, expected) + else: + with pytest.raises(ShapeError): + df.select(gather(e_name), e_idx) + result = df.select(gather(e_name), e_idx.first()) + expected = {"name": v_name, "idx": [0] * len(result)} + assert_equal_data(result, expected) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index b7f5035fe2..d82d372d0b 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -8,6 +8,7 @@ import narwhals as nw from narwhals import _plan as nwp from narwhals._plan import selectors as ncs +from narwhals._utils import Implementation from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_equal_data, dataframe from tests.utils import PYARROW_VERSION, assert_equal_data as _assert_equal_data @@ -587,8 +588,13 @@ def test_group_by_agg_last( "b_first": [3, 1, 3, 2, 1], }, ), + ( + ["d"], + [nwp.col("e", "b").unique()], + {"d": ["one", "three"], "e": [[1, 2], [None, 1]], "b": [[1, 3], [1, 2, 3]]}, + ), ], - ids=["Unique-Single", "Unique-Multi", "Unique-Selector-Fancy"], + ids=["Unique-Single", "Unique-Multi", "Unique-Selector-Fancy", "Unique-Nulls"], ) def test_group_by_agg_unique( keys: Sequence[str], aggs: Sequence[IntoExpr], expected: Mapping[str, Any] @@ -598,8 +604,42 @@ def test_group_by_agg_unique( "b": [1, 2, 1, 3, 3], "c": [5, 4, 3, 2, 1], "d": ["three", "three", "one", "three", "one"], + "e": [None, 1, 1, None, 2], + } + df = dataframe(data) + result = df.group_by(keys).agg(aggs).sort(keys) + assert_equal_data(result, expected) + + +def test_group_by_agg_kurtosis_skew(request: pytest.FixtureRequest) -> None: + data = { + "p1": ["a", "b", None, None, "b", "b"], + "p2": [1, 2, 1, None, None, None], + "p3": [None, 1, 1, 2, 2, None], + "a": [1, 2, 3, 4, 2, 1], + "b": [None, 9.9, 1.5, None, 1.0, 2.1], + } + expected = { + "p1": [None, "a", "b"], + "a_skew": [0.0, float("nan"), -0.707107], + "b_skew": [float("nan"), None, 0.666442], + "b_kurtosis": [float("nan"), None, -1.4999999999999996], + "a_kurtosis": [-2.0, float("nan"), -1.4999999999999998], } df = dataframe(data) + + request.applymarker( + pytest.mark.xfail( + (df.implementation is Implementation.PYARROW and PYARROW_VERSION < (20,)), + reason="too old for `pyarrow.compute.{kurtosis,skew}`", + ) + ) + + keys = ("p1",) + aggs = ( + nwp.col("a", "b").skew().name.suffix("_skew"), + nwp.nth(-1, -2).kurtosis().name.suffix("_kurtosis"), + ) result = df.group_by(keys).agg(aggs).sort(keys) assert_equal_data(result, expected) diff --git a/tests/plan/hist_test.py b/tests/plan/hist_test.py new file mode 100644 index 0000000000..fdd42c40c4 --- /dev/null +++ b/tests/plan/hist_test.py @@ -0,0 +1,417 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data + +if TYPE_CHECKING: + from collections.abc import Sequence + + from narwhals.typing import EagerAllowed, IntoDType + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "int": [0, 1, 2, 3, 4, 5, 6], + "float": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "int_shuffled": [1, 0, 2, 3, 6, 5, 4], + "float_shuffled": [1.0, 0.0, 2.0, 3.0, 6.0, 5.0, 4.0], + } + + +@pytest.fixture(scope="module") +def schema_data() -> nw.Schema: + return nw.Schema( + { + "int": nw.Int64(), + "float": nw.Float64(), + "int_shuffled": nw.Int64(), + "float_shuffled": nw.Float64(), + } + ) + + +@pytest.fixture(scope="module", params=["int", "float", "int_shuffled", "float_shuffled"]) +def column_data(request: pytest.FixtureRequest) -> str: + result: str = request.param + return result + + +@pytest.fixture(scope="module") +def data_missing(data: Data) -> Data: + return {"has_nan": [float("nan"), *data["int"]], "has_null": [None, *data["int"]]} + + +@pytest.fixture(scope="module") +def schema_data_missing() -> nw.Schema: + return nw.Schema({"has_nan": nw.Float64(), "has_null": nw.Int64()}) + + +@pytest.fixture(scope="module", params=["has_nan", "has_null"]) +def column_data_missing(request: pytest.FixtureRequest) -> str: + result: str = request.param + return result + + +@pytest.fixture(scope="module", params=["pyarrow"]) +def backend(request: pytest.FixtureRequest) -> EagerAllowed: + result: EagerAllowed = request.param + return result + + +@pytest.fixture(scope="module", params=[True, False]) +def include_breakpoint(request: pytest.FixtureRequest) -> bool: + result: bool = request.param + return result + + +def _series( + name: str, source: Data, schema: nw.Schema, backend: EagerAllowed, / +) -> nwp.Series[Any]: + values, dtype = (source[name], schema[name]) + return nwp.Series.from_iterable(values, name=name, dtype=dtype, backend=backend) + + +def _expected( + bins: Sequence[float], count: Sequence[int], *, include_breakpoint: bool +) -> dict[str, Any]: + if not include_breakpoint: + return {"count": count} + return {"breakpoint": bins[1:] if len(bins) > len(count) else bins, "count": count} + + +SHIFT_BINS_BY = 10 +"""shift bins property""" + +bins_cases = pytest.mark.parametrize( + ("bins", "expected_count"), + [ + pytest.param( + [-float("inf"), 2.5, 5.5, float("inf")], [3, 3, 1], id="4_bins-neg-inf-inf" + ), + pytest.param([1.0, 2.5, 5.5, float("inf")], [2, 3, 1], id="4_bins-inf"), + pytest.param([1.0, 2.5, 5.5], [2, 3], id="3_bins"), + pytest.param([-10.0, -1.0, 2.5, 5.5], [0, 3, 3], id="4_bins"), + pytest.param([1.0, 2.0625], [2], id="2_bins-1"), + pytest.param([1], [], id="1_bins"), + pytest.param([0, 10], [7], id="2_bins-2"), + ], +) + + +@bins_cases +def test_hist_bins( + data: Data, + schema_data: nw.Schema, + backend: EagerAllowed, + column_data: str, + bins: Sequence[float], + expected_count: Sequence[int], + *, + include_breakpoint: bool, +) -> None: + ser = _series(column_data, data, schema_data, backend) + expected = _expected(bins, expected_count, include_breakpoint=include_breakpoint) + result = ser.hist(bins, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected) + assert len(result) == max(len(bins) - 1, 0) + + +@bins_cases +def test_hist_bins_shifted( + data: Data, + schema_data: nw.Schema, + backend: EagerAllowed, + column_data: str, + bins: Sequence[float], + expected_count: Sequence[int], + *, + include_breakpoint: bool, +) -> None: + shifted_bins = [b + SHIFT_BINS_BY for b in bins] + expected = _expected( + shifted_bins, expected_count, include_breakpoint=include_breakpoint + ) + ser = _series(column_data, data, schema_data, backend) + SHIFT_BINS_BY + result = ser.hist(shifted_bins, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected) + + +@bins_cases +def test_hist_bins_missing( + data_missing: Data, + schema_data_missing: nw.Schema, + backend: EagerAllowed, + column_data_missing: str, + bins: Sequence[float], + expected_count: Sequence[int], + *, + include_breakpoint: bool, +) -> None: + ser = _series(column_data_missing, data_missing, schema_data_missing, backend) + expected = _expected(bins, expected_count, include_breakpoint=include_breakpoint) + result = ser.hist(bins, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected) + + +bin_count_cases = pytest.mark.parametrize( + ("bin_count", "expected_bins", "expected_count"), + [ + (4, [1.5, 3.0, 4.5, 6.0], [2, 2, 1, 2]), + ( + 12, + [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0], + [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + ), + (1, [6], [7]), + (0, [], []), + ], +) + + +@bin_count_cases +def test_hist_bin_count( + data: Data, + schema_data: nw.Schema, + backend: EagerAllowed, + column_data: str, + bin_count: int, + expected_bins: Sequence[float], + expected_count: Sequence[int], + *, + include_breakpoint: bool, +) -> None: + ser = _series(column_data, data, schema_data, backend) + expected = _expected( + expected_bins, expected_count, include_breakpoint=include_breakpoint + ) + result = ser.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected) + assert len(result) == bin_count + if bin_count > 0: + assert result.get_column("count").sum() == ser.drop_nans().count() + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.all().hist(bin_count=5), + { + "int": [2, 1, 1, 1, 2], + "float": [2, 1, 1, 1, 2], + "int_shuffled": [2, 1, 1, 1, 2], + "float_shuffled": [2, 1, 1, 1, 2], + }, + ), + ( + (99 + nwp.all()).hist(bin_count=2).name.keep(), + { + "int": [4, 3], + "float": [4, 3], + "int_shuffled": [4, 3], + "float_shuffled": [4, 3], + }, + ), + ( + nwp.all().hist([-3, -2, 3, 6]).name.to_uppercase(), + { + "INT": [0, 4, 3], + "FLOAT": [0, 4, 3], + "INT_SHUFFLED": [0, 4, 3], + "FLOAT_SHUFFLED": [0, 4, 3], + }, + ), + ( + nwp.all().clip(upper_bound=4).hist([2, 3, 4, 5, 6]), + { + "int": [2, 3, 0, 0], + "float": [2, 3, 0, 0], + "int_shuffled": [2, 3, 0, 0], + "float_shuffled": [2, 3, 0, 0], + }, + ), + ( + (nwp.all() * 2.7).hist([1.3, 5.1, 8.98, 11.3]), + { + "int": [1, 2, 1], + "float": [1, 2, 1], + "int_shuffled": [1, 2, 1], + "float_shuffled": [1, 2, 1], + }, + ), + ], +) +def test_hist_expr_counts_only( + data: Data, + schema_data: nw.Schema, + backend: EagerAllowed, + expr: nwp.Expr, + expected: dict[str, Any], +) -> None: + df = nwp.DataFrame.from_dict(data, schema_data, backend=backend) + result = df.select(expr) + assert_equal_data(result, expected) + + +def test_hist_expr_breakpoint( + data: Data, schema_data: nw.Schema, backend: EagerAllowed +) -> None: + df = nwp.DataFrame.from_dict(data, schema_data, backend=backend) + expr = nwp.all().hist(bin_count=3, include_breakpoint=True) + result = df.select(expr) + result_schema = result.collect_schema() + + dtype_breakpoint: IntoDType = nw.Float64 + # NOTE: To match polars it would be this, but maybe i64 is okay? + dtype_count: IntoDType = nw.UInt32 + dtype_count = nw.Int64 + + dtype_struct = nw.Struct({"breakpoint": dtype_breakpoint, "count": dtype_count}) + schema_struct = nw.Schema({"breakpoint": nw.Float64(), "count": nw.Int64()}) + expected_schema = nw.Schema( + [ + ("int", dtype_struct), + ("float", dtype_struct), + ("int_shuffled", dtype_struct), + ("float_shuffled", dtype_struct), + ] + ) + expected_data = { + "int": [ + {"breakpoint": 2.0, "count": 3}, + {"breakpoint": 4.0, "count": 2}, + {"breakpoint": 6.0, "count": 2}, + ], + "float": [ + {"breakpoint": 2.0, "count": 3}, + {"breakpoint": 4.0, "count": 2}, + {"breakpoint": 6.0, "count": 2}, + ], + "int_shuffled": [ + {"breakpoint": 2.0, "count": 3}, + {"breakpoint": 4.0, "count": 2}, + {"breakpoint": 6.0, "count": 2}, + ], + "float_shuffled": [ + {"breakpoint": 2.0, "count": 3}, + {"breakpoint": 4.0, "count": 2}, + {"breakpoint": 6.0, "count": 2}, + ], + } + assert result_schema == expected_schema + assert_equal_data(result, expected_data) + for ser in result.iter_columns(): + assert ser.struct.schema == schema_struct + + +@bin_count_cases +def test_hist_bin_count_missing( + data_missing: Data, + schema_data_missing: nw.Schema, + backend: EagerAllowed, + column_data_missing: str, + bin_count: int, + expected_bins: Sequence[float], + expected_count: Sequence[int], + *, + include_breakpoint: bool, +) -> None: + ser = _series(column_data_missing, data_missing, schema_data_missing, backend) + expected = _expected( + expected_bins, expected_count, include_breakpoint=include_breakpoint + ) + result = ser.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) + assert_equal_data(result, expected) + assert len(result) == bin_count + if bin_count > 0: + assert result.get_column("count").sum() == ser.drop_nans().count() + + +@pytest.mark.parametrize( + ("column", "bin_count", "expected_breakpoint", "expected_count"), + [ + ("all_zero", 4, [-0.25, 0.0, 0.25, 0.5], [0, 3, 0, 0]), + ("all_non_zero", 4, [4.75, 5.0, 5.25, 5.5], [0, 3, 0, 0]), + ("all_zero", 1, [0.5], [3]), + ], +) +def test_hist_bin_count_no_spread( + backend: EagerAllowed, + column: str, + bin_count: int, + expected_breakpoint: Sequence[float], + expected_count: Sequence[int], +) -> None: + data = {"all_zero": [0, 0, 0], "all_non_zero": [5, 5, 5]} + ser = nwp.DataFrame.from_dict(data, backend=backend).get_column(column) + result = ser.hist(bin_count=bin_count, include_breakpoint=True) + expected = {"breakpoint": expected_breakpoint, "count": expected_count} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize("bins", [[1, 5, 10]]) +def test_hist_bins_no_data( + backend: EagerAllowed, bins: list[int], *, include_breakpoint: bool +) -> None: + s = nwp.Series.from_iterable([], dtype=nw.Float64(), backend=backend) + result = s.hist(bins, include_breakpoint=include_breakpoint) + assert len(result) == 2 + assert result.get_column("count").sum() == 0 + + +@pytest.mark.parametrize("bin_count", [1, 10]) +def test_hist_bin_count_no_data( + backend: EagerAllowed, bin_count: int, *, include_breakpoint: bool +) -> None: + s = nwp.Series.from_iterable([], dtype=nw.Float64(), backend=backend) + result = s.hist(bin_count=bin_count, include_breakpoint=include_breakpoint) + assert len(result) == bin_count + assert result.get_column("count").sum() == 0 + + if include_breakpoint: + bps = result.get_column("breakpoint").to_list() + assert bps[0] == (1 / bin_count) + if bin_count > 1: + assert bps[-1] == 1 + + +def test_hist_bins_none(backend: EagerAllowed) -> None: + s = nwp.Series.from_iterable([1, 2, 3], backend=backend) + result = s.hist(bins=None, bin_count=None) + assert len(result) == 10 + + +def test_hist_series_compat_flag(backend: EagerAllowed) -> None: + # NOTE: Mainly for verifying `Expr.hist` has handled naming/collecting as struct + # The flag itself is not desirable + values = [1, 3, 8, 8, 2, 1, 3] + s = nwp.Series.from_iterable(values, name="original", backend=backend) + + result = s.hist( + bin_count=4, + include_breakpoint=False, + include_category=False, + _compatibility_behavior="narwhals", + ) + assert_equal_data(result, {"count": [3, 2, 0, 2]}) + + result = s.hist( + bin_count=4, + include_breakpoint=False, + include_category=False, + _compatibility_behavior="polars", + ) + assert_equal_data(result, {"original": [3, 2, 0, 2]}) + + result = s.hist(bin_count=4, include_breakpoint=True, include_category=False) + expected = {"breakpoint": [2.75, 4.5, 6.25, 8.0], "count": [3, 2, 0, 2]} + assert_equal_data(result, expected) diff --git a/tests/plan/is_duplicated_unique_test.py b/tests/plan/is_duplicated_unique_test.py new file mode 100644 index 0000000000..7f5ceac074 --- /dev/null +++ b/tests/plan/is_duplicated_unique_test.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from narwhals._plan import selectors as ncs +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture +def data() -> Data: + return { + "v1": [None, 2, 1, 4, 1], + "v2": ["a", "c", "c", None, None], + "p1": [2, 2, 2, 1, 1], + "i": [0, 1, 2, 3, 4], + } + + +def test_is_duplicated_unique(data: Data) -> None: + expected = { + "v1_is_unique": [True, True, False, True, False], + "v2_is_unique": [True, False, False, False, False], + "v1_is_duplicated": [False, False, True, False, True], + "v2_is_duplicated": [False, True, True, True, True], + } + vals = nwp.col("v1", "v2") + exprs = ( + vals.is_unique().name.suffix("_is_unique"), + vals.is_duplicated().name.suffix("_is_duplicated"), + ) + result = dataframe(data).select("i", *exprs).sort("i").drop("i") + assert_equal_data(result, expected) + + +# NOTE: Not supported on `main` +def test_is_duplicated_unique_partitioned(data: Data) -> None: + expected = { + "v1_is_unique": [True, True, True, True, True], + "v2_is_unique": [True, False, False, False, False], + "v1_is_duplicated": [False, False, False, False, False], + "v2_is_duplicated": [False, True, True, True, True], + } + vals = ncs.by_index(0, 1) + exprs = ( + vals.is_unique().name.suffix("_is_unique").over("p1"), + vals.is_duplicated().name.suffix("_is_duplicated").over("p1"), + ) + result = dataframe(data).select("i", *exprs).sort("i").drop("i") + assert_equal_data(result, expected) diff --git a/tests/plan/kurtosis_skew_test.py b/tests/plan/kurtosis_skew_test.py new file mode 100644 index 0000000000..1dd98483d5 --- /dev/null +++ b/tests/plan/kurtosis_skew_test.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import pytest + +from narwhals import _plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +@pytest.mark.parametrize( + ("data", "expected_kurtosis", "expected_skew"), + [ + ([], None, None), + ([None], None, None), + ([1], None, None), + ([1, 2], -2, 0.0), + ([0.0, 0.0, 0.0], None, None), + ([1, 2, 3, 2, 1], -1.153061, 0.343622), + ([None, 1.4, 1.3, 5.9, None, 2.9], -1.014744, 0.801638), + ], +) +def test_kurtosis_skew_expr( + data: list[float], expected_kurtosis: float | None, expected_skew: float | None +) -> None: + df = dataframe({"a": data}) + kurtosis = nwp.col("a").kurtosis() + skew = nwp.col("a").skew() + height = len(data) + + assert_equal_data(df.select(kurtosis), {"a": [expected_kurtosis]}) + assert_equal_data(df.select(skew), {"a": [expected_skew]}) + assert_equal_data(df.with_columns(kurtosis), {"a": [expected_kurtosis] * height}) + assert_equal_data(df.with_columns(skew), {"a": [expected_skew] * height}) diff --git a/tests/plan/list_contains_test.py b/tests/plan/list_contains_test.py new file mode 100644 index 0000000000..0761434e97 --- /dev/null +++ b/tests/plan/list_contains_test.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Final + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import IntoExpr + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "a": [[2, 2, 3, None, None], None, [], [None]], + "b": [[1, 2, 2], [3, 4], [5, 5, 5, 6], [7]], + "c": [1, 3, None, 2], + "d": ["B", None, "A", "C"], + } + + +a = nwp.col("a") +b = nwp.col("b") + + +@pytest.mark.parametrize( + ("item", "expected"), + [ + (2, [True, None, False, False]), + (4, [False, None, False, False]), + (nwp.col("c").last() + 1, [True, None, False, False]), + (nwp.lit(None, nw.Int32), [True, None, False, True]), + ], +) +def test_list_contains(data: Data, item: IntoExpr, expected: list[bool | None]) -> None: + df = dataframe(data).with_columns(a.cast(nw.List(nw.Int32))) + result = df.select(a.list.contains(item)) + assert_equal_data(result, {"a": expected}) + + +R1: Final[list[Any]] = [None, "A", "B", "A", "A", "B"] +R2: Final = None +R3: Final[list[Any]] = [] +R4: Final = [None] + + +@pytest.mark.parametrize( + ("row", "item", "expected"), + [ + (R1, "A", True), + (R2, "A", None), + (R3, "A", False), + (R4, "A", False), + (R1, None, True), + (R2, None, None), + (R3, None, False), + (R4, None, True), + (R1, "C", False), + (R2, "C", None), + (R3, "C", False), + (R4, "C", False), + ], +) +def test_list_contains_scalar( + row: list[str | None] | None, item: IntoExpr, *, expected: bool | None +) -> None: + data = {"a": [row]} + df = dataframe(data).select(a.cast(nw.List(nw.String))) + result = df.select(a.first().list.contains(item)) + assert_equal_data(result, {"a": [expected]}) diff --git a/tests/plan/list_get_test.py b/tests/plan/list_get_test.py new file mode 100644 index 0000000000..64fa644349 --- /dev/null +++ b/tests/plan/list_get_test.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "a": [[1, 2], [3, 4, None], None, [None]], + "b": [[None, "o"], ["b", None, "b"], [None, "oops", None, "hi"], None], + } + + +a = nwp.nth(0) +b = nwp.col("b") + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + (a.list.get(0), {"a": [1, 3, None, None]}), + (b.list.get(1), {"b": ["o", None, "oops", None]}), + ], +) +def test_list_get(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: + df = dataframe(data).with_columns( + a.cast(nw.List(nw.Int32())), b.cast(nw.List(nw.String)) + ) + result = df.select(exprs) + assert_equal_data(result, expected) diff --git a/tests/plan/list_join_test.py b/tests/plan/list_join_test.py new file mode 100644 index 0000000000..881700f6c7 --- /dev/null +++ b/tests/plan/list_join_test.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Final, TypeVar + + from typing_extensions import TypeAlias + + from tests.conftest import Data + + T = TypeVar("T") + SubList: TypeAlias = list[T] | list[T | None] | list[None] | None + SubListStr: TypeAlias = SubList[str] + + +R1: Final[SubListStr] = ["a", "b", "c"] +R2: Final[SubListStr] = [None, None, None] +R3: Final[SubListStr] = [None, None, "1", "2", None, "3", None] +R4: Final[SubListStr] = ["x", "y"] +R5: Final[SubListStr] = ["1", None, "3"] +R6: Final[SubListStr] = [None] +R7: Final[SubListStr] = None +R8: Final[SubListStr] = [] +R9: Final[SubListStr] = [None, None] + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [R1, R2, R3, R4, R5, R6, R7, R8, R9]} + + +a = nwp.col("a") + + +@pytest.mark.parametrize( + ("separator", "ignore_nulls", "expected"), + [ + ("-", False, ["a-b-c", None, None, "x-y", None, None, None, "", None]), + ("-", True, ["a-b-c", "", "1-2-3", "x-y", "1-3", "", None, "", ""]), + ("", False, ["abc", None, None, "xy", None, None, None, "", None]), + ("", True, ["abc", "", "123", "xy", "13", "", None, "", ""]), + ], + ids=[ + "hyphen-propagate_nulls", + "hyphen-ignore_nulls", + "empty-propagate_nulls", + "empty-ignore_nulls", + ], +) +def test_list_join( + data: Data, separator: str, *, ignore_nulls: bool, expected: list[str | None] +) -> None: + df = dataframe(data).with_columns(a.cast(nw.List(nw.String))) + expr = a.list.join(separator, ignore_nulls=ignore_nulls) + result = df.select(expr) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + "ignore_nulls", [True, False], ids=["ignore_nulls", "propagate_nulls"] +) +@pytest.mark.parametrize("separator", ["?", "", " "], ids=["question", "empty", "space"]) +@pytest.mark.parametrize( + "row", [R1, R2, R3, R4, R5, R6, R7, R8, R9], ids=[f"row-{i}" for i in range(1, 10)] +) +def test_list_join_scalar(row: SubListStr, separator: str, *, ignore_nulls: bool) -> None: + data = {"a": [row]} + df = dataframe(data).select(a.cast(nw.List(nw.String))) + expr = a.first().list.join(separator, ignore_nulls=ignore_nulls) + result = df.select(expr) + expected: str | None + if row is None: + expected = None + elif row == []: + expected = "" + elif any(el is None for el in row): + if not ignore_nulls: + expected = None + elif all(el is None for el in row): + expected = "" + else: + expected = separator.join(el for el in row if el is not None) + else: + expected = separator.join(el for el in row if el is not None) + + assert_equal_data(result, {"a": [expected]}) + + +@pytest.mark.parametrize( + ("rows", "expected"), + [ + ([R1, R4, ["all", "okay"]], ["a b c", "x y", "all okay"]), + ( + [ + None, + ["no", "nulls", "inside"], + None, + None, + ["only", "on", "validity"], + None, + ], + [None, "no nulls inside", None, None, "only on validity", None], + ), + ( + [["just", "empty", "lists"], [], [], ["nothing", "fancy"], []], + ["just empty lists", "", "", "nothing fancy", ""], + ), + ([None, None, None], [None, None, None]), + ( + [ + ["every", None, "null"], + None, + [None, "is", "lonely"], + ["not", "even"], + ["a", "single", None, "friend"], + [None], + ], + ["every null", None, "is lonely", "not even", "a single friend", ""], + ), + ( + [ + ["even", None, "this"], + [], + [None], + None, + [None], + [None, "can", "be", None, "cheap"], + [], + None, + [None], + ], + ["even this", "", "", None, "", "can be cheap", "", None, ""], + ), + ], + ids=[ + "all-good", + "no-nulls-inside", + "only-empty-lists", + "full-null", + "max-1-null", + "mixed-bag", + ], +) +def test_list_join_ignore_nulls_fastpaths( + rows: Sequence[SubListStr], expected: list[str | None] +) -> None: + # When we don't need to handle *every* edge case at the same time ... + # ... things can be simpler + separator = " " + data = {"a": list(rows)} + df = dataframe(data).with_columns(a.cast(nw.List(nw.String))) + expr = a.list.join(separator) + result = df.select(expr) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/list_len_test.py b/tests/plan/list_len_test.py new file mode 100644 index 0000000000..f23c8c61be --- /dev/null +++ b/tests/plan/list_len_test.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [[1, 2], [3, 4, None], None, [], [None]], "i": [4, 3, 2, 1, 0]} + + +a = nwp.nth(0) + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + (a.list.len(), {"a": [2, 3, None, 0, 1]}), + ( + [a.first().list.len().alias("first"), a.last().list.len().alias("last")], + {"first": [2], "last": [1]}, + ), + ( # NOTE: `polars` produces nulls following the `over(order_by=...)` + # That's either a bug, or something that won't be ported to `narwhals` + [ + a.first().over(order_by="i").list.len().alias("first_order_i"), + a.last().over(order_by="i").list.len().alias("last_order_i"), + ], + {"first_order_i": [1], "last_order_i": [2]}, + ), + ( + # NOTE: This does work already in `polars` + [ + a.sort_by("i").first().list.len().alias("sort_by_i_first"), + a.sort_by("i").last().list.len().alias("sort_by_i_last"), + ], + {"sort_by_i_first": [1], "sort_by_i_last": [2]}, + ), + ], +) +def test_list_len(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None: + df = dataframe(data).with_columns(a.cast(nw.List(nw.Int32()))) + result = df.select(exprs) + assert_equal_data(result, expected) diff --git a/tests/plan/list_unique_test.py b/tests/plan/list_unique_test.py new file mode 100644 index 0000000000..7f82e593b5 --- /dev/null +++ b/tests/plan/list_unique_test.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_series, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "a": [[2, 2, 3, None, None], None, [], [None]], + "b": [[1, 2, 2], [3, 4], [5, 5, 5, 6], [7]], + } + + +a = nwp.col("a") +b = nwp.col("b") + + +def test_list_unique(data: Data) -> None: + df = dataframe(data).select(a.cast(nw.List(nw.Int32))) + ser = df.select(a.list.unique()).to_series() + result = ser.to_list() + assert len(result) == 4 + assert len(result[0]) == 3 + assert set(result[0]) == {2, 3, None} + assert result[1] is None + assert len(result[2]) == 0 + assert len(result[3]) == 1 + + assert_equal_series(ser.explode(), [2, 3, None, None, None, None], "a") + + +# TODO @dangotbanned: Report `ListScalar.values` bug upstream +# - Returning `None` breaks: `__len__`,` __getitem__`, `__iter__` +# - Which breaks `pa.array([], pa.list_(pa.int64()))` +@pytest.mark.parametrize( + ("row", "expected"), + [ + ([None, "A", "B", "A", "A", "B"], [None, "A", "B"]), + (None, None), + ([], []), + ([None], [None]), + ], +) +def test_list_unique_scalar( + row: list[str | None] | None, expected: list[str | None] | None +) -> None: + data = {"a": [row]} + df = dataframe(data).select(a.cast(nw.List(nw.String))) + # NOTE: Don't separate `first().list.unique()` + # The chain is required to force the transition from `Expr` -> `Scalar` + result = df.select(a.first().list.unique()).to_series() + assert_equal_series(result, [expected], "a") + + +def test_list_unique_all_valid(data: Data) -> None: + df = dataframe(data).select(b.cast(nw.List(nw.Int32))) + ser = df.select(b.list.unique()).to_series() + result = ser.to_list() + assert set(result[0]) == {1, 2} + assert set(result[1]) == {3, 4} + assert set(result[2]) == {5, 6} + assert set(result[3]) == {7} diff --git a/tests/plan/map_batches_test.py b/tests/plan/map_batches_test.py new file mode 100644 index 0000000000..6c8f9776e8 --- /dev/null +++ b/tests/plan/map_batches_test.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from narwhals._plan import selectors as ncs +from tests.plan.utils import assert_equal_data, dataframe, re_compile + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + import pyarrow as pa + + from narwhals._plan.compliant.typing import ( + SeriesAny as CompliantSeriesAny, + SeriesT as CompliantSeriesT, + ) + from narwhals.typing import _1DArray, _NumpyScalar + from tests.conftest import Data + + +pytest.importorskip("numpy") +import numpy as np + + +@pytest.fixture +def data() -> Data: + return { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": ["A", "B", "A"], + "j": [12.1, None, 4.0], + "k": [42, 10, None], + "z": [7.0, 8.0, 9.0], + } + + +def elementwise_series(s: CompliantSeriesT, /) -> CompliantSeriesT: + dtype_name = type(s.dtype).__name__.lower() + repeat_name = (dtype_name,) * (len(s) - 1) + values = [*repeat_name, "last"] + return s.from_iterable(values, version=s.version, name="funky") + + +def elementwise_1d_array(s: CompliantSeriesAny, /) -> _1DArray: + return s.to_numpy() + 1 + + +def to_numpy(s: CompliantSeriesAny, /) -> _1DArray: + return s.to_numpy() + + +def groupwise_1d_array(s: CompliantSeriesAny, /) -> _1DArray: + result: _1DArray = np.append(s.to_numpy(), [10, 2]) + return result + + +def aggregation_np_scalar(s: CompliantSeriesAny, /) -> _NumpyScalar: + result: _NumpyScalar = s.to_numpy().max() + return result + + +def aggregation_pa_scalar(s: CompliantSeriesAny) -> pa.Scalar[Any]: + pytest.importorskip("pyarrow") + import pyarrow as pa + + result: pa.Scalar[Any] = pa.array(s.to_list())[0] + return result + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + pytest.param( + [ + nwp.col("e") + .alias("...") + .map_batches(elementwise_series, is_elementwise=True), + nwp.col("e"), + ], + {"funky": ["string", "string", "last"], "e": ["A", "B", "A"]}, + id="is_elementwise-series", + ), + pytest.param( + nwp.col("a", "b", "z").map_batches(to_numpy), + {"a": [1, 2, 3], "b": [4, 5, 6], "z": [7.0, 8.0, 9.0]}, + id="to-numpy", + ), + pytest.param( + nwp.col("a") + .map_batches(elementwise_1d_array, nw.Float64, is_elementwise=True) + .sum(), + {"a": [9.0]}, + id="is_elementwise-1d-array", + ), + pytest.param( + nwp.col("a").map_batches(elementwise_1d_array, nw.Float64).sum(), + {"a": [9.0]}, + id="unknown-1d-array", + ), + pytest.param( + ncs.by_index(0, 2, 3) + .map_batches(groupwise_1d_array, is_elementwise=True) + .sort(), + {"a": [1, 2, 2, 3, 10], "c": [2, 2, 4, 9, 10], "d": [2, 7, 8, 8, 10]}, + # NOTE: Maybe this should be rejected because of the length change? + # It doesn't break broadcasting rules, but uses an optional argument incorrectly + # and we only know *after* execution + id="is_elementwise-1d-array-groupwise", + ), + pytest.param( + nwp.col("j", "k") + .fill_null(15) + .map_batches(aggregation_np_scalar, returns_scalar=True), + {"j": [15], "k": [42]}, + id="returns_scalar-np-scalar", + ), + pytest.param( + [ + nwp.col("a").map_batches( + lambda _: [1, 2], + returns_scalar=True, + return_dtype=nw.List(nw.Int64()), + ), + nwp.col("b").last(), + ], + {"a": [[1, 2]], "b": [6]}, + id="returns_scalar-list", + ), + ], +) +def test_map_batches( + data: Data, expr: nwp.Expr | Sequence[nwp.Expr], expected: Data +) -> None: + result = dataframe(data).select(expr) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("udf", "result_type_name"), + [ + (aggregation_np_scalar, "'numpy.int64'"), + (aggregation_pa_scalar, ".+pyarrow.+scalar.+"), + (len, "'int'"), + (str, "'str'"), + ], +) +def test_map_batches_invalid( + data: Data, udf: Callable[[Any], Any], result_type_name: str +) -> None: + expr = nwp.col("a", "b", "z").map_batches(udf) + pattern = re_compile( + rf"map.+ with `returns_scalar=False` must return a Series.+{result_type_name}" + ) + with pytest.raises(TypeError, match=pattern): + dataframe(data).select(expr) diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index f7738dbf5e..5385c65e17 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -17,15 +17,31 @@ pytest.importorskip("polars") import polars as pl -if POLARS_VERSION >= (1, 0): - # https://github.com/pola-rs/polars/pull/16743 - OVER_CASE = ( +if POLARS_VERSION >= (1, 0): # https://github.com/pola-rs/polars/pull/16743 + if POLARS_VERSION >= (1, 36): # pragma: no cover + # TODO @dangotbanned: Update special-casing in `OrderedWindowExpr` + # https://github.com/pola-rs/polars/pull/25117/files#diff-45d1f22172e291bd4a5ce36d1fb8233698394f9590bcf11382b9c99b5449fff5 + marks: tuple[pytest.MarkDecorator, ...] = ( + pytest.mark.xfail( + reason=( + "`polars==1.36.0b1` now considers `order_by` in `root_names`\n" + r"https://github.com/pola-rs/polars/pull/25117" + ), + raises=AssertionError, + ), + ) + else: # pragma: no cover + marks = () + OVER_CASE = pytest.param( nwp.col("a").last().over("b", order_by="c"), pl.col("a").last().over("b", order_by="c"), ["a", "b"], + marks=marks, ) else: # pragma: no cover - OVER_CASE = (nwp.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) + OVER_CASE = pytest.param( + nwp.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"] + ) if POLARS_VERSION >= (0, 20, 5): LEN_CASE = (nwp.len(), pl.len(), "len") else: # pragma: no cover @@ -325,11 +341,6 @@ def test_literal_output_name() -> None: assert e.meta.output_name() == "" -# NOTE: Very low-priority -@pytest.mark.xfail( - reason="TODO: `Expr.struct.field` influences `meta.output_name`.", - raises=AssertionError, -) def test_struct_field_output_name_24003() -> None: assert nwp.col("ball").struct.field("radius").meta.output_name() == "radius" diff --git a/tests/plan/mode_test.py b/tests/plan/mode_test.py new file mode 100644 index 0000000000..1a647d2b68 --- /dev/null +++ b/tests/plan/mode_test.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from narwhals import _plan as nwp +from narwhals._plan import selectors as ncs +from narwhals.exceptions import ShapeError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [1, 1, 2, 2, 3], "b": [1, 2, 3, 3, 4]} + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("b").mode(), {"b": [3]}), + (nwp.col("a").mode(keep="all"), {"a": [1, 2]}), + (nwp.col("b").filter(nwp.col("b") != 3).mode(), {"b": [1, 2, 4]}), + (nwp.col("a").mode().sum(), {"a": [3]}), + ], + ids=["single", "multiple-1", "multiple-2", "multiple-agg"], +) +def test_mode_expr_keep_all(data: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data).select(expr).sort(ncs.first()) + assert_equal_data(result, expected) + + +def test_mode_expr_different_lengths_keep_all(data: Data) -> None: + df = dataframe(data) + with pytest.raises(ShapeError): + df.select(nwp.col("a", "b").mode(keep="all")) + + +def test_mode_expr_keep_any(data: Data) -> None: + result = dataframe(data).select(nwp.col("a", "b").mode(keep="any")) + try: + expected = {"a": [1], "b": [3]} + assert_equal_data(result, expected) + except AssertionError: # pragma: no cover + expected = {"a": [2], "b": [3]} + assert_equal_data(result, expected) diff --git a/tests/plan/over_test.py b/tests/plan/over_test.py index ccea0eec77..42f61b5467 100644 --- a/tests/plan/over_test.py +++ b/tests/plan/over_test.py @@ -5,18 +5,20 @@ import pytest +from tests.utils import PYARROW_VERSION + pytest.importorskip("pyarrow") import narwhals as nw import narwhals._plan as nwp from narwhals._plan import selectors as ncs -from narwhals._utils import zip_strict +from narwhals._utils import Implementation, zip_strict from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_equal_data, dataframe, re_compile if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence + from collections.abc import Callable, Iterable, Mapping, Sequence from _pytest.mark import ParameterSet from typing_extensions import TypeAlias @@ -274,6 +276,80 @@ def test_null_count_over() -> None: assert_equal_data(result, expected) +@pytest.fixture(scope="module") +def data_kurtosis_skew() -> Data: + return { + "p1": ["a", "a", "b", "a", "b", "b"], + "p2": ["d", "e", "e", "e", "d", "d"], + "v1": [0.2, 5.0, 1.0, 0.7, 0.5, 1.0], + "v2": [-1.0, 0.8, 0.6, 0.0, 1.1, 19.0], + "v3": [None, 1.2, 2.1, 0.4, 5.0, 3.2], + } + + +EXPECTED_SKEW = { + "v1_p1": [0.678654, 0.678654, -0.707107, 0.678654, -0.707107, -0.707107], + "v2_p1": [-0.135062, -0.135062, 0.705297, -0.135062, 0.705297, 0.705297], + "v3_p1": [-4.33681e-16, -4.33681e-16, 0.285361, -4.33681e-16, 0.285361, 0.285361], + "v1_p1_p2": [float("nan"), -2.68106e-16, float("nan"), -2.68106e-16, 0.0, 0.0], + "v2_p1_p2": [float("nan"), 0.0, float("nan"), 0.0, -2.37866e-16, -2.37866e-16], + "v3_p1_p2": [ + None, + -4.33681e-16, + float("nan"), + -4.33681e-16, + 1.44679e-15, + 1.44679e-15, + ], +} +EXPECTED_KURTOSIS = { + "v1_p1": [-1.5, -1.5, -1.5, -1.5, -1.5, -1.5], + "v2_p1": [-1.5, -1.5, -1.5, -1.5, -1.5, -1.5], + "v3_p1": [-2.0, -2.0, -1.5, -2.0, -1.5, -1.5], + "v1_p1_p2": [float("nan"), -2.0, float("nan"), -2.0, -2.0, -2.0], + "v2_p1_p2": [float("nan"), -2.0, float("nan"), -2.0, -2.0, -2.0], + "v3_p1_p2": [None, -2.0, float("nan"), -2.0, -2.0, -2.0], +} +string = ncs.string() +not_string = ~string + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + ( + [ + not_string.skew().over("p1").name.suffix("_p1"), + not_string.skew().over(string).name.suffix("_p1_p2"), + ], + EXPECTED_SKEW, + ), + ( + [ + not_string.kurtosis().over("p1").name.suffix("_p1"), + not_string.kurtosis().over(string).name.suffix("_p1_p2"), + ], + EXPECTED_KURTOSIS, + ), + ], +) +def test_kurtosis_over_skew( + data_kurtosis_skew: Data, + request: pytest.FixtureRequest, + exprs: Iterable[nwp.Expr], + expected: Data, +) -> None: + df = dataframe(data_kurtosis_skew) + request.applymarker( + pytest.mark.xfail( + (df.implementation is Implementation.PYARROW and PYARROW_VERSION < (20,)), + reason="too old for `pyarrow.compute.{kurtosis,skew}`", + ) + ) + result = df.select(exprs) + assert_equal_data(result, expected) + + @pytest.fixture def data_groups() -> Data: return { diff --git a/tests/plan/range_test.py b/tests/plan/range_test.py new file mode 100644 index 0000000000..0dc02bc5ad --- /dev/null +++ b/tests/plan/range_test.py @@ -0,0 +1,297 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Final, Literal + +import pytest + +from narwhals.exceptions import ShapeError +from tests.utils import PYARROW_VERSION + +if PYARROW_VERSION < (21,): # pragma: no cover + pytest.importorskip("numpy") +import datetime as dt + +import narwhals as nw +from narwhals import _plan as nwp +from tests.conftest import TEST_EAGER_BACKENDS +from tests.plan.utils import assert_equal_data, assert_equal_series, dataframe + +if TYPE_CHECKING: + from collections.abc import Sequence + + from narwhals.typing import ClosedInterval, EagerAllowed, IntoDType + + +@pytest.fixture(scope="module") +def data() -> dict[str, Any]: + """Variant of `compliant_test.data_small`, with only numeric data.""" + return { + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + "j": [12.1, None, 4.0], + "k": [42, 10, None], + "l": [4, 5, 6], + "m": [0, 1, 2], + } + + +_HAS_IMPLEMENTATION = frozenset((nw.Implementation.PYARROW, "pyarrow")) +"""Using to filter *the source* of `eager_backend` - which includes `polars` and `pandas` when available. + +For now, this lets some tests be written in a backend agnostic way. +""" + +_HAS_IMPLEMENTATION_IMPL = frozenset( + el for el in _HAS_IMPLEMENTATION if isinstance(el, nw.Implementation) +) +"""Filtered for heavily parametric tests.""" + + +@pytest.fixture( + scope="module", params=_HAS_IMPLEMENTATION.intersection(TEST_EAGER_BACKENDS) +) +def eager(request: pytest.FixtureRequest) -> EagerAllowed: + result: EagerAllowed = request.param + return result + + +@pytest.fixture( + scope="module", + params=_HAS_IMPLEMENTATION_IMPL.intersection(TEST_EAGER_BACKENDS).union([False]), +) +def backend(request: pytest.FixtureRequest) -> EagerAllowed | Literal[False]: + result: EagerAllowed | Literal[False] = request.param + return result + + +@pytest.fixture(scope="module", params=[2024, 2400]) +def leap_year(request: pytest.FixtureRequest) -> int: + result: int = request.param + return result + + +EXPECTED_DATE_1: Final = [ + dt.date(2020, 1, 26), + dt.date(2020, 2, 20), + dt.date(2020, 3, 16), + dt.date(2020, 4, 10), +] +EXPECTED_DATE_2: Final = [dt.date(2021, 1, 30)] +EXPECTED_DATE_3: Final = [ + dt.date(2000, 1, 1), + dt.date(2002, 9, 14), + dt.date(2005, 5, 28), + dt.date(2008, 2, 9), + dt.date(2010, 10, 23), + dt.date(2013, 7, 6), + dt.date(2016, 3, 19), + dt.date(2018, 12, 1), + dt.date(2021, 8, 14), +] +EXPECTED_DATE_4: Final = [ + dt.date(2006, 10, 14), + dt.date(2013, 7, 27), + dt.date(2020, 5, 9), +] + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + [ + nwp.date_range( + dt.date(2020, 1, 1), + dt.date(2020, 4, 30), + interval="25d", + closed="none", + ) + ], + {"literal": EXPECTED_DATE_1}, + ), + ( + ( + nwp.date_range( + dt.date(2021, 1, 30), + nwp.lit(18747, nw.Int32).cast(nw.Date), + interval="90d", + closed="left", + ).alias("date_range_cast_expr"), + {"date_range_cast_expr": EXPECTED_DATE_2}, + ) + ), + ], +) +def test_date_range( + expr: nwp.Expr | Sequence[nwp.Expr], + expected: dict[str, Any], + data: dict[str, list[dt.date]], +) -> None: + pytest.importorskip("pyarrow") + result = dataframe(data).select(expr) + assert_equal_data(result, expected) + + +def test_date_range_eager_leap(eager: EagerAllowed, leap_year: int) -> None: + series_leap = nwp.date_range( + dt.date(leap_year, 2, 25), dt.date(leap_year, 3, 25), eager=eager + ) + series_regular = nwp.date_range( + dt.date(leap_year + 1, 2, 25), + dt.date(leap_year + 1, 3, 25), + interval=dt.timedelta(days=1), + eager=eager, + ) + assert len(series_regular) == 29 + assert len(series_leap) == 30 + + +@pytest.mark.parametrize( + ("start", "end", "interval", "closed", "expected"), + [ + (dt.date(2000, 1, 1), dt.date(2023, 8, 31), "987d", "both", EXPECTED_DATE_3), + (dt.date(2000, 1, 1), dt.date(2023, 8, 31), "354w", "right", EXPECTED_DATE_4), + ], +) +def test_date_range_eager( + start: dt.date, + end: dt.date, + interval: str | dt.timedelta, + closed: ClosedInterval, + expected: list[dt.date], + eager: EagerAllowed, +) -> None: + ser = nwp.date_range(start, end, interval=interval, closed=closed, eager=eager) + result = ser.to_list() + assert result == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ([nwp.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), + ([nwp.int_range(nwp.len())], {"literal": [0, 1, 2]}), + (nwp.int_range(nwp.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), + (nwp.int_range(nwp.col("b").min() + 4, nwp.col("d").last()), {"b": [5, 6, 7]}), + ], +) +def test_int_range( + expr: nwp.Expr | Sequence[nwp.Expr], expected: dict[str, Any], data: dict[str, Any] +) -> None: + pytest.importorskip("pyarrow") + result = dataframe(data).select(expr) + assert_equal_data(result, expected) + + +def test_int_range_eager(eager: EagerAllowed) -> None: + ser = nwp.int_range(50, eager=eager) + assert isinstance(ser, nwp.Series) + assert ser.to_list() == list(range(50)) + + +@pytest.mark.parametrize(("start", "end"), [(0, 0), (0, 1), (-1, 0), (-2.1, 3.4)]) +@pytest.mark.parametrize("num_samples", [0, 1, 2, 5, 1_000]) +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +def test_linear_space_values( + start: float, + end: float, + num_samples: int, + interval: ClosedInterval, + *, + backend: EagerAllowed | Literal[False], +) -> None: + # NOTE: Adapted from https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/functions/range/test_linear_space.py#L19-L56 + if backend: + result = nwp.linear_space( + start, end, num_samples, closed=interval, eager=backend + ).rename("ls") + else: + result = ( + dataframe({}) + .select(ls=nwp.linear_space(start, end, num_samples, closed=interval)) + .to_series() + ) + + pytest.importorskip("numpy") + import numpy as np + + if interval == "both": + expected = np.linspace(start, end, num_samples) + elif interval == "left": + expected = np.linspace(start, end, num_samples, endpoint=False) + elif interval == "right": + expected = np.linspace(start, end, num_samples + 1)[1:] + else: + expected = np.linspace(start, end, num_samples + 2)[1:-1] + + assert_equal_series(result, expected, "ls") + + +def test_linear_space_expr() -> None: + # NOTE: Adapted from https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/functions/range/test_linear_space.py#L59-L68 + pytest.importorskip("pyarrow") + df = dataframe({"a": [1, 2, 3, 4, 5]}) + + result = df.select(nwp.linear_space(0, nwp.col("a").len(), 3)) + expected = df.select( + literal=nwp.Series.from_iterable( + [0.0, 2.5, 5.0], dtype=nw.Float64, backend="pyarrow" + ) + ) + assert_equal_data(result, expected) + + result = df.select(nwp.linear_space(nwp.col("a").len(), 0, 3)) + expected = df.select( + a=nwp.Series.from_iterable([5.0, 2.5, 0.0], dtype=nw.Float64, backend="pyarrow") + ) + assert_equal_data(result, expected) + + +# NOTE: More general "supertyping" behavior would need `pyarrow.unify_schemas` +# (https://arrow.apache.org/docs/14.0/python/generated/pyarrow.unify_schemas.html) +@pytest.mark.parametrize( + ("dtype_start", "dtype_end", "dtype_expected"), + [ + pytest.param( + nw.Float32, + nw.Float32, + nw.Float32, + marks=pytest.mark.xfail( + reason="Didn't preserve `Float32` dtype, promoted to `Float64`", + raises=AssertionError, + ), + ), + (nw.Float32, nw.Float64, nw.Float64), + (nw.Float64, nw.Float32, nw.Float64), + (nw.Float64, nw.Float64, nw.Float64), + (nw.UInt8, nw.UInt32, nw.Float64), + (nw.Int16, nw.Int128, nw.Float64), + (nw.Int8, nw.Float64, nw.Float64), + ], +) +def test_linear_space_expr_numeric_dtype( + dtype_start: IntoDType, dtype_end: IntoDType, dtype_expected: IntoDType +) -> None: + # NOTE: Adapted from https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/functions/range/test_linear_space.py#L71-L95 + pytest.importorskip("pyarrow") + df = dataframe({}) + result = df.select( + ls=nwp.linear_space(nwp.lit(0, dtype=dtype_start), nwp.lit(1, dtype=dtype_end), 6) + ) + expected = df.select( + ls=nwp.Series.from_iterable( + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=dtype_expected, backend="pyarrow" + ) + ) + assert result.get_column("ls").dtype == dtype_expected + assert_equal_data(result, expected) + + +def test_linear_space_expr_wrong_length() -> None: + # NOTE: Adapted from https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/py-polars/tests/unit/functions/range/test_linear_space.py#L194-L199 + pytest.importorskip("pyarrow") + df = dataframe({"a": [1, 2, 3, 4, 5]}) + with pytest.raises(ShapeError, match="Expected object of length 6, got 5"): + df.with_columns(nwp.linear_space(0, 1, 6)) diff --git a/tests/plan/replace_strict_test.py b/tests/plan/replace_strict_test.py new file mode 100644 index 0000000000..f4195dbd69 --- /dev/null +++ b/tests/plan/replace_strict_test.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from itertools import chain +from typing import TYPE_CHECKING, Any, Literal + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from narwhals._utils import no_default +from narwhals.exceptions import InvalidOperationError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Collection, Iterable, Iterator, Mapping + + from _pytest.mark import ParameterSet + + from narwhals._plan.typing import IntoExpr + from narwhals._typing import NoDefault + from narwhals.typing import IntoDType, NonNestedLiteral + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "str": ["one", "two", "three", "four"], + "int": [1, 2, 3, 4], + "str-null": ["one", None, "three", "four"], + "int-null": [1, 2, None, 4], + "str-alt": ["beluga", "narwhal", "orca", "vaquita"], + } + + +def cases( + column: Literal["str", "int", "str-null", "int-null", "str-alt"], + replacements: Mapping[Any, Any], + return_dtypes: Iterable[IntoDType | None], + *, + default: IntoExpr | NoDefault = no_default, + expected: list[NonNestedLiteral] | None = None, + marks: pytest.MarkDecorator | Collection[pytest.MarkDecorator | pytest.Mark] = (), +) -> Iterator[ParameterSet]: + old, new = list(replacements), tuple(replacements.values()) + base = nwp.col(column) + alt_name = f"{column}_seqs" + alt = nwp.col(column).alias(alt_name) + if expected: + expected_m = {column: expected, alt_name: expected} + else: + expected_m = {column: list(new), alt_name: list(new)} + if default is no_default: + suffix = "" + else: + tp = type(default._ir) if isinstance(default, nwp.Expr) else type(default) + suffix = f"-default-{tp.__name__}" + + for dtype in return_dtypes: + exprs = ( + base.replace_strict(replacements, default=default, return_dtype=dtype), + alt.replace_strict(old, new, default=default, return_dtype=dtype), + ) + schema = {column: dtype, alt_name: dtype} if dtype else None + id = f"{column}-{dtype}{suffix}" + yield pytest.param(exprs, expected_m, schema, id=id, marks=marks) + + +@pytest.mark.parametrize( + ("exprs", "expected", "schema"), + chain( + cases( + "str", + {"one": 1, "two": 2, "three": 3, "four": 4}, + [nw.Int8, nw.Float32, None], + ), + cases("int", {1: "one", 2: "two", 3: "three", 4: "four"}, [nw.String(), None]), + cases( + "int", + {1: "one", 2: "two"}, + [nw.String, None], + default=nwp.lit("other"), + expected=["one", "two", "other", "other"], + ), + cases( + "int-null", + {1: 10, 2: 20}, + [nw.Int64, None], + default=99, + expected=[10, 20, 99, 99], + ), + cases( + "int", + {1: "one", 2: "two", 3: None}, + [nw.String, None], + default="other", + expected=["one", "two", None, "other"], + ), + cases( + "int", + {1: "one", 2: "two"}, + [nw.String, None], + default=nwp.col("str-alt"), + expected=["one", "two", "orca", "vaquita"], + ), + cases( + "int", + {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"}, + [None], + default="hundred", + expected=["one", "two", "three", "four"], + ), + ), +) +def test_replace_strict_expr( + data: Data, + exprs: Iterable[nwp.Expr], + expected: Data, + schema: Mapping[str, IntoDType] | None, +) -> None: + result = dataframe(data).select(exprs) + assert_equal_data(result, expected) + if schema is not None: + assert result.collect_schema() == schema + + +@pytest.mark.parametrize( + "expr", + [ + nwp.col("int").replace_strict([1, 3], [3, 4]), + nwp.col("str-null").replace_strict({"one": "two", "four": "five"}), + ], +) +def test_replace_strict_expr_non_full(data: Data, expr: nwp.Expr) -> None: + with pytest.raises( + (ValueError, InvalidOperationError), match=r"did not replace all non-null" + ): + dataframe(data).select(expr) + + +def test_replace_strict_scalar(data: Data) -> None: + df = dataframe(data) + expr = ( + nwp.col("str-null") + .first() + .replace_strict({"one": 1, "two": 2, "three": 3, "four": 4}) + ) + assert_equal_data(df.select(expr), {"str-null": [1]}) + + int_null = nwp.col("int-null") + repl_ints = {1: 10, 2: 20, 4: 40} + + expr = int_null.last().replace_strict(repl_ints, default=999) + assert_equal_data(df.select(expr), {"int-null": [40]}) + + expr = int_null.sort(nulls_last=True).last().replace_strict(repl_ints, default=999) + assert_equal_data(df.select(expr), {"int-null": [999]}) diff --git a/tests/plan/rolling_expr_test.py b/tests/plan/rolling_expr_test.py new file mode 100644 index 0000000000..2bff999053 --- /dev/null +++ b/tests/plan/rolling_expr_test.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals.typing import NonNestedLiteral + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +def sqrt_or_null(*values: float | None) -> list[float | None]: + return [el if el is None else math.sqrt(el) for el in values] + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "a": [None, 1, 2, None, 4, 6, 11], + "b": [1, None, 2, None, 4, 6, 11], + "c": [1, None, 2, 3, 4, 5, 6], + "var_std": [1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0], + "i": list(range(7)), + } + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "ddof", "expected"), + [ + (3, None, False, 1, [None, None, 1 / 3, 1, 4 / 3, 7 / 3, 3]), + (3, 1, False, 1, [None, 0.5, 1 / 3, 1.0, 4 / 3, 7 / 3, 3]), + (2, 1, False, 1, [None, 0.5, 0.5, 2.0, 2.0, 4.5, 4.5]), + (5, 1, True, 1, [1 / 3, 11 / 12, 4 / 5, 17 / 10, 2.0, 2.25, 3]), + (4, 1, True, 1, [0.5, 1 / 3, 11 / 12, 11 / 12, 2.25, 2.25, 3]), + (3, None, False, 2, [None, None, 2 / 3, 2.0, 8 / 3, 14 / 3, 6.0]), + ], +) +def test_rolling_var( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + ddof: int, + expected: list[NonNestedLiteral], +) -> None: + expr = nwp.col("var_std").rolling_var( + window_size, min_samples=min_samples, center=center, ddof=ddof + ) + result = dataframe(data).select(expr) + assert_equal_data(result, {"var_std": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "ddof", "expected"), + [ + (3, None, False, 1, sqrt_or_null(None, None, 1 / 3, 1, 4 / 3, 7 / 3, 3)), + (3, 1, False, 1, sqrt_or_null(None, 0.5, 1 / 3, 1.0, 4 / 3, 7 / 3, 3)), + (2, 1, False, 1, sqrt_or_null(None, 0.5, 0.5, 2.0, 2.0, 4.5, 4.5)), + (5, 1, True, 1, sqrt_or_null(1 / 3, 11 / 12, 4 / 5, 17 / 10, 2.0, 2.25, 3)), + (4, 1, True, 1, sqrt_or_null(0.5, 1 / 3, 11 / 12, 11 / 12, 2.25, 2.25, 3)), + (3, None, False, 2, sqrt_or_null(None, None, 2 / 3, 2.0, 8 / 3, 14 / 3, 6.0)), + ], +) +def test_rolling_std( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + ddof: int, + expected: list[NonNestedLiteral], +) -> None: + expr = nwp.col("var_std").rolling_std( + window_size, min_samples=min_samples, center=center, ddof=ddof + ) + result = dataframe(data).select(expr) + assert_equal_data(result, {"var_std": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "expected"), + [ + (3, None, False, [None, None, None, None, None, None, 21]), + (3, 1, False, [None, 1.0, 3.0, 3.0, 6.0, 10.0, 21.0]), + (2, 1, False, [None, 1.0, 3.0, 2.0, 4.0, 10.0, 17.0]), + (5, 1, True, [3.0, 3.0, 7.0, 13.0, 23.0, 21.0, 21.0]), + (4, 1, True, [1.0, 3.0, 3.0, 7.0, 12.0, 21.0, 21.0]), + ], +) +def test_rolling_sum( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + expected: list[NonNestedLiteral], +) -> None: + expr = nwp.col("a").rolling_sum(window_size, min_samples=min_samples, center=center) + result = dataframe(data).select(expr) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "expected"), + [ + (3, None, False, [None, None, None, None, None, None, 7.0]), + (3, 1, False, [None, 1.0, 1.5, 1.5, 3.0, 5.0, 7.0]), + (2, 1, False, [None, 1.0, 1.5, 2.0, 4.0, 5.0, 8.5]), + (5, 1, True, [1.5, 1.5, 7 / 3, 3.25, 5.75, 7.0, 7.0]), + (4, 1, True, [1.0, 1.5, 1.5, 7 / 3, 4.0, 7.0, 7.0]), + ], +) +def test_rolling_mean( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + expected: list[NonNestedLiteral], +) -> None: + expr = nwp.col("a").rolling_mean(window_size, min_samples=min_samples, center=center) + result = dataframe(data).select(expr) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "expected"), + [ + (2, None, False, [None, None, 3, None, None, 10, 17]), + (2, 2, False, [None, None, 3, None, None, 10, 17]), + (3, 2, False, [None, None, 3, 3, 6, 10, 21]), + (3, 1, False, [1, None, 3, 3, 6, 10, 21]), + (3, 1, True, [3, 1, 3, 6, 10, 21, 17]), + (4, 1, True, [3, 1, 3, 7, 12, 21, 21]), + (5, 1, True, [3, 3, 7, 13, 23, 21, 21]), + ], +) +def test_rolling_sum_order_by( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + expected: list[NonNestedLiteral], +) -> None: + expr = ( + nwp.col("b") + .rolling_sum(window_size, min_samples=min_samples, center=center) + .over(order_by="c") + ) + result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") + assert_equal_data(result, {"b": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "expected"), + [ + (2, None, False, [None, None, 1.5, None, None, 5, 8.5]), + (2, 2, False, [None, None, 1.5, None, None, 5, 8.5]), + (3, 2, False, [None, None, 1.5, 1.5, 3, 5, 7]), + (3, 1, False, [1, None, 1.5, 1.5, 3, 5, 7]), + (3, 1, True, [1.5, 1, 1.5, 3, 5, 7, 8.5]), + (4, 1, True, [1.5, 1, 1.5, 2.333333, 4, 7, 7]), + (5, 1, True, [1.5, 1.5, 2.333333, 3.25, 5.75, 7.0, 7.0]), + ], +) +def test_rolling_mean_order_by( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + expected: list[NonNestedLiteral], +) -> None: + expr = ( + nwp.col("b") + .rolling_mean(window_size, min_samples=min_samples, center=center) + .over(order_by="c") + ) + result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") + assert_equal_data(result, {"b": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "ddof", "expected"), + [ + (2, None, False, 0, [None, None, 0.25, None, None, 1, 6.25]), + (2, 2, False, 1, [None, None, 0.5, None, None, 2, 12.5]), + (3, 2, False, 1, [None, None, 0.5, 0.5, 2, 2, 13]), + (3, 1, False, 0, [0, None, 0.25, 0.25, 1, 1, 8.666666]), + (3, 1, True, 1, [0.5, None, 0.5, 2, 2, 13, 12.5]), + (4, 1, True, 1, [0.5, None, 0.5, 2.333333, 4, 13, 13]), + (5, 1, True, 0, [0.25, 0.25, 1.555555, 3.6875, 11.1875, 8.666666, 8.666666]), + ], +) +def test_rolling_var_order_by( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + ddof: int, + expected: list[NonNestedLiteral], +) -> None: + expr = ( + nwp.col("b") + .rolling_var(window_size, min_samples=min_samples, center=center, ddof=ddof) + .over(order_by="c") + ) + result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") + assert_equal_data(result, {"b": expected}) + + +@pytest.mark.parametrize( + ("window_size", "min_samples", "center", "ddof", "expected"), + [ + (2, None, False, 0, [None, None, 0.5, None, None, 1, 2.5]), + (2, 2, False, 1, [None, None, 0.707107, None, None, 1.414214, 3.535534]), + (3, 2, False, 1, [None, None, 0.707107, 0.707107, 1.414214, 1.414214, 3.605551]), + (3, 1, False, 0, [0.0, None, 0.5, 0.5, 1.0, 1.0, 2.943920]), + ( + 3, + 1, + True, + 1, + [0.707107, None, 0.707107, 1.414214, 1.414214, 3.605551, 3.535534], + ), + (4, 1, True, 1, [0.707107, None, 0.707107, 1.527525, 2.0, 3.605551, 3.605551]), + (5, 1, True, 0, [0.5, 0.5, 1.247219, 1.920286, 3.344772, 2.943920, 2.943920]), + ], +) +def test_rolling_std_order_by( + data: Data, + window_size: int, + *, + min_samples: int | None, + center: bool, + ddof: int, + expected: list[NonNestedLiteral], +) -> None: + expr = ( + nwp.col("b") + .rolling_std(window_size, min_samples=min_samples, center=center, ddof=ddof) + .over(order_by="c") + ) + result = dataframe(data).with_columns(expr).select("b", "i").sort("i").drop("i") + assert_equal_data(result, {"b": expected}) diff --git a/tests/plan/sample_test.py b/tests/plan/sample_test.py new file mode 100644 index 0000000000..5491a4189b --- /dev/null +++ b/tests/plan/sample_test.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import sys +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from narwhals.exceptions import ShapeError +from tests.plan.utils import dataframe, series + +if TYPE_CHECKING: + from collections.abc import Callable + + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": [1, 2, 3] * 10, "b": [4, 5, 6] * 10} + + +@pytest.fixture(scope="module") +def data_big() -> Data: + return {"a": list(range(100))} + + +if sys.version_info >= (3, 13): # pragma: no cover + # NOTE: (#2705) Would've added the handling for `category` + # The default triggers a warning, but only on `>=3.13` + deprecated_call: Callable[..., AbstractContextManager[Any]] = pytest.deprecated_call +else: # pragma: no cover + deprecated_call = nullcontext + + +@pytest.mark.parametrize("n", [None, 1, 7, 29]) +def test_sample_n_series(data: Data, n: int | None) -> None: + result = series(data["a"]).sample(n).shape + expected = (1,) if n is None else (n,) + assert result == expected + + +def test_sample_fraction_series(data: Data) -> None: + result = series(data["a"]).sample(fraction=0.1).shape + expected = (3,) + assert result == expected + + +@pytest.mark.parametrize("n", [10]) +def test_sample_with_seed_series(data_big: Data, n: int) -> None: + ser = series(data_big["a"]) + seed1 = ser.sample(n, seed=123) + seed2 = ser.sample(n, seed=123) + seed3 = ser.sample(n, seed=42) + result = {"res1": [(seed1 == seed2).all()], "res2": [(seed1 == seed3).all()]} + expected = {"res1": [True], "res2": [False]} + assert result == expected + + +@pytest.mark.parametrize("n", [2, None, 1, 18]) +def test_sample_n_dataframe(data: Data, n: int | None) -> None: + result = dataframe(data).sample(n=n).shape + expected = (1, 2) if n is None else (n, 2) + assert result == expected + + +def test_sample_fraction_dataframe(data: Data) -> None: + result = dataframe(data).sample(fraction=0.5).shape + expected = (15, 2) + assert result == expected + + +@pytest.mark.parametrize("n", [10]) +def test_sample_with_seed_dataframe(data_big: Data, n: int) -> None: + df = dataframe(data_big) + r1 = df.sample(n, seed=123).to_native() + r2 = df.sample(n, seed=123).to_native() + r3 = df.sample(n, seed=42).to_native() + assert r1.equals(r2) + assert not r1.equals(r3) + + +@pytest.mark.parametrize("n", [39, 42, 20, 99]) +def test_sample_with_replacement_series(data: Data, n: int) -> None: + result = series(data["a"]).slice(0, 10).sample(n, with_replacement=True) + assert len(result) == n + + +@pytest.mark.parametrize("n", [10, 15, 28, 100]) +def test_sample_with_replacement_dataframe(data: Data, n: int) -> None: + result = dataframe(data).slice(0, 5).sample(n, with_replacement=True) + assert len(result) == n + + +@pytest.mark.parametrize( + ("base", "kwds", "expected"), + [ + (nwp.col("a"), {"n": 2}, (2, 1)), + (nwp.all(), {"n": 1}, (1, 2)), + (nwp.nth(1, 0), {}, (1, 2)), + (~ncs.string(), {"fraction": 0.5}, (15, 2)), + (ncs.last(), {"n": 75, "with_replacement": True, "seed": 99}, (75, 1)), + ], +) +def test_sample_expr( + data: Data, base: nwp.Expr, kwds: dict[str, Any], expected: tuple[int, int] +) -> None: + with deprecated_call(): + expr = base.sample(**kwds) + result = dataframe(data).select(expr).shape + assert result == expected + + +def test_sample_invalid(data: Data) -> None: + df = dataframe(data) + ser = df.to_series() + + both_n_fraction = r"cannot specify both `n` and `fraction`" + too_high_n = r"cannot take a larger sample than the total population when `with_replacement=false`" + + with pytest.raises(ValueError, match=both_n_fraction): + df.sample(n=1, fraction=0.5) + with pytest.raises(ValueError, match=both_n_fraction): + ser.sample(n=567, fraction=0.1) + with pytest.raises(ValueError, match=both_n_fraction), deprecated_call(): + nwp.col("a").sample(n=30, fraction=0.3) + with pytest.raises(ShapeError, match=too_high_n): + df.sample(n=1_000) + with pytest.raises(ShapeError, match=too_high_n): + ser.sample(n=2_000) + with pytest.raises(ShapeError), deprecated_call(): + df.with_columns(nwp.col("b").sample(123, with_replacement=True)) diff --git a/tests/plan/selectors_test.py b/tests/plan/selectors_test.py index 56ba3b1251..a9565521f8 100644 --- a/tests/plan/selectors_test.py +++ b/tests/plan/selectors_test.py @@ -18,7 +18,7 @@ from narwhals._plan import Selector, selectors as ncs from narwhals._plan._guards import is_expr, is_selector from narwhals._utils import zip_strict -from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError +from narwhals.exceptions import ColumnNotFoundError, DuplicateError, InvalidOperationError from tests.plan.utils import ( Frame, assert_expr_ir_equal, @@ -749,3 +749,14 @@ def test_when_then_keep_map_13858() -> None: ) df.assert_selects(aliased, "b_other") df.assert_selects(when_keep_chain, "b_other") + + +def test_keep_name_struct_field_23669() -> None: + df = Frame.from_mapping( + {"foo": nw.Struct({"x": nw.Int64}), "bar": nw.Struct({"x": nw.Int64})} + ) + + with pytest.raises(DuplicateError): + df.project(nwp.all().struct.field("x")) + + df.assert_selects(nwp.all().struct.field("x").name.keep(), "foo", "bar") diff --git a/tests/plan/str_contains_test.py b/tests/plan/str_contains_test.py new file mode 100644 index 0000000000..666992e557 --- /dev/null +++ b/tests/plan/str_contains_test.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"pets": ["cat", "dog", "rabbit and parrot", "dove", "Parrot|dove", None]} + + +@pytest.mark.parametrize( + ("pattern", "literal", "expected"), + [ + ("(?i)parrot|Dove", False, [False, False, True, True, True, None]), + ("parrot|Dove", False, [False, False, True, False, False, None]), + ("Parrot|dove", False, [False, False, False, True, True, None]), + ("Parrot|dove", True, [False, False, False, False, True, None]), + ], + ids=["case_insensitive", "case_sensitive-1", "case_sensitive-2", "literal"], +) +def test_str_contains( + data: Data, pattern: str, *, literal: bool, expected: list[bool | None] +) -> None: + result = dataframe(data).select( + nwp.col("pets").str.contains(pattern, literal=literal) + ) + assert_equal_data(result, {"pets": expected}) diff --git a/tests/plan/str_len_chars_test.py b/tests/plan/str_len_chars_test.py new file mode 100644 index 0000000000..e5e3589592 --- /dev/null +++ b/tests/plan/str_len_chars_test.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +def test_len_chars() -> None: + data = {"a": ["foo", "foobar", "Café", "345", "東京"]} + expected = {"a": [3, 6, 4, 3, 2]} + result = dataframe(data).select(nwp.col("a").str.len_chars()) + assert_equal_data(result, expected) diff --git a/tests/plan/str_replace_test.py b/tests/plan/str_replace_test.py new file mode 100644 index 0000000000..b61db9010e --- /dev/null +++ b/tests/plan/str_replace_test.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Final + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Sequence + +A1: Final = ["123abc", "abc456"] +A2: Final = ["abc abc", "abc456"] +A3: Final = ["abc abc abc", "456abc"] +A4: Final = ["Dollar $ign", "literal"] +A5: Final = [None, "oop"] +B: Final = ["ghi", "jkl"] + + +replace_scalar = pytest.mark.parametrize( + ("data", "pattern", "value", "n", "literal", "expected"), + [ + (A1, r"abc\b", "ABC", 1, False, ["123ABC", "abc456"]), + (A2, r"abc", "", 1, False, [" abc", "456"]), + (A3, r"abc", "", -1, False, [" ", "456"]), + (A4, r"$", "S", -1, True, ["Dollar Sign", "literal"]), + ], +) +replace_vector = pytest.mark.parametrize( + ("data", "pattern", "value", "n", "literal", "expected"), + [ + pytest.param( + A1, r"abc", nwp.col("b"), 1, False, ["123ghi", "jkl456"], id="n-1-single" + ), + pytest.param( + A2, r"abc", nwp.col("b"), 1, False, ["ghi abc", "jkl456"], id="n-1-mixed" + ), + pytest.param( + A3, + r"a", + nwp.col("b"), + 2, + False, + ["ghibc ghibc abc", "456jklbc"], + id="n-2-mixed", + ), + pytest.param( + A3, + r"abc", + nwp.col("b"), + -1, + False, + ["ghi ghi ghi", "456jkl"], + id="replace_all", + ), + pytest.param( + A4, + r"$", + nwp.col("b"), + -1, + True, + ["Dollar ghiign", "literal"], + id="literal-replace_all", + ), + pytest.param( + ["dogcatdogcat", "dog dog"], + "cat", + nwp.col("b").last(), + 1, + True, + ["dogjkldogcat", "dog dog"], + id="agg-replacement", + ), + pytest.param( + A3, + r"^abc", + nwp.col("b").str.to_uppercase(), + 1, + False, + ["GHI abc abc", "456abc"], + id="transformed-replacement", + ), + pytest.param(A5, r"o", nwp.col("b"), 1, False, [None, "jklop"], id="null-input"), + ], +) +replace_all_scalar = pytest.mark.parametrize( + ("data", "pattern", "value", "literal", "expected"), + [ + (A1, r"abc\b", "ABC", False, ["123ABC", "abc456"]), + (A2, r"abc", "", False, [" ", "456"]), + (A3, r"abc", "", False, [" ", "456"]), + (A4, r"$", "S", True, ["Dollar Sign", "literal"]), + ], +) + + +replace_all_vector = pytest.mark.parametrize( + ("data", "pattern", "value", "literal", "expected"), + [ + pytest.param(A1, r"abc", nwp.col("b"), False, ["123ghi", "jkl456"], id="single"), + pytest.param(A2, r"abc", nwp.col("b"), False, ["ghi ghi", "jkl456"], id="mixed"), + pytest.param( + A4, r"$", nwp.col("b"), True, ["Dollar ghiign", "literal"], id="literal" + ), + pytest.param(A5, r"o", nwp.col("b"), False, [None, "jkljklp"], id="null-input"), + pytest.param( + A3, + r"\d", + nwp.col("b").first(), + False, + ["abc abc abc", "ghighighiabc"], + id="agg-replacement", + ), + pytest.param( + A3, + r" ?abc$", + nwp.lit(" HELLO").str.to_lowercase().str.strip_chars(), + False, + ["abc abchello", "456hello"], + id="transformed-replacement", + ), + ], +) + + +@replace_scalar +def test_str_replace_scalar( + data: list[str], + pattern: str, + value: str, + n: int, + *, + literal: bool, + expected: list[str], +) -> None: + df = dataframe({"a": data}) + result = df.select(nwp.col("a").str.replace(pattern, value, n=n, literal=literal)) + assert_equal_data(result, {"a": expected}) + + +@replace_vector +def test_str_replace_vector( + data: Sequence[str | None], + pattern: str, + value: nwp.Expr, + n: int, + *, + literal: bool, + expected: Sequence[str | None], +) -> None: + df = dataframe({"a": data, "b": B}) + result = df.select(nwp.col("a").str.replace(pattern, value, n=n, literal=literal)) + assert_equal_data(result, {"a": expected}) + + +@replace_all_scalar +def test_str_replace_all_scalar( + data: list[str], pattern: str, value: str, *, literal: bool, expected: list[str] +) -> None: + df = dataframe({"a": data}) + result = df.select(nwp.col("a").str.replace_all(pattern, value, literal=literal)) + assert_equal_data(result, {"a": expected}) + + +@replace_all_vector +def test_str_replace_all_vector( + data: Sequence[str | None], + pattern: str, + value: nwp.Expr, + *, + literal: bool, + expected: Sequence[str | None], +) -> None: + df = dataframe({"a": data, "b": B}) + result = df.select(nwp.col("a").str.replace_all(pattern, value, literal=literal)) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/str_slice_test.py b/tests/plan/str_slice_test.py new file mode 100644 index 0000000000..b2ee40c300 --- /dev/null +++ b/tests/plan/str_slice_test.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.mark.parametrize( + ("offset", "length", "expected"), + [(1, 2, {"a": ["da", "df"]}), (-2, None, {"a": ["as", "as"]})], +) +def test_str_slice(offset: int, length: int | None, expected: Data) -> None: + data = {"a": ["fdas", "edfas"]} + result = dataframe(data).select(nwp.col("a").str.slice(offset, length)) + assert_equal_data(result, expected) diff --git a/tests/plan/str_split_test.py b/tests/plan/str_split_test.py new file mode 100644 index 0000000000..548c909d33 --- /dev/null +++ b/tests/plan/str_split_test.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +@pytest.mark.parametrize( + ("by", "expected"), + [ + ("_", [["foo bar"], ["foo", "bar"], ["foo", "bar", "baz"], ["foo,bar"]]), + (",", [["foo bar"], ["foo_bar"], ["foo_bar_baz"], ["foo", "bar"]]), + ], +) +def test_str_split(by: str, expected: list[list[str]]) -> None: + data = {"a": ["foo bar", "foo_bar", "foo_bar_baz", "foo,bar"]} + result = dataframe(data).select(nwp.col("a").str.split(by)) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/str_starts_ends_with_test.py b/tests/plan/str_starts_ends_with_test.py new file mode 100644 index 0000000000..ebf3f4e0f9 --- /dev/null +++ b/tests/plan/str_starts_ends_with_test.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from tests.conftest import Data + + +@pytest.fixture(scope="module") +def data() -> Data: + return {"a": ["Starts_with", "starts_with", "Ends_with", "ends_With", None]} + + +@pytest.mark.parametrize( + ("prefix", "expected"), + [ + ("start", [False, True, False, False, None]), + ("End", [False, False, True, False, None]), + ], +) +def test_str_starts_with(data: Data, prefix: str, expected: list[bool | None]) -> None: + result = dataframe(data).select(nwp.col("a").str.starts_with(prefix)) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + ("suffix", "expected"), + [("With", [False, False, False, True, None]), ("th", [True, True, True, True, None])], +) +def test_str_ends_with(data: Data, suffix: str, expected: list[bool | None]) -> None: + result = dataframe(data).select(nwp.col("a").str.ends_with(suffix)) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/str_strip_chars_test.py b/tests/plan/str_strip_chars_test.py new file mode 100644 index 0000000000..9910d69ec1 --- /dev/null +++ b/tests/plan/str_strip_chars_test.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +@pytest.mark.parametrize( + ("characters", "expected"), + [(None, ["foobar", "bar", "baz"]), ("foo", ["bar", "bar\n", " baz"])], +) +def test_str_strip_chars(characters: str | None, expected: list[str]) -> None: + data = {"a": ["foobar", "bar\n", " baz"]} + result = dataframe(data).select(nwp.col("a").str.strip_chars(characters)) + assert_equal_data(result, {"a": expected}) diff --git a/tests/plan/str_transform_case_test.py b/tests/plan/str_transform_case_test.py new file mode 100644 index 0000000000..e7bf2be7c7 --- /dev/null +++ b/tests/plan/str_transform_case_test.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + StrData: TypeAlias = dict[str, list[str]] + + +@pytest.fixture(scope="module") +def data() -> StrData: + return { + "a": [ + "e.t. phone home", + "they're bill's friends from the UK", + "to infinity,and BEYOND!", + "with123numbers", + "__dunder__score_a1_.2b ?three", + ] + } + + +@pytest.fixture(scope="module") +def data_lower(data: StrData) -> StrData: + return {"a": [*data["a"], "SPECIAL CASE ß", "ΣPECIAL CAΣE"]} + + +@pytest.fixture(scope="module") +def expected_title(data: StrData) -> StrData: + return {"a": [s.title() for s in data["a"]]} + + +@pytest.fixture(scope="module") +def expected_upper(data: StrData) -> StrData: + return {"a": [s.upper() for s in data["a"]]} + + +@pytest.fixture(scope="module") +def expected_lower(data_lower: StrData) -> StrData: + return {"a": [s.lower() for s in data_lower["a"]]} + + +def test_str_to_titlecase(data: StrData, expected_title: StrData) -> None: + result = dataframe(data).select(nwp.col("a").str.to_titlecase()) + assert_equal_data(result, expected_title) + + +def test_str_to_uppercase(data: StrData, expected_upper: StrData) -> None: + result = dataframe(data).select(nwp.col("a").str.to_uppercase()) + assert_equal_data(result, expected_upper) + + +def test_str_to_lowercase(data_lower: StrData, expected_lower: StrData) -> None: + result = dataframe(data_lower).select(nwp.col("a").str.to_lowercase()) + assert_equal_data(result, expected_lower) diff --git a/tests/plan/str_zfill_test.py b/tests/plan/str_zfill_test.py new file mode 100644 index 0000000000..1d471b70de --- /dev/null +++ b/tests/plan/str_zfill_test.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + + +def test_str_zfill() -> None: + data = {"a": ["-1", "+1", "1", "12", "123", "99999", "+9999", None]} + expected = {"a": ["-01", "+01", "001", "012", "123", "99999", "+9999", None]} + result = dataframe(data).select(nwp.col("a").str.zfill(3)) + assert_equal_data(result, expected) diff --git a/tests/plan/struct_field_test.py b/tests/plan/struct_field_test.py new file mode 100644 index 0000000000..4c7b3b1d90 --- /dev/null +++ b/tests/plan/struct_field_test.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Iterable + + from tests.conftest import Data + +pytest.importorskip("pyarrow") + + +@pytest.mark.parametrize( + ("exprs", "expected"), + [ + pytest.param( + nwp.col("user").struct.field("id"), {"id": ["0", "1"]}, id="field-single" + ), + pytest.param( + [nwp.col("user").struct.field("id"), nwp.col("user").struct.field("name")], + {"id": ["0", "1"], "name": ["john", "jane"]}, + id="multiple-fields-same-root", + ), + pytest.param( + nwp.col("user").struct.field("id").name.keep(), + {"user": ["0", "1"]}, + id="field-single-keep-root", + ), + ], +) +def test_struct_field(exprs: nwp.Expr | Iterable[nwp.Expr], expected: Data) -> None: + data = {"user": [{"id": "0", "name": "john"}, {"id": "1", "name": "jane"}]} + result = dataframe(data).select(exprs) + assert_equal_data(result, expected) diff --git a/tests/plan/temp_test.py b/tests/plan/temp_test.py index 9dd7a0e42f..58873be4d1 100644 --- a/tests/plan/temp_test.py +++ b/tests/plan/temp_test.py @@ -66,14 +66,14 @@ def test_temp_column_names_sources(source: _StoresColumns | Iterable[str]) -> No @given(n_chars=st.integers(6, 106)) @pytest.mark.slow -def test_temp_column_name_n_chars(n_chars: int) -> None: +def test_temp_column_name_n_chars(n_chars: int) -> None: # pragma: no cover name = temp.column_name(_COLUMNS, n_chars=n_chars) assert name not in _COLUMNS @given(n_new_names=st.integers(10_000, 100_000)) @pytest.mark.slow -def test_temp_column_names_always_new_names(n_new_names: int) -> None: +def test_temp_column_names_always_new_names(n_new_names: int) -> None: # pragma: no cover it = temp.column_names(_COLUMNS) new_names = set(islice(it, n_new_names)) assert len(new_names) == n_new_names diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 09acbf750f..92446a48cd 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload import pytest @@ -13,11 +13,13 @@ pytest.importorskip("pyarrow") +from collections.abc import Sequence + import pyarrow as pa if TYPE_CHECKING: import sys - from collections.abc import Iterable, Mapping, Sequence + from collections.abc import Iterable, Mapping from typing_extensions import LiteralString, TypeAlias @@ -218,14 +220,27 @@ def series(values: Iterable[Any], /) -> nwp.Series[pa.ChunkedArray[Any]]: def assert_equal_data( - result: nwp.DataFrame[Any, Any], expected: Mapping[str, Any] + result: nwp.DataFrame[Any, Any], expected: Mapping[str, Any] | nwp.DataFrame[Any, Any] ) -> None: + if isinstance(expected, nwp.DataFrame): + expected = expected.to_dict(as_series=False) _assert_equal_data(result.to_dict(as_series=False), expected) +@overload +def assert_equal_series(result: nwp.Series[Any], expected: nwp.Series[Any]) -> None: ... +@overload +def assert_equal_series( + result: nwp.Series[Any], expected: Iterable[Any], name: str +) -> None: ... def assert_equal_series( - result: nwp.Series[Any], expected: Sequence[Any], name: str + result: nwp.Series[Any], expected: Iterable[Any], name: str = "" ) -> None: + if isinstance(expected, nwp.Series): + name = expected.name + expected = expected.to_list() + else: + expected = expected if isinstance(expected, Sequence) else tuple(expected) assert_equal_data(result.to_frame(), {name: expected})