Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/expr_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- mean
- median
- min
- sort
- sum
- unique
show_source: false
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- mean
- median
- min
- sort
- sum
- unique
show_source: false
Expand Down
7 changes: 6 additions & 1 deletion narwhals/_arrow/series_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg
from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg, list_sort
from narwhals._compliant.any_namespace import ListNamespace
from narwhals._utils import not_implemented

Expand Down Expand Up @@ -35,5 +35,10 @@ def median(self) -> ArrowSeries:
def sum(self) -> ArrowSeries:
return self.with_native(list_agg(self.native, "sum"))

def sort(self, *, descending: bool, nulls_last: bool) -> ArrowSeries:
return self.with_native(
list_sort(self.native, descending=descending, nulls_last=nulls_last)
)

unique = not_implemented()
contains = not_implemented()
32 changes: 32 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,35 @@ def list_agg(
)
]
)


def list_sort(
array: ChunkedArrayAny, *, descending: bool, nulls_last: bool
) -> ChunkedArrayAny:
sort_direction: Literal["ascending", "descending"] = (
"descending" if descending else "ascending"
)
nulls_position: Literal["at_start", "at_end"] = "at_end" if nulls_last else "at_start"
idx, v = "idx", "values"
is_not_sorted = pc.greater(pc.list_value_length(array), lit(0))
indexed = pa.Table.from_arrays(
[arange(start=0, end=len(array), step=1), array], names=[idx, v]
)
not_sorted_part = indexed.filter(is_not_sorted)
pass_through = indexed.filter(pc.fill_null(pc.invert(is_not_sorted), lit(True))) # pyright: ignore[reportArgumentType]
exploded = pa.Table.from_arrays(
[pc.list_flatten(array), pc.list_parent_indices(array)], names=[v, idx]
)
sorted_indices = pc.sort_indices(
exploded,
sort_keys=[(idx, "ascending"), (v, sort_direction)],
null_placement=nulls_position,
)
offsets = not_sorted_part.column(v).combine_chunks().offsets # type: ignore[attr-defined]
sorted_imploded = pa.ListArray.from_arrays(
offsets, pa.array(exploded.take(sorted_indices).column(v))
)
imploded_by_idx = pa.Table.from_arrays(
[not_sorted_part.column(idx), sorted_imploded], names=[idx, v]
)
return pa.concat_tables([imploded_by_idx, pass_through]).sort_by(idx).column(v)
1 change: 1 addition & 0 deletions narwhals/_compliant/any_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def max(self) -> CompliantT_co: ...
def mean(self) -> CompliantT_co: ...
def median(self) -> CompliantT_co: ...
def sum(self) -> CompliantT_co: ...
def sort(self, *, descending: bool, nulls_last: bool) -> CompliantT_co: ...


class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
Expand Down
5 changes: 5 additions & 0 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,11 @@ def median(self) -> EagerExprT:
def sum(self) -> EagerExprT:
return self.compliant._reuse_series_namespace("list", "sum")

def sort(self, *, descending: bool, nulls_last: bool) -> EagerExprT:
return self.compliant._reuse_series_namespace(
"list", "sort", descending=descending, nulls_last=nulls_last
)


class CompliantExprNameNamespace( # type: ignore[misc]
_ExprNamespace[CompliantExprT_co],
Expand Down
7 changes: 7 additions & 0 deletions narwhals/_duckdb/expr_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,10 @@ def func(expr: Expression) -> Expression:
)

return self.compliant._with_callable(func)

def sort(self, *, descending: bool, nulls_last: bool) -> DuckDBExpr:
sort_direction = "DESC" if descending else "ASC"
nulls_position = "NULLS LAST" if nulls_last else "NULLS FIRST"
return self.compliant._with_elementwise(
lambda expr: F("list_sort", expr, lit(sort_direction), lit(nulls_position))
)
14 changes: 14 additions & 0 deletions narwhals/_ibis/expr_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,18 @@ def func(expr: ir.ArrayColumn) -> ir.Value:

return self.compliant._with_callable(func)

def sort(self, *, descending: bool, nulls_last: bool) -> IbisExpr:
if descending:
msg = "Descending sort is not currently supported for Ibis."
raise NotImplementedError(msg)

def func(expr: ir.ArrayColumn) -> ir.ArrayValue:
if nulls_last:
return expr.sort()
expr_no_nulls = expr.filter(lambda x: x.notnull())
expr_nulls = expr.filter(lambda x: x.isnull())
return expr_nulls.concat(expr_no_nulls.sort())

return self.compliant._with_callable(func)

median = not_implemented()
39 changes: 23 additions & 16 deletions narwhals/_pandas_like/series_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

from narwhals._compliant.any_namespace import ListNamespace
Expand Down Expand Up @@ -34,35 +35,28 @@ def len(self) -> PandasLikeSeries:
)
return self.with_native(result.astype(dtype)).alias(self.native.name)

unique = not_implemented()

contains = not_implemented()

def get(self, index: int) -> PandasLikeSeries:
result = self.native.list[index]
result.name = self.native.name
return self.with_native(result)

def _agg(
self, func: Literal["min", "max", "mean", "approximate_median", "sum"]
) -> PandasLikeSeries:
def _raise_if_not_pyarrow_backend(self) -> None:
dtype_backend = get_dtype_backend(
self.native.dtype, self.compliant._implementation
)
if dtype_backend != "pyarrow": # pragma: no cover
msg = "Only pyarrow backend is currently supported."
raise NotImplementedError(msg)

from narwhals._arrow.utils import list_agg, native_to_narwhals_dtype
def _agg(
self, func: Literal["min", "max", "mean", "approximate_median", "sum"]
) -> PandasLikeSeries:
self._raise_if_not_pyarrow_backend()

from narwhals._arrow.utils import list_agg

ca = self.native.array._pa_array
result_arr = list_agg(ca, func)
nw_dtype = native_to_narwhals_dtype(result_arr.type, self.version)
out_dtype = narwhals_to_native_dtype(
nw_dtype, "pyarrow", self.implementation, self.version
)
result_native = type(self.native)(
result_arr, dtype=out_dtype, index=self.native.index, name=self.native.name
result_native = self.compliant._apply_pyarrow_compute_func(
self.native, partial(list_agg, func=func)
)
return self.with_native(result_native)

Expand All @@ -80,3 +74,16 @@ def median(self) -> PandasLikeSeries:

def sum(self) -> PandasLikeSeries:
return self._agg("sum")

def sort(self, *, descending: bool, nulls_last: bool) -> PandasLikeSeries:
self._raise_if_not_pyarrow_backend()

from narwhals._arrow.utils import list_sort

result_native = self.compliant._apply_pyarrow_compute_func(
self.native, partial(list_sort, descending=descending, nulls_last=nulls_last)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FBruzzesi factored some repeated logic here out using _apply_pyarrow_compute_func, is it possible to use that here?

return self.with_native(result_native)

unique = not_implemented()
contains = not_implemented()
2 changes: 2 additions & 0 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ def len(self) -> CompliantT: ...

min: Method[CompliantT]

sort: Method[CompliantT]

sum: Method[CompliantT]


Expand Down
12 changes: 12 additions & 0 deletions narwhals/_spark_like/expr_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,15 @@ def func(expr: Column) -> Column: # pragma: no cover
)

return self.compliant._with_elementwise(func)

def sort(self, *, descending: bool, nulls_last: bool) -> SparkLikeExpr:
def func(expr: Column) -> Column:
F = self.compliant._F
if not descending and nulls_last:
return F.array_sort(expr)
if descending and not nulls_last: # pragma: no cover
# https://github.com/eakmanrq/sqlframe/issues/559
return F.reverse(F.array_sort(expr))
return F.sort_array(expr, asc=not descending)

return self.compliant._with_elementwise(func)
36 changes: 36 additions & 0 deletions narwhals/expr_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,39 @@ def sum(self) -> ExprT:
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.sum"))

def sort(self, *, descending: bool = False, nulls_last: bool = False) -> ExprT:
"""Sort the lists of the expression.

Arguments:
descending: Sort in descending order.
nulls_last: Place null values last.

Examples:
>>> import duckdb
>>> import narwhals as nw
>>> df_native = duckdb.sql(
... "SELECT * FROM VALUES ([2, -1, 1]), ([3, -4, NULL]) df(a)"
... )
>>> df = nw.from_native(df_native)
>>> df.with_columns(a_sorted=nw.col("a").list.sort())
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
| Narwhals LazyFrame |
|---------------------------------|
|β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”|
|β”‚ a β”‚ a_sorted β”‚|
|β”‚ int32[] β”‚ int32[] β”‚|
|β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€|
|β”‚ [2, -1, 1] β”‚ [-1, 1, 2] β”‚|
|β”‚ [3, -4, NULL] β”‚ [NULL, -4, 3] β”‚|
|β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self._expr._append_node(
ExprNode(
ExprKind.ELEMENTWISE,
"list.sort",
descending=descending,
nulls_last=nulls_last,
)
)
26 changes: 26 additions & 0 deletions narwhals/series_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,29 @@ def sum(self) -> SeriesT:
return self._narwhals_series._with_compliant(
self._narwhals_series._compliant_series.list.sum()
)

def sort(self, *, descending: bool = False, nulls_last: bool = False) -> SeriesT:
"""Sort the lists of the series.

Arguments:
descending: Sort in descending order.
nulls_last: Place null values last.

Examples:
>>> import polars as pl
>>> import narwhals as nw
>>> s_native = pl.Series([[2, -1, 1], [3, -4, None]])
>>> s = nw.from_native(s_native, series_only=True)
>>> s.list.sort().to_native() # doctest: +NORMALIZE_WHITESPACE
shape: (2,)
Series: '' [list[i64]]
[
[-1, 1, 2]
[null, -4, 3]
]
"""
return self._narwhals_series._with_compliant(
self._narwhals_series._compliant_series.list.sort(
descending=descending, nulls_last=nulls_last
)
)
Loading
Loading