-
Notifications
You must be signed in to change notification settings - Fork 180
feat(expr-ir): Add list.* aggregate methods
#3353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
22efc12
02032f9
2c2fa08
c8d09ed
0cb1f5c
a867206
e9c3656
7cd45d6
501480f
b5a78b0
a3a43a4
e99f97a
5d0376e
f8f9909
abd4843
a7c9ee1
3fefcdb
96b6638
5b310c6
86a3060
d232439
92d0b74
76ba623
f6de206
fc761a9
f74c4dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from itertools import chain | ||
| from typing import TYPE_CHECKING, Any, Literal, overload | ||
| from typing import TYPE_CHECKING, Any, Literal, cast, overload | ||
|
|
||
| import pyarrow as pa # ignore-banned-import | ||
| import pyarrow.compute as pc # ignore-banned-import | ||
|
|
@@ -13,7 +13,7 @@ | |
| from narwhals._plan.common import temp | ||
| from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy | ||
| from narwhals._plan.expressions import aggregation as agg | ||
| from narwhals._utils import Implementation | ||
| from narwhals._utils import Implementation, qualified_type_name | ||
| from narwhals.exceptions import InvalidOperationError | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -26,16 +26,17 @@ | |
| ArrayAny, | ||
| ChunkedArray, | ||
| ChunkedArrayAny, | ||
| ChunkedList, | ||
| ChunkedOrScalarAny, | ||
| Indices, | ||
| ListScalar, | ||
| ScalarAny, | ||
| ) | ||
| from narwhals._plan.expressions import NamedIR | ||
| from narwhals._plan.typing import Seq | ||
|
|
||
| Incomplete: TypeAlias = Any | ||
|
|
||
| # NOTE: Unless stated otherwise, all aggregations have 2 variants: | ||
| # - `<function>` (pc.Function.kind == "scalar_aggregate") | ||
| # - `hash_<function>` (pc.Function.kind == "hash_aggregate") | ||
| SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { | ||
| agg.Sum: "hash_sum", | ||
| agg.Mean: "hash_mean", | ||
|
|
@@ -51,9 +52,19 @@ | |
| agg.Last: "hash_last", | ||
| fn.MinMax: "hash_min_max", | ||
| } | ||
| SUPPORTED_LIST_AGG: Mapping[type[ir.lists.Aggregation], type[agg.AggExpr]] = { | ||
| ir.lists.Mean: agg.Mean, | ||
| ir.lists.Median: agg.Median, | ||
| ir.lists.Max: agg.Max, | ||
| ir.lists.Min: agg.Min, | ||
| ir.lists.Sum: agg.Sum, | ||
| ir.lists.First: agg.First, | ||
| ir.lists.Last: agg.Last, | ||
| ir.lists.NUnique: agg.NUnique, | ||
| } | ||
| SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { | ||
| ir.Len: "hash_count_all", | ||
| ir.Column: "hash_list", # `hash_aggregate` only | ||
| ir.Column: "hash_list", | ||
| } | ||
|
|
||
| _version_dependent: dict[Any, acero.Aggregation] = {} | ||
|
|
@@ -65,16 +76,42 @@ | |
| SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { | ||
| ir.boolean.All: "hash_all", | ||
| ir.boolean.Any: "hash_any", | ||
| ir.functions.Unique: "hash_distinct", # `hash_aggregate` only | ||
| ir.functions.Unique: "hash_distinct", | ||
| ir.functions.NullCount: "hash_count", | ||
| **_version_dependent, | ||
| } | ||
|
|
||
| del _version_dependent | ||
|
|
||
|
|
||
| SUPPORTED_LIST_FUNCTION: Mapping[type[ir.lists.Aggregation], type[ir.Function]] = { | ||
| ir.lists.Any: ir.boolean.Any, | ||
| ir.lists.All: ir.boolean.All, | ||
| } | ||
|
Comment on lines
+87
to
+90
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would expect these to be aggregations as well (and go in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I too was confused by this initially 😄 But this is something I've inherited from
|
||
|
|
||
| SCALAR_OUTPUT_TYPE: Mapping[acero.Aggregation, pa.DataType] = { | ||
| "all": fn.BOOL, | ||
| "any": fn.BOOL, | ||
| "approximate_median": fn.F64, | ||
| "count": fn.I64, | ||
| "count_all": fn.I64, | ||
| "count_distinct": fn.I64, | ||
| "kurtosis": fn.F64, | ||
| "mean": fn.F64, | ||
| "skew": fn.F64, | ||
| "stddev": fn.F64, | ||
| "variance": fn.F64, | ||
| } | ||
| """Scalar aggregates that have an output type **not** dependent on input types*. | ||
|
|
||
| For use in list aggregates, where the input was null. | ||
|
|
||
| *Except `"mean"` will preserve `Decimal`, if that's where we started. | ||
| """ | ||
|
|
||
|
|
||
| class AggSpec: | ||
| __slots__ = ("agg", "name", "option", "target") | ||
| __slots__ = ("_function", "_name", "_option", "_target") | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -83,19 +120,22 @@ def __init__( | |
| option: acero.Opts = None, | ||
| name: acero.OutputName = "", | ||
| ) -> None: | ||
| self.target = target | ||
| self.agg = agg | ||
| self.option = option | ||
| self.name = name or str(target) | ||
| self._target = target | ||
| self._function: acero.Aggregation = agg | ||
| self._option: acero.Opts = option | ||
| self._name: acero.OutputName = name or str(target) | ||
|
|
||
| @property | ||
| def use_threads(self) -> bool: | ||
| """See https://github.com/apache/arrow/issues/36709.""" | ||
| return acero.can_thread(self.agg) | ||
| return acero.can_thread(self._function) | ||
|
|
||
| def __iter__(self) -> Iterator[acero.Target | acero.Aggregation | acero.Opts]: | ||
| """Let's us duck-type as a 4-tuple.""" | ||
| yield from (self.target, self.agg, self.option, self.name) | ||
| yield from (self._target, self._function, self._option, self._name) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{type(self).__name__}({self._target!r}, {self._function!r}, {self._option!r}, {self._name!r})" | ||
|
|
||
| @classmethod | ||
| def from_named_ir(cls, named_ir: NamedIR) -> Self: | ||
|
|
@@ -141,6 +181,24 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: | |
| def _from_function(cls, tp: type[ir.Function], name: str) -> Self: | ||
| return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name) | ||
|
|
||
| @classmethod | ||
| def _from_list_agg(cls, list_agg: ir.lists.Aggregation, /, name: str) -> Self: | ||
| tp = type(list_agg) | ||
| if tp_agg := SUPPORTED_LIST_AGG.get(tp): | ||
| if tp_agg in {agg.Std, agg.Var}: | ||
| msg = ( | ||
| f"TODO: {qualified_type_name(list_agg)!r} needs access to `ddof`.\n" | ||
| "Add some sugar around mapping `ListFunction.<parameter>` -> `AggExpr.<parameter>`\n" | ||
| "or using `Immutable.__immutable_keys__`" | ||
| ) | ||
| raise NotImplementedError(msg) | ||
| fn_name = SUPPORTED_AGG[tp_agg] | ||
| elif tp_func := SUPPORTED_LIST_FUNCTION.get(tp): | ||
| fn_name = SUPPORTED_FUNCTION[tp_func] | ||
| else: | ||
| raise NotImplementedError(tp) | ||
| return cls(name, fn_name, options.LIST_AGG.get(tp), name) | ||
|
|
||
| @classmethod | ||
| def any(cls, name: str) -> Self: | ||
| return cls._from_function(ir.boolean.Any, name) | ||
|
|
@@ -155,6 +213,29 @@ def implode(cls, name: str) -> Self: | |
| # https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-plan/src/dsl/expr/mod.rs#L44 | ||
| return cls(name, SUPPORTED_IR[ir.Column], None, name) | ||
|
|
||
| @overload | ||
| def agg_list(self, native: ChunkedList) -> ChunkedArrayAny: ... | ||
| @overload | ||
| def agg_list(self, native: ListScalar) -> ScalarAny: ... | ||
| def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny: | ||
| """Execute this aggregation over the values in *each* list, reducing *each* to a single value.""" | ||
| result: ChunkedOrScalarAny | ||
| if isinstance(native, pa.Scalar): | ||
| scalar = cast("pa.ListScalar[Any]", native) | ||
| func = HASH_TO_SCALAR_NAME[self._function] | ||
| if not scalar.is_valid: | ||
| return fn.lit(None, SCALAR_OUTPUT_TYPE.get(func, scalar.type.value_type)) | ||
| result = pc.call_function(func, [scalar.values], self._option) | ||
| return result | ||
| result = self.over_index(fn.ExplodeBuilder().explode_with_indices(native), "idx") | ||
| result = fn.when_then(native.is_valid(), result) | ||
| if self._is_n_unique(): | ||
| # NOTE: Exploding `[]` becomes `[None]` - so we need to adjust the unique count *iff* we were unlucky | ||
| is_sublist_empty = fn.eq(fn.list_len(native), fn.lit(0)) | ||
| if fn.any_(is_sublist_empty).as_py(): | ||
| result = fn.when_then(is_sublist_empty, fn.lit(0), result) | ||
| return result | ||
|
|
||
| def over(self, native: pa.Table, keys: Iterable[acero.Field]) -> pa.Table: | ||
| """Sugar for `native.group_by(keys).aggregate([self])`. | ||
|
|
||
|
|
@@ -167,7 +248,10 @@ def over_index(self, native: pa.Table, index_column: str) -> ChunkedArrayAny: | |
|
|
||
| Returns a single, (unnamed) array, representing the aggregation results. | ||
| """ | ||
| return acero.group_by_table(native, [index_column], [self]).column(self.name) | ||
| return acero.group_by_table(native, [index_column], [self]).column(self._name) | ||
|
|
||
| def _is_n_unique(self) -> bool: | ||
| return self._function == SUPPORTED_AGG[agg.NUnique] | ||
|
|
||
|
|
||
| def group_by_error( | ||
|
|
@@ -321,3 +405,32 @@ def _partition_by_many( | |
| # E.g, to push down column selection to *before* collection | ||
| # Not needed for this task though | ||
| yield acero.collect(source, acero.filter(key == v), select) | ||
|
|
||
|
|
||
| def _generate_hash_to_scalar_name() -> Mapping[acero.Aggregation, acero.Aggregation]: | ||
| nw_to_hash = SUPPORTED_AGG, SUPPORTED_IR, SUPPORTED_FUNCTION | ||
| only_hash = {"hash_distinct", "hash_list", "hash_one"} | ||
| targets = set[str](chain.from_iterable(m.values() for m in nw_to_hash)) - only_hash | ||
| hash_to_scalar = {hash_name: hash_name.removeprefix("hash_") for hash_name in targets} | ||
| # NOTE: Support both of these when using `AggSpec` directly for scalar aggregates | ||
| # `(..., "hash_mean", ..., ...)` | ||
| # `(..., "mean", ..., ...)` | ||
| scalar_names = hash_to_scalar.values() | ||
| scalar_to_scalar = zip(scalar_names, scalar_names) | ||
| hash_to_scalar.update(dict(scalar_to_scalar)) | ||
| return cast("Mapping[acero.Aggregation, acero.Aggregation]", hash_to_scalar) | ||
|
|
||
|
|
||
| # TODO @dangotbanned: Replace this with a lazier version | ||
| # Don't really want this running at import-time, but using `ModuleType.__getattr__` means | ||
| # defining it somewhere else | ||
| HASH_TO_SCALAR_NAME: Mapping[acero.Aggregation, acero.Aggregation] = ( | ||
| _generate_hash_to_scalar_name() | ||
| ) | ||
| """Mapping between [Hash aggregate] and [Scalar aggregate] names. | ||
|
|
||
| Dynamically built for use in `ListScalar` aggregations, accounting for version availability. | ||
|
|
||
| [Hash aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#grouped-aggregations-group-by | ||
| [Scalar aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#aggregations | ||
| """ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually, it'll make more sense to define this somewhere in
_plan/expressions/- but this'll do for now