feat: add list aggregate methods#3332
Conversation
7b15d52 to
8e04fc1
Compare
FBruzzesi
left a comment
There was a problem hiding this comment.
Thanks @raisadz, I have a couple more comments, apologies for the fragmented review 🙈
- Could you add a couple of test cases:
a. All nulls in list
b. Empty list
c.polars.Expr.list.sumsays: If there are no non-null elements in a row, the output is 0. For the other aggregations it's unclear what the output should be, and I wonder how consistent it is across all different backends. - Could you mix the docstring examples a bit other than polars?
| def list_agg( | ||
| array: ChunkedArrayAny, | ||
| func: Literal["min", "max", "mean", "approximate_median", "sum"], | ||
| ) -> ChunkedArrayAny: | ||
| return ( | ||
| pa.Table.from_arrays( | ||
| [pc.list_flatten(array), pc.list_parent_indices(array)], | ||
| names=["values", "offsets"], | ||
| ) | ||
| .group_by("offsets") | ||
| .aggregate([("values", func)]) | ||
| .column(f"values_{func}") | ||
| ) |
There was a problem hiding this comment.
@raisadz I'm pretty excited by this! 😄
+1 from me on (#3332 (review))
I've just tried this out with the test case for list.unique:
The result for that should be:
[[None, 2, 3], None, [], [None]]
But using list_agg seems to have dropped 2/4 lists and all nulls 🤔
import pyarrow as pa
data = {"a": [[2, 2, 3, None, None], None, [], [None]]}
ca = pa.chunked_array([pa.array(data["a"])])
result = list_agg(ca, "distinct").to_pylist()
print(result)[[2, 3], []]
I managed to get slightly closer to what we want, by passing in options for the group_by:
Show list_agg_opts
from typing import Any
import pyarrow as pa
import pyarrow.compute as pc
def list_agg_opts(
array: pa.ChunkedArray[Any], func: Any, options: Any = None
) -> pa.ChunkedArray[Any]:
return (
pa.Table.from_arrays(
[pc.list_flatten(array), pc.list_parent_indices(array)],
names=["values", "offsets"],
)
.group_by("offsets")
.aggregate([("values", func, options)]) # <-------
.column(f"values_{func}")
)These are the correct results for 2/4 of the lists 🎉
But where did the other 2 go? 😳
result = list_agg_opts(ca, "distinct", pc.CountOptions("all")).to_pylist()
print(result)[[2, 3, None], [None]]
Edit: I missed it myself lol, fixed in (d8363e1)
|
Thank you both! Yes in fact those tests with empty lists or all Nones lists fail for multiple backends. I pushed some changes but this is wip and I will continue to work on that |
Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>
|
@raisadz the action failing for python 3.9 seems to be on windows only (I cannot replicate that on mac nor ubuntu, but I don't have a windows machine to try it). I would say let's xfail such case and consider:
|
| if TYPE_CHECKING: | ||
| from tests.utils import Constructor, ConstructorEager | ||
|
|
||
| data = {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], []]} |
There was a problem hiding this comment.
I've been a bit afraid to bring this up, but AFAICT all of the list.* methods have only been tested against 1 level of nesting.
I think 2 (or more) levels might be a problem for pyarrow and pandas, because lists aren't hashable yet.
I'd be really happy to wrong on this though 🙏
If it is an issue, I don't think it needs to be a blocker - just something to keep in mind 🙂
There was a problem hiding this comment.
polars seems to properly support these ops only for List(<numeric_type>) (and maybe temporal types?) and either fail or return nulls otherwise:
import polars as pl
data = [
[[1], [2,3]],
[[4,5,6], [7,8]]
]
series = pl.Series(data)
print(series.dtype)
print()
for op in ("min", "mean", "max", "median", "sum"):
print(f"Executing {op}")
try:
print("result", getattr(series.list, op)())
except Exception as exc:
print("error", exc)
print()List(List(Int64))
Executing min
error `min` operation not supported for dtype `list[i64]`
Executing mean
result shape: (2,)
Series: '' [f64]
[
null
null
]
Executing max
error `max` operation not supported for dtype `list[i64]`
Executing median
result shape: (2,)
Series: '' [f64]
[
null
null
]
Executing sum
error `sum` operation not supported for dtype `list[i64]`
Hint: you may mean to call `concat_list`
There are a few other cases for which we check the dtype before performing an operation. We might do the same here
There was a problem hiding this comment.
@raisadz I think this is the last missing bit before merging. If we can standardize the error message for all backends it would be amazing. I would be in favor in raising for list.{mean,median} as well, and ask in polars if the current output is expected.
Update1: Asked in their discord channel
Update2: reply:
it's an older leftover I think
unsupported aggregates used to return None and we've been slowly transitioning over time to nice errors
There was a problem hiding this comment.
I think 2 (or more) levels might be a problem for pyarrow and pandas, because lists aren't hashable yet.
tbh i wouldn't worry about it, just leave it to each backend
Thanks @FBruzzesi ! I missed this failure and I am skipping it now. I don't think we should worry about 3.9 version on windows as it is almost obsolete at this point |
This might be good to mention in #3204? |
narwhals/_arrow/utils.py
Outdated
| base_array = pc.if_else( | ||
| non_empty_mask, 0, None | ||
| ) # zero is just a placeholder which is replaced below |
There was a problem hiding this comment.
does it work to just do
base_array = pa.repeat(lit(None, type=agg.type), len(array))?
the 0 placeholder feels a bit magical
There was a problem hiding this comment.
here's something i worked through in the live stream, could it work?
lit_: Incomplete = lit
aggregation = ('values', 'sum', pc.ScalarAggregateOptions(min_count=0)) if func == 'sum' else ('values', func)
agg = pa.array(
pa.Table.from_arrays(
[pc.list_flatten(array), pc.list_parent_indices(array)],
names=["values", "offsets"],
)
.group_by("offsets")
.aggregate([aggregation])
.sort_by("offsets")
.column(f"values_{func}")
)
non_empty_mask = pa.array(pc.not_equal(pc.list_value_length(array), lit(0)))
if func == "sum":
# Make sure sum of empty list is 0.
base_array = pc.if_else(non_empty_mask.is_null(), None, 0)
else:
base_array = pa.repeat(lit_(None, type=agg.type), len(array))
return pa.chunked_array(
[
pc.replace_with_mask(
base_array,
non_empty_mask.fill_null(False), # type: ignore[arg-type]
agg,
)
]
)There was a problem hiding this comment.
I've got like 97 variations of this now 😂
Pretty much everything that involves list.* needs to do this dance around [] and None
MarcoGorelli
left a comment
There was a problem hiding this comment.
thanks all! looks like this is very much on the right track
|
Thanks for the suggestion @MarcoGorelli ! I applied it now to the list_agg function |
There was a problem hiding this comment.
thanks @raisadz , and @FBruzzesi + @dangotbanned for reviews
Playing catch-up on #3332
| if ( | ||
| any(backend in str(constructor) for backend in ("pandas", "pyarrow")) | ||
| and sys.version_info < (3, 10) | ||
| and is_windows | ||
| ): # pragma: no cover | ||
| reason = "The issue only affects old Python versions on Windows." | ||
| pytest.skip(reason=reason) |
There was a problem hiding this comment.
Sorry to have caught this so late, I just noticed in (oh-nodes...expr-ir/list-agg)
Since is_windows is a function ...
... this condition is always True:
from tests.utils import is_windows
bool(is_windows)True
Meaning that all platforms skip on sys.version_info < (3, 10) 😱
I usually prefer xfail to skip since the former is the only one that'll tell you when something's amiss 😉
There was a problem hiding this comment.
Thanks @dangotbanned ! I added a follow-up PR to fix this #3354
Tried to keep everything as close to original as possible Next step is simplifying everything and fixing `list.sum`
| data = {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], [], [3, 4, None]]} | ||
| expected = [2.5, -1, None, None, None, 3.5] | ||
| expected_pyarrow = [2.5, -1, None, None, None, 3] |
There was a problem hiding this comment.
@raisadz @FBruzzesi (re #3332 (comment), #3332 (review))
I've done some experimenting and think I've found what the pyarrow issue is:
import pyarrow as pa
import pyarrow.compute as pc
def median(*values: float | None) -> pa.DoubleScalar:
return pc.approximate_median(pa.array(values, pa.float64()))
def median_pretty(*values: float | None) -> None:
print(f"median({list(values)!a:21}) = {median(*values)}")I wonder if you can spot it too 😄:
median_pretty()
median_pretty(3)
median_pretty(3, 4)
median_pretty(3, 4, None)
median_pretty(3, 4, None, None)
median_pretty(3, 4, None, None, 5)
median_pretty(3, 4, 5)
median_pretty(5, 3, 4)
median_pretty(5, 3)
median_pretty(5, 2)
median_pretty(5, 2, 50)
median_pretty(None, 2, 50)
median_pretty(None, 2, 50, 2)
median_pretty(50, 2, 50, 2)median([] ) = None
median([3] ) = 3.0
median([3, 4] ) = 3.0
median([3, 4, None] ) = 3.0
median([3, 4, None, None] ) = 3.0
median([3, 4, None, None, 5]) = 4.0
median([3, 4, 5] ) = 4.0
median([5, 3, 4] ) = 4.0
median([5, 3] ) = 3.0
median([5, 2] ) = 2.0
median([5, 2, 50] ) = 5.0
median([None, 2, 50] ) = 2.0
median([None, 2, 50, 2] ) = 2.0
median([50, 2, 50, 2] ) = 26.0
Demonstrated in (#3332 (comment)) The issue is unrelated to group_by and lists


Description
The following list methods are implemented:
What type of PR is this? (check all applicable)
Related issues
Checklist