diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index e0709a84ff..6d7cd7a87a 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -406,8 +406,8 @@ def join( other: Self, *, how: Literal["left", "inner", "cross", "anti", "semi"], - left_on: list[str] | None, - right_on: list[str] | None, + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, suffix: str, ) -> Self: how_to_join_map: dict[str, JoinType] = { @@ -442,8 +442,8 @@ def join( return self._from_native_frame( self._native_frame.join( other._native_frame, - keys=left_on or [], - right_keys=right_on, + keys=left_on or [], # type: ignore[arg-type] + right_keys=right_on, # type: ignore[arg-type] join_type=how_to_join_map[how], right_suffix=suffix, ), diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 780bf7c195..a8563859db 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -7,6 +7,8 @@ from typing import Mapping from typing import Sequence +import pyarrow.compute as pc + from narwhals._arrow.expr_cat import ArrowExprCatNamespace from narwhals._arrow.expr_dt import ArrowExprDateTimeNamespace from narwhals._arrow.expr_list import ArrowExprListNamespace @@ -22,6 +24,7 @@ from narwhals.exceptions import ColumnNotFoundError from narwhals.typing import CompliantExpr from narwhals.utils import Implementation +from narwhals.utils import generate_temporary_column_name from narwhals.utils import not_implemented if TYPE_CHECKING: @@ -423,28 +426,57 @@ def clip(self: Self, lower_bound: Any | None, upper_bound: Any | None) -> Self: self, "clip", lower_bound=lower_bound, upper_bound=upper_bound ) - def over(self: Self, keys: Sequence[str], kind: ExprKind) -> Self: - if not is_scalar_like(kind): - msg = "Only aggregation or literal operations are supported in `over` context for PyArrow." + def over( + self: Self, + partition_by: Sequence[str], + kind: ExprKind, + order_by: Sequence[str] | None, + ) -> Self: + if partition_by and not is_scalar_like(kind): + msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow." raise NotImplementedError(msg) - def func(df: ArrowDataFrame) -> list[ArrowSeries]: - output_names, aliases = evaluate_output_names_and_aliases(self, df, []) - if overlap := set(output_names).intersection(keys): - # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined, - # we just don't support it yet. - msg = ( - f"Column names {overlap} appear in both expression output names and in `over` keys.\n" - "This is not yet supported." + if not partition_by: + # e.g. `nw.col('a').cum_sum().order_by(key)` + # which we can always easily support, as it doesn't require grouping. + assert order_by is not None # help type checkers # noqa: S101 + + def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: + token = generate_temporary_column_name(8, df.columns) + df = df.with_row_index(token).sort( + *order_by, descending=False, nulls_last=False + ) + result = self(df) + # TODO(marco): is there a way to do this efficiently without + # doing 2 sorts? Here we're sorting the dataframe and then + # again calling `sort_indices`. `ArrowSeries.scatter` would also sort. + sorting_indices = pc.sort_indices(df[token]._native_series) # type: ignore[call-overload] + return [ + ser._from_native_series(pc.take(ser._native_series, sorting_indices)) # type: ignore[call-overload] + for ser in result + ] + else: + + def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + if overlap := set(output_names).intersection(partition_by): + # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined, + # we just don't support it yet. + msg = ( + f"Column names {overlap} appear in both expression output names and in `over` keys.\n" + "This is not yet supported." + ) + raise NotImplementedError(msg) + + tmp = df.group_by(*partition_by, drop_null_keys=False).agg(self) + tmp = df.simple_select(*partition_by).join( + tmp, + how="left", + left_on=partition_by, + right_on=partition_by, + suffix="_right", ) - raise NotImplementedError(msg) - - tmp = df.group_by(*keys, drop_null_keys=False).agg(self) - on = list(keys) - tmp = df.simple_select(*keys).join( - tmp, how="left", left_on=on, right_on=on, suffix="_right" - ) - return [tmp[alias] for alias in aliases] + return [tmp[alias] for alias in aliases] return self.__class__( func, diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 5cb473773d..bb1c1eed6b 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -26,6 +26,8 @@ from narwhals.utils import not_implemented if TYPE_CHECKING: + from narwhals._expression_parsing import ExprKind + try: import dask.dataframe.dask_expr as dx except ModuleNotFoundError: @@ -343,6 +345,7 @@ def shift(self: Self, n: int) -> Self: def cum_sum(self: Self, *, reverse: bool) -> Self: if reverse: # pragma: no cover + # https://github.com/dask/dask/issues/11802 msg = "`cum_sum(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) @@ -533,41 +536,58 @@ def null_count(self: Self) -> Self: lambda _input: _input.isna().sum().to_series(), "null_count" ) - def over(self: Self, keys: Sequence[str], kind: ExprKind) -> Self: + def over( + self: Self, + partition_by: Sequence[str], + kind: ExprKind, + order_by: Sequence[str] | None, + ) -> Self: # pandas is a required dependency of dask so it's safe to import this from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT - if not is_elementary_expression(self): # pragma: no cover + if not partition_by: + assert order_by is not None # help type checkers # noqa: S101 + + # This is something like `nw.col('a').cum_sum().order_by(key)` + # which we can always easily support, as it doesn't require grouping. + def func(df: DaskLazyFrame) -> Sequence[dx.Series]: + return self(df.sort(*order_by, descending=False, nulls_last=False)) + elif not is_elementary_expression(self): # pragma: no cover msg = ( "Only elementary expressions are supported for `.over` in dask.\n\n" "Please see: " "https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/" ) raise NotImplementedError(msg) - function_name = re.sub(r"(\w+->)", "", self._function_name) - try: - dask_function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT[function_name] - except KeyError: - msg = ( - f"Unsupported function: {function_name} in `over` context.\n\n." - f"Supported functions are {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}\n" - ) - raise NotImplementedError(msg) from None - - def func(df: DaskLazyFrame) -> list[dx.Series]: - output_names, aliases = evaluate_output_names_and_aliases(self, df, []) - - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message=".*`meta` is not specified", category=UserWarning - ) - res_native = df._native_frame.groupby(keys)[list(output_names)].transform( - dask_function_name, **self._call_kwargs + else: + function_name = re.sub(r"(\w+->)", "", self._function_name) + try: + dask_function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT[function_name] + except KeyError: + # window functions are unsupported: https://github.com/dask/dask/issues/11806 + msg = ( + f"Unsupported function: {function_name} in `over` context.\n\n." + f"Supported functions are {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}\n" ) - result_frame = df._from_native_frame( - res_native.rename(columns=dict(zip(output_names, aliases))) - )._native_frame - return [result_frame[name] for name in aliases] + raise NotImplementedError(msg) from None + + def func(df: DaskLazyFrame) -> Sequence[dx.Series]: + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + + with warnings.catch_warnings(): + # https://github.com/dask/dask/issues/11804 + warnings.filterwarnings( + "ignore", + message=".*`meta` is not specified", + category=UserWarning, + ) + res_native = df._native_frame.groupby(partition_by)[ + list(output_names) + ].transform(dask_function_name, **self._call_kwargs) + result_frame = df._from_native_frame( + res_native.rename(columns=dict(zip(output_names, aliases))) + )._native_frame + return [result_frame[name] for name in aliases] return self.__class__( func, diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 49345e153c..8aa4a15a30 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -19,11 +19,11 @@ from narwhals._pandas_like.expr_str import PandasLikeExprStringNamespace from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT from narwhals._pandas_like.series import PandasLikeSeries -from narwhals._pandas_like.utils import rename from narwhals.dependencies import get_numpy from narwhals.dependencies import is_numpy_array from narwhals.exceptions import ColumnNotFoundError from narwhals.typing import CompliantExpr +from narwhals.utils import generate_temporary_column_name if TYPE_CHECKING: from typing_extensions import Self @@ -457,63 +457,90 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: call_kwargs=self._call_kwargs, ) - def over(self: Self, partition_by: Sequence[str], kind: ExprKind) -> Self: - if not is_elementary_expression(self): + def over( + self: Self, + partition_by: Sequence[str], + kind: ExprKind, + order_by: Sequence[str] | None, + ) -> Self: + if not partition_by: + # e.g. `nw.col('a').cum_sum().order_by(key)` + # We can always easily support this as it doesn't require grouping. + assert order_by is not None # noqa: S101 # help type-check + + def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: + token = generate_temporary_column_name(8, df.columns) + df = df.with_row_index(token).sort( + *order_by, descending=False, nulls_last=False + ) + results = self(df) + sorting_indices = df[token] + for s in results: + s._scatter_in_place(sorting_indices, s) + return results + elif not is_elementary_expression(self): msg = ( "Only elementary expressions are supported for `.over` in pandas-like backends.\n\n" "Please see: " "https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/" ) raise NotImplementedError(msg) - function_name = re.sub(r"(\w+->)", "", self._function_name) - try: - pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT[function_name] - except KeyError: - try: - pandas_function_name = AGGREGATIONS_TO_PANDAS_EQUIVALENT[function_name] - except KeyError: + else: + function_name: str = re.sub(r"(\w+->)", "", self._function_name) + if pandas_function_name := WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get( + function_name, AGGREGATIONS_TO_PANDAS_EQUIVALENT.get(function_name, None) + ): + pass + else: msg = ( f"Unsupported function: {function_name} in `over` context.\n\n" f"Supported functions are {', '.join(WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT)}\n" f"and {', '.join(AGGREGATIONS_TO_PANDAS_EQUIVALENT)}." ) - raise NotImplementedError(msg) from None - - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + raise NotImplementedError(msg) pandas_kwargs = window_kwargs_to_pandas_equivalent( function_name, self._call_kwargs ) - if function_name == "cum_count": - plx = self.__narwhals_namespace__() - df = df.with_columns(~plx.col(*output_names).is_null()) - if function_name.startswith("cum_"): - reverse = self._call_kwargs["reverse"] - else: - assert "reverse" not in self._call_kwargs # debug assertion # noqa: S101 - reverse = False - if reverse: - # Only select the columns we need to avoid reversing columns - # unnecessarily - columns = list(set(partition_by).union(output_names)) - native_frame = df[columns]._native_frame[::-1] - else: - native_frame = df._native_frame - res_native = native_frame.groupby(partition_by)[list(output_names)].transform( - pandas_function_name, **pandas_kwargs - ) - result_frame = df._from_native_frame( - rename( - res_native, - columns=dict(zip(output_names, aliases)), - implementation=self._implementation, - backend_version=self._backend_version, + def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + + if function_name == "cum_count": + plx = self.__narwhals_namespace__() + df = df.with_columns(~plx.col(*output_names).is_null()) + + if function_name.startswith("cum_"): + reverse = self._call_kwargs["reverse"] + else: + assert "reverse" not in self._call_kwargs # noqa: S101 + reverse = False + + if order_by: + columns = list(set(partition_by).union(output_names).union(order_by)) + token = generate_temporary_column_name(8, columns) + df = ( + df[columns] + .with_row_index(token) + .sort(*order_by, descending=reverse, nulls_last=reverse) + ) + sorting_indices = df[token] + elif reverse: + columns = list(set(partition_by).union(output_names)) + df = df[columns][::-1] + res_native = df._native_frame.groupby(partition_by)[ + list(output_names) + ].transform(pandas_function_name, **pandas_kwargs) + result_frame = df._from_native_frame(res_native).rename( + dict(zip(output_names, aliases)) ) - ) - if reverse: - return [result_frame[name][::-1] for name in aliases] - return [result_frame[name] for name in aliases] + results = [result_frame[name] for name in aliases] + if order_by: + for s in results: + s._scatter_in_place(sorting_indices, s) + return results + if reverse: + return [s[::-1] for s in results] + return results return self.__class__( func, diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index b80f2859c2..87057c3d6b 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -9,6 +9,8 @@ from typing import cast from typing import overload +import numpy as np + from narwhals._pandas_like.series_cat import PandasLikeSeriesCatNamespace from narwhals._pandas_like.series_dt import PandasLikeSeriesDateTimeNamespace from narwhals._pandas_like.series_list import PandasLikeSeriesListNamespace @@ -27,6 +29,7 @@ from narwhals.typing import CompliantSeries from narwhals.utils import Implementation from narwhals.utils import import_dtypes_module +from narwhals.utils import parse_version from narwhals.utils import validate_backend_version if TYPE_CHECKING: @@ -242,6 +245,25 @@ def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: s.name = self.name return self._from_native_series(s) + def _scatter_in_place(self: Self, indices: Self, values: Self) -> None: + # Scatter, modifying original Series. Use with care! + values_native = set_index( + values._native_series, + self._native_series.index[indices._native_series], + implementation=self._implementation, + backend_version=self._backend_version, + ) + if self._implementation is Implementation.PANDAS and parse_version(np) < (2,): + values_native = values_native.copy() # pragma: no cover + min_pd_version = (1, 2) + if ( + self._implementation is Implementation.PANDAS + and self._backend_version < min_pd_version + ): + self._native_series.iloc[indices._native_series.values] = values_native # noqa: PD011 + else: + self._native_series.iloc[indices._native_series] = values_native + def cast(self: Self, dtype: DType | type[DType]) -> Self: ser = self._native_series dtype_backend = get_dtype_backend( diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index d39d2cdb31..7ebbe4ff78 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -97,8 +97,22 @@ def is_nan(self: Self) -> Self: ) return self._from_native_expr(self._native_expr.is_nan()) - def over(self: Self, keys: list[str], kind: ExprKind) -> Self: - return self._from_native_expr(self._native_expr.over(keys)) + def over( + self: Self, + partition_by: Sequence[str], + kind: ExprKind, + order_by: Sequence[str] | None, + ) -> Self: + if self._backend_version < (1, 9): + if order_by: + msg = "`order_by` in Polars requires version 1.10 or greater" + raise NotImplementedError(msg) + return self._from_native_expr( + self._native_expr.over(partition_by or pl.lit(1)) + ) + return self._from_native_expr( + self._native_expr.over(partition_by or pl.lit(1), order_by=order_by) + ) def rolling_var( self: Self, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 965627ff82..69148eca34 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -29,6 +29,7 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals._spark_like.typing import WindowFunction from narwhals.dtypes import DType from narwhals.utils import Version @@ -54,6 +55,7 @@ def __init__( self._backend_version = backend_version self._version = version self._implementation = implementation + self._window_function: WindowFunction | None = None def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]: return self._call(df) @@ -211,6 +213,22 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: implementation=self._implementation, ) + def _with_window_function( + self: Self, + window_function: WindowFunction, + ) -> Self: + result = self.__class__( + self._call, + function_name=self._function_name, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + result._window_function = window_function + return result + def __eq__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] return self._from_call( lambda _input, other: _input.__eq__(other), "__eq__", other=other @@ -496,9 +514,27 @@ def _n_unique(_input: Column) -> Column: return self._from_call(_n_unique, "n_unique") - def over(self: Self, keys: Sequence[str], kind: ExprKind) -> Self: - def func(df: SparkLikeLazyFrame) -> list[Column]: - return [expr.over(self._Window.partitionBy(*keys)) for expr in self._call(df)] + def over( + self: Self, + partition_by: Sequence[str], + kind: ExprKind, + order_by: Sequence[str] | None, + ) -> Self: + if (window_function := self._window_function) is not None: + assert order_by is not None # noqa: S101 + + def func(df: SparkLikeLazyFrame) -> list[Column]: + return [ + window_function(expr, partition_by, order_by) + for expr in self._call(df) + ] + else: + + def func(df: SparkLikeLazyFrame) -> list[Column]: + return [ + expr.over(self._Window.partitionBy(*partition_by)) + for expr in self._call(df) + ] return self.__class__( func, @@ -521,6 +557,24 @@ def _is_nan(_input: Column) -> Column: return self._from_call(_is_nan, "is_nan") + def cum_sum(self, *, reverse: bool) -> Self: + def func( + _input: Column, partition_by: Sequence[str], order_by: Sequence[str] + ) -> Column: + if reverse: + order_by_cols = [self._F.col(x).desc_nulls_last() for x in order_by] + else: + order_by_cols = [self._F.col(x).asc_nulls_first() for x in order_by] + window = ( + self._Window() + .partitionBy(list(partition_by)) + .orderBy(order_by_cols) + .rangeBetween(self._Window().unboundedPreceding, 0) + ) + return self._F.sum(_input).over(window) + + return self._with_window_function(func) + @property def str(self: Self) -> SparkLikeExprStringNamespace: return SparkLikeExprStringNamespace(self) @@ -559,7 +613,6 @@ def list(self: Self) -> SparkLikeExprListNamespace: shift = not_implemented() is_first_distinct = not_implemented() is_last_distinct = not_implemented() - cum_sum = not_implemented() cum_count = not_implemented() cum_min = not_implemented() cum_max = not_implemented() diff --git a/narwhals/_spark_like/typing.py b/narwhals/_spark_like/typing.py new file mode 100644 index 0000000000..b55ef935f7 --- /dev/null +++ b/narwhals/_spark_like/typing.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Protocol +from typing import Sequence + +if TYPE_CHECKING: + from pyspark.sql import Column + + class WindowFunction(Protocol): + def __call__( + self, _input: Column, partition_by: Sequence[str], order_by: Sequence[str] + ) -> Column: ... diff --git a/narwhals/expr.py b/narwhals/expr.py index b829d3789b..e6c6b5bd13 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1551,15 +1551,22 @@ def over( if self._metadata.kind.is_filtration(): msg = "`.over()` can not be used for expressions which change length." raise LengthChangingExprError(msg) + + flat_partition_by = flatten(partition_by) + order_by = [_order_by] if isinstance(_order_by, str) else _order_by + if not flat_partition_by and not _order_by: # pragma: no cover + msg = "At least one of `partition_by` or `order_by` must be specified." + raise ValueError(msg) + kind = ExprKind.TRANSFORM n_open_windows = self._metadata.n_open_windows if _order_by is not None and self._metadata.kind.is_window(): n_open_windows -= 1 metadata = ExprMetadata(kind, n_open_windows=n_open_windows) - flat_partition_by = flatten(partition_by) + return self.__class__( lambda plx: self._to_compliant_expr(plx).over( - flat_partition_by, kind=self._metadata.kind + flat_partition_by, order_by=order_by, kind=self._metadata.kind ), metadata, ) diff --git a/narwhals/typing.py b/narwhals/typing.py index 7ec2179a3a..3c9f1da1d6 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -193,7 +193,9 @@ def replace_strict( *, return_dtype: DType | type[DType] | None, ) -> Self: ... - def over(self: Self, keys: Sequence[str], kind: ExprKind) -> Self: ... + def over( + self: Self, keys: Sequence[str], kind: ExprKind, order_by: Sequence[str] | None + ) -> Self: ... def sample( self, n: int | None, diff --git a/narwhals/utils.py b/narwhals/utils.py index 1d384cc1da..0fa1f312fc 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -884,7 +884,7 @@ def maybe_set_index( msg = "Only one of `column_names` or `index` should be provided" raise ValueError(msg) - if not column_names and not index: + if not column_names and index is None: msg = "Either `column_names` or `index` should be provided" raise ValueError(msg) diff --git a/tests/expr_and_series/cum_sum_test.py b/tests/expr_and_series/cum_sum_test.py index eb1b985b84..8ea8cb5eb5 100644 --- a/tests/expr_and_series/cum_sum_test.py +++ b/tests/expr_and_series/cum_sum_test.py @@ -3,6 +3,8 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import POLARS_VERSION +from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -24,6 +26,190 @@ def test_cum_sum_expr(constructor_eager: ConstructorEager, *, reverse: bool) -> assert_equal_data(result, {name: expected[name]}) +@pytest.mark.parametrize( + ("reverse", "expected_a"), + [ + (False, [3, 2, 6]), + (True, [4, 6, 3]), + ], +) +def test_lazy_cum_sum_grouped( + constructor: Constructor, + request: pytest.FixtureRequest, + *, + reverse: bool, + expected_a: list[int], +) -> None: + if "duckdb" in str(constructor): + # no window function support yet in duckdb + request.applymarker(pytest.mark.xfail) + if "pyarrow_table" in str(constructor): + # grouped window functions not yet supported + request.applymarker(pytest.mark.xfail) + if "modin" in str(constructor): + # bugged + request.applymarker(pytest.mark.xfail) + if "dask" in str(constructor): + # https://github.com/dask/dask/issues/11806 + request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + pytest.skip(reason="too old version") + + df = nw.from_native( + constructor( + { + "a": [1, 2, 3], + "b": [1, 0, 2], + "i": [0, 1, 2], + "g": [1, 1, 1], + } + ) + ) + result = df.with_columns( + nw.col("a").cum_sum(reverse=reverse).over("g", _order_by="b") + ).sort("i") + expected = {"a": expected_a, "b": [1, 0, 2], "i": [0, 1, 2]} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("reverse", "expected_a"), + [ + (False, [10, 6, 14, 11, 16, 9, 4]), + (True, [7, 12, 5, 6, 2, 10, 16]), + ], +) +def test_lazy_cum_sum_ordered_by_nulls( + constructor: Constructor, + request: pytest.FixtureRequest, + *, + reverse: bool, + expected_a: list[int], +) -> None: + if "duckdb" in str(constructor): + # no window function support yet in duckdb + request.applymarker(pytest.mark.xfail) + if "pyarrow_table" in str(constructor): + # grouped window functions not yet supported + request.applymarker(pytest.mark.xfail) + if "modin" in str(constructor): + # bugged + request.applymarker(pytest.mark.xfail) + if "dask" in str(constructor): + # https://github.com/dask/dask/issues/11806 + request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + pytest.skip(reason="too old version") + + df = nw.from_native( + constructor( + { + "a": [1, 2, 3, 1, 2, 3, 4], + "b": [1, -1, 3, 2, 5, 0, None], + "i": [0, 1, 2, 3, 4, 5, 6], + "g": [1, 1, 1, 1, 1, 1, 1], + } + ) + ) + result = df.with_columns( + nw.col("a").cum_sum(reverse=reverse).over("g", _order_by="b") + ).sort("i") + expected = { + "a": expected_a, + "b": [1, -1, 3, 2, 5, 0, None], + "i": [0, 1, 2, 3, 4, 5, 6], + } + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("reverse", "expected_a"), + [ + (False, [3, 2, 6]), + (True, [4, 6, 3]), + ], +) +def test_lazy_cum_sum_ungrouped( + constructor: Constructor, + request: pytest.FixtureRequest, + *, + reverse: bool, + expected_a: list[int], +) -> None: + if "duckdb" in str(constructor): + # no window function support yet in duckdb + request.applymarker(pytest.mark.xfail) + if "dask" in str(constructor) and reverse: + # https://github.com/dask/dask/issues/11802 + request.applymarker(pytest.mark.xfail) + if "modin" in str(constructor): + # probably bugged + request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + pytest.skip(reason="too old version") + + df = nw.from_native( + constructor( + { + "a": [2, 3, 1], + "b": [0, 2, 1], + "i": [1, 2, 0], + } + ) + ).sort("i") + result = df.with_columns( + nw.col("a").cum_sum(reverse=reverse).over(_order_by="b") + ).sort("i") + expected = {"a": expected_a, "b": [1, 0, 2], "i": [0, 1, 2]} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("reverse", "expected_a"), + [ + (False, [10, 6, 14, 11, 16, 9, 4]), + (True, [7, 12, 5, 6, 2, 10, 16]), + ], +) +def test_lazy_cum_sum_ungrouped_ordered_by_nulls( + constructor: Constructor, + request: pytest.FixtureRequest, + *, + reverse: bool, + expected_a: list[int], +) -> None: + if "duckdb" in str(constructor): + # no window function support yet in duckdb + request.applymarker(pytest.mark.xfail) + if "dask" in str(constructor): + # https://github.com/dask/dask/issues/11806 + request.applymarker(pytest.mark.xfail) + if "modin" in str(constructor): + # probably bugged + request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + pytest.skip(reason="too old version") + + df = nw.from_native( + constructor( + { + "a": [1, 2, 3, 1, 2, 3, 4], + "b": [1, -1, 3, 2, 5, 0, None], + "i": [0, 1, 2, 3, 4, 5, 6], + } + ) + ).sort("i") + result = df.with_columns( + nw.col("a").cum_sum(reverse=reverse).over(_order_by="b") + ).sort("i") + expected = { + "a": expected_a, + "b": [1, -1, 3, 2, 5, 0, None], + "i": [0, 1, 2, 3, 4, 5, 6], + } + assert_equal_data(result, expected) + + def test_cum_sum_series(constructor_eager: ConstructorEager) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) result = df.select( diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 5c72dc00ef..38bfdd2cc7 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -10,6 +10,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import LengthChangingExprError from tests.utils import PANDAS_VERSION +from tests.utils import POLARS_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -247,7 +248,7 @@ def test_over_anonymous_cumulative( if "cudf" in str(constructor_eager): # https://github.com/rapidsai/cudf/issues/18159 request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor_eager({"a": [1, 1, 2], "b": [4, 5, 6]})) + df = nw.from_native(constructor_eager({"": [1, 1, 2], "b": [4, 5, 6]})) context = ( pytest.raises(NotImplementedError) if df.implementation.is_pyarrow() @@ -260,12 +261,12 @@ def test_over_anonymous_cumulative( ) with context: result = df.with_columns( - nw.all().cum_sum().over("a").name.suffix("_cum_sum") - ).sort("a", "b") + nw.all().cum_sum().over("").name.suffix("_cum_sum") + ).sort("", "b") expected = { - "a": [1, 1, 2], + "": [1, 1, 2], "b": [4, 5, 6], - "a_cum_sum": [1, 2, 2], + "_cum_sum": [1, 2, 2], "b_cum_sum": [4, 9, 6], } assert_equal_data(result, expected) @@ -413,3 +414,24 @@ def test_unsupported_over() -> None: tbl = pa.table(data) # type: ignore[arg-type] with pytest.raises(NotImplementedError, match="aggregation or literal"): nw.from_native(tbl).select(nw.col("a").shift(1).cum_sum().over("b")) + + +def test_over_without_partition_by( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + pytest.skip() + if "duckdb" in str(constructor): + # windows not yet supported + request.applymarker(pytest.mark.xfail) + if "modin" in str(constructor): + # probably bugged + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor({"a": [1, -1, 2], "i": [0, 2, 1]})) + result = ( + df.with_columns(b=nw.col("a").abs().cum_sum().over(_order_by="i")) + .sort("i") + .select("a", "b", "i") + ) + expected = {"a": [1, 2, -1], "b": [1, 3, 4], "i": [0, 1, 2]} + assert_equal_data(result, expected) diff --git a/tests/series_only/scatter_test.py b/tests/series_only/scatter_test.py index d7f0aa7792..26bb88672d 100644 --- a/tests/series_only/scatter_test.py +++ b/tests/series_only/scatter_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pandas as pd import pytest import narwhals as nw @@ -27,6 +28,13 @@ def test_scatter( assert_equal_data(result, expected) +def test_scatter_indices() -> None: + s = nw.from_native(pd.Series([2, 3, 6], index=[1, 0, 2]), series_only=True) + result = s.scatter([1, 0, 2], s) + expected = pd.Series([3, 2, 6], index=[1, 0, 2]) + pd.testing.assert_series_equal(result.to_native(), expected) + + def test_scatter_unchanged(constructor_eager: ConstructorEager) -> None: df = nw.from_native( constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True