From 97af862ad578f60096cbea8c80e5e26d74493812 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 15 Dec 2025 11:47:00 +0000 Subject: [PATCH 01/11] feat: add `list_sort` --- docs/api-reference/expr_list.md | 1 + docs/api-reference/series_list.md | 1 + narwhals/_compliant/any_namespace.py | 1 + narwhals/_compliant/expr.py | 5 + narwhals/_duckdb/expr_list.py | 7 ++ narwhals/_ibis/expr_list.py | 14 +++ narwhals/_polars/utils.py | 2 + narwhals/_spark_like/expr_list.py | 12 ++ narwhals/expr_list.py | 36 ++++++ narwhals/series_list.py | 26 +++++ tests/expr_and_series/list/sort_test.py | 141 ++++++++++++++++++++++++ 11 files changed, 246 insertions(+) create mode 100644 tests/expr_and_series/list/sort_test.py diff --git a/docs/api-reference/expr_list.md b/docs/api-reference/expr_list.md index f44a25d751..2078a1e394 100644 --- a/docs/api-reference/expr_list.md +++ b/docs/api-reference/expr_list.md @@ -11,6 +11,7 @@ - mean - median - min + - sort - sum - unique show_source: false diff --git a/docs/api-reference/series_list.md b/docs/api-reference/series_list.md index 39adbad185..2629758001 100644 --- a/docs/api-reference/series_list.md +++ b/docs/api-reference/series_list.md @@ -11,6 +11,7 @@ - mean - median - min + - sort - sum - unique show_source: false diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index cd76470d89..cf1e6fe805 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -75,6 +75,7 @@ def max(self) -> CompliantT_co: ... def mean(self) -> CompliantT_co: ... def median(self) -> CompliantT_co: ... def sum(self) -> CompliantT_co: ... + def sort(self, *, descending: bool, nulls_last: bool) -> CompliantT_co: ... class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 8aa52f8d61..6e0971b27d 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -1014,6 +1014,11 @@ def median(self) -> EagerExprT: def sum(self) -> EagerExprT: return self.compliant._reuse_series_namespace("list", "sum") + def sort(self, *, descending: bool, nulls_last: bool) -> EagerExprT: + return self.compliant._reuse_series_namespace( + "list", "sort", descending=descending, nulls_last=nulls_last + ) + class CompliantExprNameNamespace( # type: ignore[misc] _ExprNamespace[CompliantExprT_co], diff --git a/narwhals/_duckdb/expr_list.py b/narwhals/_duckdb/expr_list.py index 15b3f70a53..22d694bde9 100644 --- a/narwhals/_duckdb/expr_list.py +++ b/narwhals/_duckdb/expr_list.py @@ -64,3 +64,10 @@ def func(expr: Expression) -> Expression: ) return self.compliant._with_callable(func) + + def sort(self, *, descending: bool, nulls_last: bool) -> DuckDBExpr: + sort_direction = "DESC" if descending else "ASC" + nulls_position = "NULLS LAST" if nulls_last else "NULLS FIRST" + return self.compliant._with_elementwise( + lambda expr: F("list_sort", expr, lit(sort_direction), lit(nulls_position)) + ) diff --git a/narwhals/_ibis/expr_list.py b/narwhals/_ibis/expr_list.py index 64cc053831..aeef63c916 100644 --- a/narwhals/_ibis/expr_list.py +++ b/narwhals/_ibis/expr_list.py @@ -52,4 +52,18 @@ def func(expr: ir.ArrayColumn) -> ir.Value: return self.compliant._with_callable(func) + def sort(self, *, descending: bool, nulls_last: bool) -> IbisExpr: + if descending: + msg = "Descending sort is not currently supported for Ibis." + raise NotImplementedError(msg) + + def func(expr: ir.ArrayColumn) -> ir.ArrayValue: + if nulls_last: + return expr.sort() + expr_no_nulls = expr.filter(lambda x: x.notnull()) + expr_nulls = expr.filter(lambda x: x.isnull()) + return expr_nulls.concat(expr_no_nulls.sort()) + + return self.compliant._with_callable(func) + median = not_implemented() diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index f2081f99dc..26f781d44e 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -370,6 +370,8 @@ def len(self) -> CompliantT: ... min: Method[CompliantT] + sort: Method[CompliantT] + sum: Method[CompliantT] diff --git a/narwhals/_spark_like/expr_list.py b/narwhals/_spark_like/expr_list.py index 1b2ec723bc..898cded7be 100644 --- a/narwhals/_spark_like/expr_list.py +++ b/narwhals/_spark_like/expr_list.py @@ -82,3 +82,15 @@ def func(expr: Column) -> Column: # pragma: no cover ) return self.compliant._with_elementwise(func) + + def sort(self, *, descending: bool, nulls_last: bool) -> SparkLikeExpr: + def func(expr: Column) -> Column: + F = self.compliant._F + if not descending and nulls_last: + return F.array_sort(expr) + if descending and not nulls_last: # pragma: no cover + # https://github.com/eakmanrq/sqlframe/issues/559 + return F.reverse(F.array_sort(expr)) + return F.sort_array(expr, asc=not descending) + + return self.compliant._with_elementwise(func) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index dcc54d4404..48f322c7dd 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -262,3 +262,39 @@ def sum(self) -> ExprT: └────────────────────────┘ """ return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.sum")) + + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> ExprT: + """Sort the lists of the expression. + + Arguments: + descending: Sort in descending order. + nulls_last: Place null values last. + + Examples: + >>> import duckdb + >>> import narwhals as nw + >>> df_native = duckdb.sql( + ... "SELECT * FROM VALUES ([2, -1, 1]), ([3, -4, NULL]) df(a)" + ... ) + >>> df = nw.from_native(df_native) + >>> df.with_columns(a_sorted=nw.col("a").list.sort()) + ┌─────────────────────────────────┐ + | Narwhals LazyFrame | + |---------------------------------| + |┌───────────────┬───────────────┐| + |│ a │ a_sorted │| + |│ int32[] │ int32[] │| + |├───────────────┼───────────────┤| + |│ [2, -1, 1] │ [-1, 1, 2] │| + |│ [3, -4, NULL] │ [NULL, -4, 3] │| + |└───────────────┴───────────────┘| + └─────────────────────────────────┘ + """ + return self._expr._append_node( + ExprNode( + ExprKind.ELEMENTWISE, + "list.sort", + descending=descending, + nulls_last=nulls_last, + ) + ) diff --git a/narwhals/series_list.py b/narwhals/series_list.py index d5ead244e1..d48dc2b18c 100644 --- a/narwhals/series_list.py +++ b/narwhals/series_list.py @@ -219,3 +219,29 @@ def sum(self) -> SeriesT: return self._narwhals_series._with_compliant( self._narwhals_series._compliant_series.list.sum() ) + + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> SeriesT: + """Sort the lists of the series. + + Arguments: + descending: Sort in descending order. + nulls_last: Place null values last. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> s_native = pl.Series([[2, -1, 1], [3, -4, None]]) + >>> s = nw.from_native(s_native, series_only=True) + >>> s.list.sort().to_native() # doctest: +NORMALIZE_WHITESPACE + shape: (2,) + Series: '' [list[i64]] + [ + [-1, 1, 2] + [null, -4, 3] + ] + """ + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.list.sort( + descending=descending, nulls_last=nulls_last + ) + ) diff --git a/tests/expr_and_series/list/sort_test.py b/tests/expr_and_series/list/sort_test.py new file mode 100644 index 0000000000..38fa20774c --- /dev/null +++ b/tests/expr_and_series/list/sort_test.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from tests.utils import assert_equal_data + +if TYPE_CHECKING: + from typing import Any + + from tests.utils import Constructor, ConstructorEager + + +data = {"a": [[3, 2, 2, 4, -10, None, None], [-1], None, [None, None, None], []]} +expected_desc_nulls_last = [ + [4, 3, 2, 2, -10, None, None], + [-1], + None, + [None, None, None], + [], +] +expected_desc_nulls_first = [ + [None, None, 4, 3, 2, 2, -10], + [-1], + None, + [None, None, None], + [], +] +expected_asc_nulls_last = [ + [-10, 2, 2, 3, 4, None, None], + [-1], + None, + [None, None, None], + [], +] +expected_asc_nulls_first = [ + [None, None, -10, 2, 2, 3, 4], + [-1], + None, + [None, None, None], + [], +] + + +def test_sort_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if any( + backend in str(constructor) for backend in ("dask", "cudf", "pandas", "pyarrow") + ): + # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 + request.applymarker(pytest.mark.xfail) + if "sqlframe" in str(constructor): + # https://github.com/eakmanrq/sqlframe/issues/559 + # https://github.com/eakmanrq/sqlframe/issues/560 + request.applymarker(pytest.mark.xfail) + result = nw.from_native(constructor(data)).select( + nw.col("a").cast(nw.List(nw.Int32())).list.sort() + ) + assert_equal_data(result, {"a": expected_asc_nulls_first}) + + +@pytest.mark.parametrize( + ("descending", "nulls_last", "expected"), + [ + (True, True, expected_desc_nulls_last), + (True, False, expected_desc_nulls_first), + (False, True, expected_asc_nulls_last), + (False, False, expected_asc_nulls_first), + ], +) +def test_sort_expr_args( + request: pytest.FixtureRequest, + constructor: Constructor, + descending: bool, # noqa: FBT001 + nulls_last: bool, # noqa: FBT001 + expected: list[Any], +) -> None: + if any( + backend in str(constructor) for backend in ("dask", "cudf", "pandas", "pyarrow") + ): + # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 + request.applymarker(pytest.mark.xfail) + if "ibis" in str(constructor) and descending: + # https://github.com/ibis-project/ibis/issues/11735 + request.applymarker(pytest.mark.xfail) + if "sqlframe" in str(constructor) and not nulls_last: + # https://github.com/eakmanrq/sqlframe/issues/559 + # https://github.com/eakmanrq/sqlframe/issues/560 + request.applymarker(pytest.mark.xfail) + result = nw.from_native(constructor(data)).select( + nw.col("a") + .cast(nw.List(nw.Int32())) + .list.sort(descending=descending, nulls_last=nulls_last) + ) + assert_equal_data(result, {"a": expected}) + + +def test_sort_series( + request: pytest.FixtureRequest, constructor_eager: ConstructorEager +) -> None: + if any( + backend in str(constructor_eager) + for backend in ("dask", "cudf", "pandas", "pyarrow") + ): + # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df["a"].cast(nw.List(nw.Int32())).list.sort() + assert_equal_data({"a": result}, {"a": expected_asc_nulls_first}) + + +@pytest.mark.parametrize( + ("descending", "nulls_last", "expected"), + [ + (True, True, expected_desc_nulls_last), + (True, False, expected_desc_nulls_first), + (False, True, expected_asc_nulls_last), + (False, False, expected_asc_nulls_first), + ], +) +def test_sort_series_args( + request: pytest.FixtureRequest, + constructor_eager: ConstructorEager, + descending: bool, # noqa: FBT001 + nulls_last: bool, # noqa: FBT001 + expected: list[Any], +) -> None: + if any( + backend in str(constructor_eager) + for backend in ("dask", "cudf", "pandas", "pyarrow") + ): + # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor_eager(data), eager_only=True) + result = ( + df["a"] + .cast(nw.List(nw.Int32())) + .list.sort(descending=descending, nulls_last=nulls_last) + ) + assert_equal_data({"a": result}, {"a": expected}) From 7c2f1f3e066462944a8951cf4500c1a550bb7730 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:09:19 +0000 Subject: [PATCH 02/11] add not_implemented to pyarrow and pandas --- narwhals/_arrow/series_list.py | 1 + narwhals/_pandas_like/series_list.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/narwhals/_arrow/series_list.py b/narwhals/_arrow/series_list.py index 25e598aedd..7e367239e0 100644 --- a/narwhals/_arrow/series_list.py +++ b/narwhals/_arrow/series_list.py @@ -37,3 +37,4 @@ def sum(self) -> ArrowSeries: unique = not_implemented() contains = not_implemented() + sort = not_implemented() diff --git a/narwhals/_pandas_like/series_list.py b/narwhals/_pandas_like/series_list.py index d1fb085262..dc698bfbd5 100644 --- a/narwhals/_pandas_like/series_list.py +++ b/narwhals/_pandas_like/series_list.py @@ -34,10 +34,6 @@ def len(self) -> PandasLikeSeries: ) return self.with_native(result.astype(dtype)).alias(self.native.name) - unique = not_implemented() - - contains = not_implemented() - def get(self, index: int) -> PandasLikeSeries: result = self.native.list[index] result.name = self.native.name @@ -80,3 +76,7 @@ def median(self) -> PandasLikeSeries: def sum(self) -> PandasLikeSeries: return self._agg("sum") + + unique = not_implemented() + contains = not_implemented() + sort = not_implemented() From 42cfcd5b843f58c90e8125d71dac6c32597ea7a5 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:23:11 +0000 Subject: [PATCH 03/11] skip old polars as `nulls_last` arg was not implemented --- tests/expr_and_series/list/sort_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/expr_and_series/list/sort_test.py b/tests/expr_and_series/list/sort_test.py index 38fa20774c..389ea684e8 100644 --- a/tests/expr_and_series/list/sort_test.py +++ b/tests/expr_and_series/list/sort_test.py @@ -5,7 +5,7 @@ import pytest import narwhals as nw -from tests.utils import assert_equal_data +from tests.utils import POLARS_VERSION, assert_equal_data if TYPE_CHECKING: from typing import Any @@ -88,6 +88,8 @@ def test_sort_expr_args( # https://github.com/eakmanrq/sqlframe/issues/559 # https://github.com/eakmanrq/sqlframe/issues/560 request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 5): + pytest.skip() result = nw.from_native(constructor(data)).select( nw.col("a") .cast(nw.List(nw.Int32())) @@ -132,6 +134,8 @@ def test_sort_series_args( ): # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): + pytest.skip() df = nw.from_native(constructor_eager(data), eager_only=True) result = ( df["a"] From 4402e98b1e7d2bef0723a3697abfa9525cceed57 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:28:18 +0000 Subject: [PATCH 04/11] skip old polars again --- tests/expr_and_series/list/sort_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/expr_and_series/list/sort_test.py b/tests/expr_and_series/list/sort_test.py index 389ea684e8..0b96c5bdfd 100644 --- a/tests/expr_and_series/list/sort_test.py +++ b/tests/expr_and_series/list/sort_test.py @@ -54,6 +54,8 @@ def test_sort_expr(request: pytest.FixtureRequest, constructor: Constructor) -> # https://github.com/eakmanrq/sqlframe/issues/559 # https://github.com/eakmanrq/sqlframe/issues/560 request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 5): + pytest.skip() result = nw.from_native(constructor(data)).select( nw.col("a").cast(nw.List(nw.Int32())).list.sort() ) @@ -107,6 +109,8 @@ def test_sort_series( ): # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 request.applymarker(pytest.mark.xfail) + if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): + pytest.skip() df = nw.from_native(constructor_eager(data), eager_only=True) result = df["a"].cast(nw.List(nw.Int32())).list.sort() assert_equal_data({"a": result}, {"a": expected_asc_nulls_first}) From 5ce0cf9aa2995d22f3ab95b0d340957daa7bcc2f Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 15 Dec 2025 15:59:35 +0000 Subject: [PATCH 05/11] update pyarrow issue link --- tests/expr_and_series/list/sort_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/expr_and_series/list/sort_test.py b/tests/expr_and_series/list/sort_test.py index 0b96c5bdfd..37fd9290ff 100644 --- a/tests/expr_and_series/list/sort_test.py +++ b/tests/expr_and_series/list/sort_test.py @@ -48,7 +48,7 @@ def test_sort_expr(request: pytest.FixtureRequest, constructor: Constructor) -> if any( backend in str(constructor) for backend in ("dask", "cudf", "pandas", "pyarrow") ): - # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 + # PyArrow issue: https://github.com/apache/arrow/issues/48060 request.applymarker(pytest.mark.xfail) if "sqlframe" in str(constructor): # https://github.com/eakmanrq/sqlframe/issues/559 @@ -81,7 +81,7 @@ def test_sort_expr_args( if any( backend in str(constructor) for backend in ("dask", "cudf", "pandas", "pyarrow") ): - # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 + # PyArrow issue: https://github.com/apache/arrow/issues/48060 request.applymarker(pytest.mark.xfail) if "ibis" in str(constructor) and descending: # https://github.com/ibis-project/ibis/issues/11735 @@ -107,7 +107,7 @@ def test_sort_series( backend in str(constructor_eager) for backend in ("dask", "cudf", "pandas", "pyarrow") ): - # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 + # PyArrow issue: https://github.com/apache/arrow/issues/48060 request.applymarker(pytest.mark.xfail) if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): pytest.skip() @@ -136,7 +136,7 @@ def test_sort_series_args( backend in str(constructor_eager) for backend in ("dask", "cudf", "pandas", "pyarrow") ): - # PyArrow issue: https://github.com/apache/arrow/issues/48060#issuecomment-3510993921 + # PyArrow issue: https://github.com/apache/arrow/issues/48060 request.applymarker(pytest.mark.xfail) if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): pytest.skip() From f1308bbc120e24c55f37c61677e029f73a85e85e Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:58:52 +0000 Subject: [PATCH 06/11] implement pyarrow sort --- narwhals/_arrow/series_list.py | 8 +++++-- narwhals/_arrow/utils.py | 31 +++++++++++++++++++++++++ narwhals/_pandas_like/series_list.py | 22 +++++++++++++++++- tests/expr_and_series/list/sort_test.py | 22 ++++-------------- 4 files changed, 62 insertions(+), 21 deletions(-) diff --git a/narwhals/_arrow/series_list.py b/narwhals/_arrow/series_list.py index 7e367239e0..d864150654 100644 --- a/narwhals/_arrow/series_list.py +++ b/narwhals/_arrow/series_list.py @@ -5,7 +5,7 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg +from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg, list_sort from narwhals._compliant.any_namespace import ListNamespace from narwhals._utils import not_implemented @@ -35,6 +35,10 @@ def median(self) -> ArrowSeries: def sum(self) -> ArrowSeries: return self.with_native(list_agg(self.native, "sum")) + def sort(self, *, descending: bool, nulls_last: bool) -> ArrowSeries: + return self.with_native( + list_sort(self.native, descending=descending, nulls_last=nulls_last) + ) + unique = not_implemented() contains = not_implemented() - sort = not_implemented() diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index dbd8aa6c62..b5f80a12cc 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -532,3 +532,34 @@ def list_agg( ) ] ) + + +def list_sort( + array: ChunkedArrayAny, *, descending: bool, nulls_last: bool +) -> ChunkedArrayAny: + sort_direction: Literal["ascending", "descending"] = ( + "descending" if descending else "ascending" + ) + nulls_position: Literal["at_start", "at_end"] = "at_end" if nulls_last else "at_start" + idx, v = "idx", "values" + len_gt_0 = pc.greater(pc.list_value_length(array), lit(0)) + arange = pa.arange(0, len(array)) # type: ignore[attr-defined] + indexed = pa.Table.from_arrays([arange, array], names=[idx, v]) + valid = indexed.filter(len_gt_0) + invalid = indexed.filter(pc.or_kleene(array.is_null(), pc.invert(len_gt_0))) + agg = pa.Table.from_arrays( + [pc.list_flatten(array), pc.list_parent_indices(array)], names=[v, idx] + ) + sorted_indices = pc.sort_indices( + agg, + sort_keys=[(idx, "ascending"), (v, sort_direction)], + null_placement=nulls_position, + ) + offsets = valid.column(v).combine_chunks().offsets # type: ignore[attr-defined] + sorted_imploded = pa.ListArray.from_arrays( + offsets, pa.array(agg.take(sorted_indices).column(v)) + ) + valid_finished = pa.Table.from_arrays( + [valid.column(idx), sorted_imploded], names=[idx, v] + ) + return pa.concat_tables([valid_finished, invalid]).sort_by(idx).column(v) diff --git a/narwhals/_pandas_like/series_list.py b/narwhals/_pandas_like/series_list.py index dc698bfbd5..699a4b2387 100644 --- a/narwhals/_pandas_like/series_list.py +++ b/narwhals/_pandas_like/series_list.py @@ -77,6 +77,26 @@ def median(self) -> PandasLikeSeries: def sum(self) -> PandasLikeSeries: return self._agg("sum") + def sort(self, *, descending: bool, nulls_last: bool) -> PandasLikeSeries: + dtype_backend = get_dtype_backend( + self.native.dtype, self.compliant._implementation + ) + if dtype_backend != "pyarrow": # pragma: no cover + msg = "Only pyarrow backend is currently supported." + raise NotImplementedError(msg) + + from narwhals._arrow.utils import list_sort, native_to_narwhals_dtype + + ca = self.native.array._pa_array + result_arr = list_sort(ca, descending=descending, nulls_last=nulls_last) + nw_dtype = native_to_narwhals_dtype(result_arr.type, self.version) + out_dtype = narwhals_to_native_dtype( + nw_dtype, "pyarrow", self.implementation, self.version + ) + result_native = type(self.native)( + result_arr, dtype=out_dtype, index=self.native.index, name=self.native.name + ) + return self.with_native(result_native) + unique = not_implemented() contains = not_implemented() - sort = not_implemented() diff --git a/tests/expr_and_series/list/sort_test.py b/tests/expr_and_series/list/sort_test.py index 37fd9290ff..5803097b62 100644 --- a/tests/expr_and_series/list/sort_test.py +++ b/tests/expr_and_series/list/sort_test.py @@ -45,10 +45,7 @@ def test_sort_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any( - backend in str(constructor) for backend in ("dask", "cudf", "pandas", "pyarrow") - ): - # PyArrow issue: https://github.com/apache/arrow/issues/48060 + if any(backend in str(constructor) for backend in ("dask", "cudf")): request.applymarker(pytest.mark.xfail) if "sqlframe" in str(constructor): # https://github.com/eakmanrq/sqlframe/issues/559 @@ -78,10 +75,7 @@ def test_sort_expr_args( nulls_last: bool, # noqa: FBT001 expected: list[Any], ) -> None: - if any( - backend in str(constructor) for backend in ("dask", "cudf", "pandas", "pyarrow") - ): - # PyArrow issue: https://github.com/apache/arrow/issues/48060 + if any(backend in str(constructor) for backend in ("dask", "cudf")): request.applymarker(pytest.mark.xfail) if "ibis" in str(constructor) and descending: # https://github.com/ibis-project/ibis/issues/11735 @@ -103,11 +97,7 @@ def test_sort_expr_args( def test_sort_series( request: pytest.FixtureRequest, constructor_eager: ConstructorEager ) -> None: - if any( - backend in str(constructor_eager) - for backend in ("dask", "cudf", "pandas", "pyarrow") - ): - # PyArrow issue: https://github.com/apache/arrow/issues/48060 + if any(backend in str(constructor_eager) for backend in ("dask", "cudf")): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): pytest.skip() @@ -132,11 +122,7 @@ def test_sort_series_args( nulls_last: bool, # noqa: FBT001 expected: list[Any], ) -> None: - if any( - backend in str(constructor_eager) - for backend in ("dask", "cudf", "pandas", "pyarrow") - ): - # PyArrow issue: https://github.com/apache/arrow/issues/48060 + if any(backend in str(constructor_eager) for backend in ("dask", "cudf")): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): pytest.skip() From ea596942298bedfea24a650d2ced96f14c48f3f3 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Thu, 18 Dec 2025 15:58:17 +0000 Subject: [PATCH 07/11] use arange from utils --- narwhals/_arrow/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index b5f80a12cc..38c206fadf 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -543,8 +543,9 @@ def list_sort( nulls_position: Literal["at_start", "at_end"] = "at_end" if nulls_last else "at_start" idx, v = "idx", "values" len_gt_0 = pc.greater(pc.list_value_length(array), lit(0)) - arange = pa.arange(0, len(array)) # type: ignore[attr-defined] - indexed = pa.Table.from_arrays([arange, array], names=[idx, v]) + indexed = pa.Table.from_arrays( + [arange(start=0, end=len(array), step=1), array], names=[idx, v] + ) valid = indexed.filter(len_gt_0) invalid = indexed.filter(pc.or_kleene(array.is_null(), pc.invert(len_gt_0))) agg = pa.Table.from_arrays( From 58b86aded0faa598f851a77f7972c64dc7ce3504 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Thu, 18 Dec 2025 15:59:25 +0000 Subject: [PATCH 08/11] skip old pandas --- tests/expr_and_series/list/sort_test.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/expr_and_series/list/sort_test.py b/tests/expr_and_series/list/sort_test.py index 5803097b62..dd9bc85826 100644 --- a/tests/expr_and_series/list/sort_test.py +++ b/tests/expr_and_series/list/sort_test.py @@ -5,7 +5,7 @@ import pytest import narwhals as nw -from tests.utils import POLARS_VERSION, assert_equal_data +from tests.utils import PANDAS_VERSION, POLARS_VERSION, assert_equal_data if TYPE_CHECKING: from typing import Any @@ -53,6 +53,10 @@ def test_sort_expr(request: pytest.FixtureRequest, constructor: Constructor) -> request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 5): pytest.skip() + if "pandas" in str(constructor): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") result = nw.from_native(constructor(data)).select( nw.col("a").cast(nw.List(nw.Int32())).list.sort() ) @@ -86,6 +90,10 @@ def test_sort_expr_args( request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 5): pytest.skip() + if "pandas" in str(constructor): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") result = nw.from_native(constructor(data)).select( nw.col("a") .cast(nw.List(nw.Int32())) @@ -101,6 +109,10 @@ def test_sort_series( request.applymarker(pytest.mark.xfail) if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): pytest.skip() + if "pandas" in str(constructor_eager): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") df = nw.from_native(constructor_eager(data), eager_only=True) result = df["a"].cast(nw.List(nw.Int32())).list.sort() assert_equal_data({"a": result}, {"a": expected_asc_nulls_first}) @@ -126,6 +138,10 @@ def test_sort_series_args( request.applymarker(pytest.mark.xfail) if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): pytest.skip() + if "pandas" in str(constructor_eager): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") df = nw.from_native(constructor_eager(data), eager_only=True) result = ( df["a"] From aa167980595359ffca768047a9c8e56fd2b22d07 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Fri, 19 Dec 2025 15:29:32 +0000 Subject: [PATCH 09/11] change var names --- narwhals/_arrow/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 38c206fadf..6105914886 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -542,25 +542,25 @@ def list_sort( ) nulls_position: Literal["at_start", "at_end"] = "at_end" if nulls_last else "at_start" idx, v = "idx", "values" - len_gt_0 = pc.greater(pc.list_value_length(array), lit(0)) + is_not_sorted = pc.greater(pc.list_value_length(array), lit(0)) indexed = pa.Table.from_arrays( [arange(start=0, end=len(array), step=1), array], names=[idx, v] ) - valid = indexed.filter(len_gt_0) - invalid = indexed.filter(pc.or_kleene(array.is_null(), pc.invert(len_gt_0))) - agg = pa.Table.from_arrays( + not_sorted_part = indexed.filter(is_not_sorted) + pass_through = indexed.filter(pc.fill_null(pc.invert(is_not_sorted), True)) # pyright: ignore[reportArgumentType] + exploded = pa.Table.from_arrays( [pc.list_flatten(array), pc.list_parent_indices(array)], names=[v, idx] ) sorted_indices = pc.sort_indices( - agg, + exploded, sort_keys=[(idx, "ascending"), (v, sort_direction)], null_placement=nulls_position, ) - offsets = valid.column(v).combine_chunks().offsets # type: ignore[attr-defined] + offsets = not_sorted_part.column(v).combine_chunks().offsets # type: ignore[attr-defined] sorted_imploded = pa.ListArray.from_arrays( - offsets, pa.array(agg.take(sorted_indices).column(v)) + offsets, pa.array(exploded.take(sorted_indices).column(v)) ) - valid_finished = pa.Table.from_arrays( - [valid.column(idx), sorted_imploded], names=[idx, v] + imploded_by_idx = pa.Table.from_arrays( + [not_sorted_part.column(idx), sorted_imploded], names=[idx, v] ) - return pa.concat_tables([valid_finished, invalid]).sort_by(idx).column(v) + return pa.concat_tables([imploded_by_idx, pass_through]).sort_by(idx).column(v) From 5343e0fe6a98e50459389937ab17404627c6dd68 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Fri, 19 Dec 2025 15:35:01 +0000 Subject: [PATCH 10/11] fix typing --- narwhals/_arrow/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 6105914886..bb9921b97d 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -547,7 +547,7 @@ def list_sort( [arange(start=0, end=len(array), step=1), array], names=[idx, v] ) not_sorted_part = indexed.filter(is_not_sorted) - pass_through = indexed.filter(pc.fill_null(pc.invert(is_not_sorted), True)) # pyright: ignore[reportArgumentType] + pass_through = indexed.filter(pc.fill_null(pc.invert(is_not_sorted), lit(True))) # pyright: ignore[reportArgumentType] exploded = pa.Table.from_arrays( [pc.list_flatten(array), pc.list_parent_indices(array)], names=[v, idx] ) From 4107ad171233af91e7806dcf821e098d1a902f26 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Wed, 24 Dec 2025 11:11:25 +0000 Subject: [PATCH 11/11] refactor pandas, remove tests fot the default params --- narwhals/_pandas_like/series_list.py | 41 +++++++++---------------- tests/expr_and_series/list/sort_test.py | 35 --------------------- 2 files changed, 14 insertions(+), 62 deletions(-) diff --git a/narwhals/_pandas_like/series_list.py b/narwhals/_pandas_like/series_list.py index 699a4b2387..2b60119cd2 100644 --- a/narwhals/_pandas_like/series_list.py +++ b/narwhals/_pandas_like/series_list.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING from narwhals._compliant.any_namespace import ListNamespace @@ -39,9 +40,7 @@ def get(self, index: int) -> PandasLikeSeries: result.name = self.native.name return self.with_native(result) - def _agg( - self, func: Literal["min", "max", "mean", "approximate_median", "sum"] - ) -> PandasLikeSeries: + def _raise_if_not_pyarrow_backend(self) -> None: dtype_backend = get_dtype_backend( self.native.dtype, self.compliant._implementation ) @@ -49,16 +48,15 @@ def _agg( msg = "Only pyarrow backend is currently supported." raise NotImplementedError(msg) - from narwhals._arrow.utils import list_agg, native_to_narwhals_dtype + def _agg( + self, func: Literal["min", "max", "mean", "approximate_median", "sum"] + ) -> PandasLikeSeries: + self._raise_if_not_pyarrow_backend() + + from narwhals._arrow.utils import list_agg - ca = self.native.array._pa_array - result_arr = list_agg(ca, func) - nw_dtype = native_to_narwhals_dtype(result_arr.type, self.version) - out_dtype = narwhals_to_native_dtype( - nw_dtype, "pyarrow", self.implementation, self.version - ) - result_native = type(self.native)( - result_arr, dtype=out_dtype, index=self.native.index, name=self.native.name + result_native = self.compliant._apply_pyarrow_compute_func( + self.native, partial(list_agg, func=func) ) return self.with_native(result_native) @@ -78,23 +76,12 @@ def sum(self) -> PandasLikeSeries: return self._agg("sum") def sort(self, *, descending: bool, nulls_last: bool) -> PandasLikeSeries: - dtype_backend = get_dtype_backend( - self.native.dtype, self.compliant._implementation - ) - if dtype_backend != "pyarrow": # pragma: no cover - msg = "Only pyarrow backend is currently supported." - raise NotImplementedError(msg) + self._raise_if_not_pyarrow_backend() - from narwhals._arrow.utils import list_sort, native_to_narwhals_dtype + from narwhals._arrow.utils import list_sort - ca = self.native.array._pa_array - result_arr = list_sort(ca, descending=descending, nulls_last=nulls_last) - nw_dtype = native_to_narwhals_dtype(result_arr.type, self.version) - out_dtype = narwhals_to_native_dtype( - nw_dtype, "pyarrow", self.implementation, self.version - ) - result_native = type(self.native)( - result_arr, dtype=out_dtype, index=self.native.index, name=self.native.name + result_native = self.compliant._apply_pyarrow_compute_func( + self.native, partial(list_sort, descending=descending, nulls_last=nulls_last) ) return self.with_native(result_native) diff --git a/tests/expr_and_series/list/sort_test.py b/tests/expr_and_series/list/sort_test.py index dd9bc85826..41a17e1eec 100644 --- a/tests/expr_and_series/list/sort_test.py +++ b/tests/expr_and_series/list/sort_test.py @@ -44,25 +44,6 @@ ] -def test_sort_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any(backend in str(constructor) for backend in ("dask", "cudf")): - request.applymarker(pytest.mark.xfail) - if "sqlframe" in str(constructor): - # https://github.com/eakmanrq/sqlframe/issues/559 - # https://github.com/eakmanrq/sqlframe/issues/560 - request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 5): - pytest.skip() - if "pandas" in str(constructor): - if PANDAS_VERSION < (2, 2): - pytest.skip() - pytest.importorskip("pyarrow") - result = nw.from_native(constructor(data)).select( - nw.col("a").cast(nw.List(nw.Int32())).list.sort() - ) - assert_equal_data(result, {"a": expected_asc_nulls_first}) - - @pytest.mark.parametrize( ("descending", "nulls_last", "expected"), [ @@ -102,22 +83,6 @@ def test_sort_expr_args( assert_equal_data(result, {"a": expected}) -def test_sort_series( - request: pytest.FixtureRequest, constructor_eager: ConstructorEager -) -> None: - if any(backend in str(constructor_eager) for backend in ("dask", "cudf")): - request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 5): - pytest.skip() - if "pandas" in str(constructor_eager): - if PANDAS_VERSION < (2, 2): - pytest.skip() - pytest.importorskip("pyarrow") - df = nw.from_native(constructor_eager(data), eager_only=True) - result = df["a"].cast(nw.List(nw.Int32())).list.sort() - assert_equal_data({"a": result}, {"a": expected_asc_nulls_first}) - - @pytest.mark.parametrize( ("descending", "nulls_last", "expected"), [