Conversation
|
@raisadz I still need to loop back to it and clean things up, but I think Here's a rough equivalent in Outdated
Note: I've hidden this because of (#3356 (comment)) import polars as pl
def list_sort(
native: pl.Series, *, descending: bool = False, nulls_last: bool = False
) -> pl.Series:
idx, name = "index", native.name
indexed = native.to_frame().with_row_index(name=idx)
len_gt_1 = pl.col(name).list.len() > 1
valid = indexed.filter(len_gt_1)
invalid = indexed.filter(pl.col(name).is_null() | ~len_gt_1)
exploded = valid.explode(name, empty_as_null=False, keep_nulls=False)
valid_finished = (
exploded.sort(idx, name, descending=[False, descending], nulls_last=nulls_last)
.group_by(idx, maintain_order=True)
.agg(name)
)
return pl.concat([valid_finished, invalid], how="vertical").sort(idx).get_column(name)And then against the test suite (https://github.com/narwhals-dev/narwhals/pull/3356/files#diff-24b32e940c35026188c24063b47ae536444a402c28879bdf7fa4853c7c5f5a0b) data = {"a": [[3, 2, 2, 4, -10, None, None], [-1], None, [None, None, None], []]}
ser = pl.DataFrame(data).to_series()print(list_sort(ser, descending=True, nulls_last=True).to_list())
print(list_sort(ser, descending=True, nulls_last=False).to_list())
print(list_sort(ser, descending=False, nulls_last=True).to_list())
print(list_sort(ser, descending=False, nulls_last=False).to_list())[[4, 3, 2, 2, -10, None, None], [-1], None, [None, None, None], []]
[[None, None, 4, 3, 2, 2, -10], [-1], None, [None, None, None], []]
[[-10, 2, 2, 3, 4, None, None], [-1], None, [None, None, None], []]
[[None, None, -10, 2, 2, 3, 4], [-1], None, [None, None, None], []] |
MarcoGorelli
left a comment
There was a problem hiding this comment.
thanks for working on this! just some comments
| def test_sort_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: | ||
| if any(backend in str(constructor) for backend in ("dask", "cudf")): | ||
| request.applymarker(pytest.mark.xfail) | ||
| if "sqlframe" in str(constructor): | ||
| # https://github.com/eakmanrq/sqlframe/issues/559 | ||
| # https://github.com/eakmanrq/sqlframe/issues/560 | ||
| request.applymarker(pytest.mark.xfail) | ||
| if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 5): | ||
| pytest.skip() | ||
| if "pandas" in str(constructor): | ||
| if PANDAS_VERSION < (2, 2): | ||
| pytest.skip() | ||
| pytest.importorskip("pyarrow") | ||
| result = nw.from_native(constructor(data)).select( | ||
| nw.col("a").cast(nw.List(nw.Int32())).list.sort() | ||
| ) | ||
| assert_equal_data(result, {"a": expected_asc_nulls_first}) |
There was a problem hiding this comment.
maybe we can skip this test if the below already tests all possibilities?
| ) | ||
| result_native = type(self.native)( | ||
| result_arr, dtype=out_dtype, index=self.native.index, name=self.native.name | ||
| ) |
There was a problem hiding this comment.
@FBruzzesi factored some repeated logic here out using _apply_pyarrow_compute_func, is it possible to use that here?
| assert_equal_data(result, {"a": expected}) | ||
|
|
||
|
|
||
| def test_sort_series( |
|
Thanks for the review, @MarcoGorelli ! I addressed your comments |
MarcoGorelli
left a comment
There was a problem hiding this comment.
thanks @raisadz and @dangotbanned !

Description
I tried to have some workarounds for the edge cases for pyarrow that we discussed #3332 (comment) (None, empty lists and lists with only None elements) with
pc.sort_indicesandpc.replace_with_maskbut the latter doesn't seem to work for the list types. There is an open issue apache/arrow#48060 that can make it work for pyarrow, pandas and modin when solved.Also, I opened a couple of issues in sqlframe and ibis related to this PR:
eakmanrq/sqlframe#559
eakmanrq/sqlframe#560
ibis-project/ibis#11735
What type of PR is this? (check all applicable)
Related issues
Checklist