diff --git a/docs/api-reference/expr_list.md b/docs/api-reference/expr_list.md index f44a25d751..2078a1e394 100644 --- a/docs/api-reference/expr_list.md +++ b/docs/api-reference/expr_list.md @@ -11,6 +11,7 @@ - mean - median - min + - sort - sum - unique show_source: false diff --git a/docs/api-reference/series_list.md b/docs/api-reference/series_list.md index 39adbad185..2629758001 100644 --- a/docs/api-reference/series_list.md +++ b/docs/api-reference/series_list.md @@ -11,6 +11,7 @@ - mean - median - min + - sort - sum - unique show_source: false diff --git a/narwhals/_arrow/series_list.py b/narwhals/_arrow/series_list.py index 25e598aedd..d864150654 100644 --- a/narwhals/_arrow/series_list.py +++ b/narwhals/_arrow/series_list.py @@ -5,7 +5,7 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg +from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg, list_sort from narwhals._compliant.any_namespace import ListNamespace from narwhals._utils import not_implemented @@ -35,5 +35,10 @@ def median(self) -> ArrowSeries: def sum(self) -> ArrowSeries: return self.with_native(list_agg(self.native, "sum")) + def sort(self, *, descending: bool, nulls_last: bool) -> ArrowSeries: + return self.with_native( + list_sort(self.native, descending=descending, nulls_last=nulls_last) + ) + unique = not_implemented() contains = not_implemented() diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index dbd8aa6c62..bb9921b97d 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -532,3 +532,35 @@ def list_agg( ) ] ) + + +def list_sort( + array: ChunkedArrayAny, *, descending: bool, nulls_last: bool +) -> ChunkedArrayAny: + sort_direction: Literal["ascending", "descending"] = ( + "descending" if descending else "ascending" + ) + nulls_position: Literal["at_start", "at_end"] = "at_end" if nulls_last else "at_start" + idx, v = "idx", "values" + is_not_sorted = pc.greater(pc.list_value_length(array), lit(0)) + indexed = pa.Table.from_arrays( + [arange(start=0, end=len(array), step=1), array], names=[idx, v] + ) + not_sorted_part = indexed.filter(is_not_sorted) + pass_through = indexed.filter(pc.fill_null(pc.invert(is_not_sorted), lit(True))) # pyright: ignore[reportArgumentType] + exploded = pa.Table.from_arrays( + [pc.list_flatten(array), pc.list_parent_indices(array)], names=[v, idx] + ) + sorted_indices = pc.sort_indices( + exploded, + sort_keys=[(idx, "ascending"), (v, sort_direction)], + null_placement=nulls_position, + ) + offsets = not_sorted_part.column(v).combine_chunks().offsets # type: ignore[attr-defined] + sorted_imploded = pa.ListArray.from_arrays( + offsets, pa.array(exploded.take(sorted_indices).column(v)) + ) + imploded_by_idx = pa.Table.from_arrays( + [not_sorted_part.column(idx), sorted_imploded], names=[idx, v] + ) + return pa.concat_tables([imploded_by_idx, pass_through]).sort_by(idx).column(v) diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index cd76470d89..cf1e6fe805 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -75,6 +75,7 @@ def max(self) -> CompliantT_co: ... def mean(self) -> CompliantT_co: ... def median(self) -> CompliantT_co: ... def sum(self) -> CompliantT_co: ... + def sort(self, *, descending: bool, nulls_last: bool) -> CompliantT_co: ... class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 8aa52f8d61..6e0971b27d 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -1014,6 +1014,11 @@ def median(self) -> EagerExprT: def sum(self) -> EagerExprT: return self.compliant._reuse_series_namespace("list", "sum") + def sort(self, *, descending: bool, nulls_last: bool) -> EagerExprT: + return self.compliant._reuse_series_namespace( + "list", "sort", descending=descending, nulls_last=nulls_last + ) + class CompliantExprNameNamespace( # type: ignore[misc] _ExprNamespace[CompliantExprT_co], diff --git a/narwhals/_duckdb/expr_list.py b/narwhals/_duckdb/expr_list.py index 15b3f70a53..22d694bde9 100644 --- a/narwhals/_duckdb/expr_list.py +++ b/narwhals/_duckdb/expr_list.py @@ -64,3 +64,10 @@ def func(expr: Expression) -> Expression: ) return self.compliant._with_callable(func) + + def sort(self, *, descending: bool, nulls_last: bool) -> DuckDBExpr: + sort_direction = "DESC" if descending else "ASC" + nulls_position = "NULLS LAST" if nulls_last else "NULLS FIRST" + return self.compliant._with_elementwise( + lambda expr: F("list_sort", expr, lit(sort_direction), lit(nulls_position)) + ) diff --git a/narwhals/_ibis/expr_list.py b/narwhals/_ibis/expr_list.py index 64cc053831..aeef63c916 100644 --- a/narwhals/_ibis/expr_list.py +++ b/narwhals/_ibis/expr_list.py @@ -52,4 +52,18 @@ def func(expr: ir.ArrayColumn) -> ir.Value: return self.compliant._with_callable(func) + def sort(self, *, descending: bool, nulls_last: bool) -> IbisExpr: + if descending: + msg = "Descending sort is not currently supported for Ibis." + raise NotImplementedError(msg) + + def func(expr: ir.ArrayColumn) -> ir.ArrayValue: + if nulls_last: + return expr.sort() + expr_no_nulls = expr.filter(lambda x: x.notnull()) + expr_nulls = expr.filter(lambda x: x.isnull()) + return expr_nulls.concat(expr_no_nulls.sort()) + + return self.compliant._with_callable(func) + median = not_implemented() diff --git a/narwhals/_pandas_like/series_list.py b/narwhals/_pandas_like/series_list.py index d1fb085262..2b60119cd2 100644 --- a/narwhals/_pandas_like/series_list.py +++ b/narwhals/_pandas_like/series_list.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING from narwhals._compliant.any_namespace import ListNamespace @@ -34,18 +35,12 @@ def len(self) -> PandasLikeSeries: ) return self.with_native(result.astype(dtype)).alias(self.native.name) - unique = not_implemented() - - contains = not_implemented() - def get(self, index: int) -> PandasLikeSeries: result = self.native.list[index] result.name = self.native.name return self.with_native(result) - def _agg( - self, func: Literal["min", "max", "mean", "approximate_median", "sum"] - ) -> PandasLikeSeries: + def _raise_if_not_pyarrow_backend(self) -> None: dtype_backend = get_dtype_backend( self.native.dtype, self.compliant._implementation ) @@ -53,16 +48,15 @@ def _agg( msg = "Only pyarrow backend is currently supported." raise NotImplementedError(msg) - from narwhals._arrow.utils import list_agg, native_to_narwhals_dtype + def _agg( + self, func: Literal["min", "max", "mean", "approximate_median", "sum"] + ) -> PandasLikeSeries: + self._raise_if_not_pyarrow_backend() + + from narwhals._arrow.utils import list_agg - ca = self.native.array._pa_array - result_arr = list_agg(ca, func) - nw_dtype = native_to_narwhals_dtype(result_arr.type, self.version) - out_dtype = narwhals_to_native_dtype( - nw_dtype, "pyarrow", self.implementation, self.version - ) - result_native = type(self.native)( - result_arr, dtype=out_dtype, index=self.native.index, name=self.native.name + result_native = self.compliant._apply_pyarrow_compute_func( + self.native, partial(list_agg, func=func) ) return self.with_native(result_native) @@ -80,3 +74,16 @@ def median(self) -> PandasLikeSeries: def sum(self) -> PandasLikeSeries: return self._agg("sum") + + def sort(self, *, descending: bool, nulls_last: bool) -> PandasLikeSeries: + self._raise_if_not_pyarrow_backend() + + from narwhals._arrow.utils import list_sort + + result_native = self.compliant._apply_pyarrow_compute_func( + self.native, partial(list_sort, descending=descending, nulls_last=nulls_last) + ) + return self.with_native(result_native) + + unique = not_implemented() + contains = not_implemented() diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index f2081f99dc..26f781d44e 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -370,6 +370,8 @@ def len(self) -> CompliantT: ... min: Method[CompliantT] + sort: Method[CompliantT] + sum: Method[CompliantT] diff --git a/narwhals/_spark_like/expr_list.py b/narwhals/_spark_like/expr_list.py index 1b2ec723bc..898cded7be 100644 --- a/narwhals/_spark_like/expr_list.py +++ b/narwhals/_spark_like/expr_list.py @@ -82,3 +82,15 @@ def func(expr: Column) -> Column: # pragma: no cover ) return self.compliant._with_elementwise(func) + + def sort(self, *, descending: bool, nulls_last: bool) -> SparkLikeExpr: + def func(expr: Column) -> Column: + F = self.compliant._F + if not descending and nulls_last: + return F.array_sort(expr) + if descending and not nulls_last: # pragma: no cover + # https://github.com/eakmanrq/sqlframe/issues/559 + return F.reverse(F.array_sort(expr)) + return F.sort_array(expr, asc=not descending) + + return self.compliant._with_elementwise(func) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index dcc54d4404..48f322c7dd 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -262,3 +262,39 @@ def sum(self) -> ExprT: └────────────────────────┘ """ return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.sum")) + + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> ExprT: + """Sort the lists of the expression. + + Arguments: + descending: Sort in descending order. + nulls_last: Place null values last. + + Examples: + >>> import duckdb + >>> import narwhals as nw + >>> df_native = duckdb.sql( + ... "SELECT * FROM VALUES ([2, -1, 1]), ([3, -4, NULL]) df(a)" + ... ) + >>> df = nw.from_native(df_native) + >>> df.with_columns(a_sorted=nw.col("a").list.sort()) + ┌─────────────────────────────────┐ + | Narwhals LazyFrame | + |---------------------------------| + |┌───────────────┬───────────────┐| + |│ a │ a_sorted │| + |│ int32[] │ int32[] │| + |├───────────────┼───────────────┤| + |│ [2, -1, 1] │ [-1, 1, 2] │| + |│ [3, -4, NULL] │ [NULL, -4, 3] │| + |└───────────────┴───────────────┘| + └─────────────────────────────────┘ + """ + return self._expr._append_node( + ExprNode( + ExprKind.ELEMENTWISE, + "list.sort", + descending=descending, + nulls_last=nulls_last, + ) + ) diff --git a/narwhals/series_list.py b/narwhals/series_list.py index d5ead244e1..d48dc2b18c 100644 --- a/narwhals/series_list.py +++ b/narwhals/series_list.py @@ -219,3 +219,29 @@ def sum(self) -> SeriesT: return self._narwhals_series._with_compliant( self._narwhals_series._compliant_series.list.sum() ) + + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> SeriesT: + """Sort the lists of the series. + + Arguments: + descending: Sort in descending order. + nulls_last: Place null values last. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> s_native = pl.Series([[2, -1, 1], [3, -4, None]]) + >>> s = nw.from_native(s_native, series_only=True) + >>> s.list.sort().to_native() # doctest: +NORMALIZE_WHITESPACE + shape: (2,) + Series: '' [list[i64]] + [ + [-1, 1, 2] + [null, -4, 3] + ] + """ + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.list.sort( + descending=descending, nulls_last=nulls_last + ) + ) diff --git a/tests/expr_and_series/list/sort_test.py b/tests/expr_and_series/list/sort_test.py new file mode 100644 index 0000000000..41a17e1eec --- /dev/null +++ b/tests/expr_and_series/list/sort_test.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from tests.utils import PANDAS_VERSION, POLARS_VERSION, assert_equal_data + +if TYPE_CHECKING: + from typing import Any + + from tests.utils import Constructor, ConstructorEager + + +data = {"a": [[3, 2, 2, 4, -10, None, None], [-1], None, [None, None, None], []]} +expected_desc_nulls_last = [ + [4, 3, 2, 2, -10, None, None], + [-1], + None, + [None, None, None], + [], +] +expected_desc_nulls_first = [ + [None, None, 4, 3, 2, 2, -10], + [-1], + None, + [None, None, None], + [], +] +expected_asc_nulls_last = [ + [-10, 2, 2, 3, 4, None, None], + [-1], + None, + [None, None, None], + [], +] +expected_asc_nulls_first = [ + [None, None, -10, 2, 2, 3, 4], + [-1], + None, + [None, None, None], + [], +] + + +@pytest.mark.parametrize( + ("descending", "nulls_last", "expected"), + [ + (True, True, expected_desc_nulls_last), + (True, False, expected_desc_nulls_first), + (False, True, expected_asc_nulls_last), + (False, False, expected_asc_nulls_first), + ], +) +def test_sort_expr_args( + request: pytest.FixtureRequest, + constructor: Constructor, + descending: bool, # noqa: FBT001 + nulls_last: bool, # noqa: FBT001 + expected: list[Any], +) -> None: + if any(backend in str(constructor) for backend in ("dask", "cudf")): + request.applymarker(pytest.mark.xfail) + if "ibis" in str(constructor) and descending: + # https://github.com/ibis-project/ibis/issues/11735 + request.applymarker(pytest.mark.xfail) + if "sqlframe" in str(constructor) and not nulls_last: + # https://github.com/eakmanrq/sqlframe/issues/559 + # https://github.com/eakmanrq/sqlframe/issues/560 + request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 5): + pytest.skip() + if "pandas" in str(constructor): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + result = nw.from_native(constructor(data)).select( + nw.col("a") + .cast(nw.List(nw.Int32())) + .list.sort(descending=descending, nulls_last=nulls_last) + ) + assert_equal_data(result, {"a": expected}) + + +@pytest.mark.parametrize( + ("descending", "nulls_last", "expected"), + [ + (True, True, expected_desc_nulls_last), + (True, False, expected_desc_nulls_first), + (False, True, expected_asc_nulls_last), + (False, False, expected_asc_nulls_first), + ], +) +def test_sort_series_args( + request: pytest.FixtureRequest, + constructor_eager: ConstructorEager, + descending: bool, # noqa: FBT001 + nulls_last: bool, # noqa: FBT001 + expected: list[Any], +) -> None: + if any(backend in str(constructor_eager) for backend in ("dask", "cudf")): + request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): + pytest.skip() + if "pandas" in str(constructor_eager): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + df = nw.from_native(constructor_eager(data), eager_only=True) + result = ( + df["a"] + .cast(nw.List(nw.Int32())) + .list.sort(descending=descending, nulls_last=nulls_last) + ) + assert_equal_data({"a": result}, {"a": expected})