-
Notifications
You must be signed in to change notification settings - Fork 172
feat: Adds {Expr,Series}.{first,last}
#2528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ff661ae
1b77bd7
e84cba3
25ef241
78822aa
4075c50
45f24b9
4041dd1
bb9912d
6a53aa1
4efc939
fc149c1
7489e61
d2719a4
0af11db
afe20f0
6c0bd6f
e0fdf78
aa7c510
4fdc0aa
bd4ab89
9f7f5a9
ddb50d2
7146f60
8c24e6e
54a4cb4
63e0459
9493aad
88535a4
77ae9c0
c1a6173
3c4ff9b
696e35d
b2866d2
cd002f3
1458530
962ebcd
5d310bc
a417341
354da1a
0cea41b
5229096
8d3aaec
a62e3ef
9c36285
ad8e3f7
628f71e
deacc71
5c52ee4
eec2a4f
e003bab
fb2dc1c
211673b
652615f
68fdfe8
ecaca9a
ea30f26
12987ee
7d70a42
5446095
b927340
f62c085
45d20c8
72ab185
e72b115
fae137c
cb363be
bc80a5f
14051fa
3d42dcf
934d09e
3fbf6f2
801a7a8
47bfaba
4618d01
1998ad2
570cdaf
d561027
b77d2b3
2b0bc16
abbb4b7
ccfe532
dd1f89e
54b3188
5f9ff6f
4000b25
1c62ce2
64fdf10
063e5d0
2e4f260
65e6804
22fae20
60624b9
d66fddc
5e444a5
e1a9bc3
0cbe33d
2960736
4ede6b2
2ae4245
b8066c4
2dae6ef
3aa52dc
7c578c7
30bad0e
d269d56
c0e37aa
6f5c05b
20be193
abd027a
1fd9fd3
94d6b19
849a6d9
476c63e
c169104
d77fcd1
1f38bde
3c63726
6d7b09b
3b6301f
bfc55c7
f47ef14
b32db75
f22a497
b5fe1ba
0f301e7
16a2762
09dca76
ffe7e24
0fb0455
6d63ea6
7b00310
016abc9
29d6cb7
3b91e23
c87935d
0393dfe
466c922
4266e4b
555098b
36e38e0
42d2cd6
63f012a
c4ac043
8739b6a
ff22604
d9c4a1b
03b7969
d01a398
18c0861
948d96d
8810d03
843549f
d7be792
363490d
c25d649
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |
| from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar | ||
| from narwhals._compliant import EagerGroupBy | ||
| from narwhals._expression_parsing import evaluate_output_names_and_aliases | ||
| from narwhals._utils import generate_temporary_column_name | ||
| from narwhals._utils import generate_temporary_column_name, requires | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Iterator, Mapping, Sequence | ||
|
|
@@ -39,12 +39,23 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]): | |
| "count": "count", | ||
| "all": "all", | ||
| "any": "any", | ||
| "first": "first", | ||
| "last": "last", | ||
| } | ||
| _REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = { | ||
| "any": "min", | ||
| "first": "min", | ||
| "last": "max", | ||
| } | ||
| _OPTION_COUNT_ALL: ClassVar[frozenset[NarwhalsAggregation]] = frozenset( | ||
| ("len", "n_unique") | ||
| ) | ||
| _OPTION_COUNT_VALID: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("count",)) | ||
| _OPTION_ORDERED: ClassVar[frozenset[NarwhalsAggregation]] = frozenset( | ||
| ("first", "last") | ||
| ) | ||
| _OPTION_VARIANCE: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("std", "var")) | ||
| _OPTION_SCALAR: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("any", "all")) | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -60,12 +71,58 @@ def __init__( | |
| self._grouped = pa.TableGroupBy(self.compliant.native, self._keys) | ||
| self._drop_null_keys = drop_null_keys | ||
|
|
||
| def _configure_agg( | ||
| self, grouped: pa.TableGroupBy, expr: ArrowExpr, / | ||
| ) -> tuple[pa.TableGroupBy, Aggregation, AggregateOptions | None]: | ||
| option: AggregateOptions | None = None | ||
| function_name = self._leaf_name(expr) | ||
| if function_name in self._OPTION_VARIANCE: | ||
| ddof = expr._scalar_kwargs.get("ddof", 1) | ||
| option = pc.VarianceOptions(ddof=ddof) | ||
| elif function_name in self._OPTION_COUNT_ALL: | ||
| option = pc.CountOptions(mode="all") | ||
| elif function_name in self._OPTION_COUNT_VALID: | ||
| option = pc.CountOptions(mode="only_valid") | ||
| elif function_name in self._OPTION_SCALAR: | ||
| option = pc.ScalarAggregateOptions(min_count=0) | ||
| elif function_name in self._OPTION_ORDERED: | ||
| grouped, option = self._ordered_agg(grouped, function_name) | ||
| return grouped, self._remap_expr_name(function_name), option | ||
|
Comment on lines
+74
to
+90
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible follow-upDo another pass on this, since it was written before the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Important I've redone all of this in a much cleaner way, also solving (https://github.com/narwhals-dev/narwhals/pull/2528/files#r2226587107) in I should be able to upstream some version of that to It works on all versions of |
||
|
|
||
| def _ordered_agg( | ||
| self, grouped: pa.TableGroupBy, name: NarwhalsAggregation, / | ||
| ) -> tuple[pa.TableGroupBy, AggregateOptions]: | ||
| """The default behavior of `pyarrow` raises when `first` or `last` are used. | ||
|
|
||
| You'd see an error like: | ||
|
|
||
| ArrowNotImplementedError: Using ordered aggregator in multiple threaded execution is not supported | ||
|
|
||
| We need to **disable** multi-threading to use them, but the ability to do so | ||
| wasn't possible before `14.0.0` ([pyarrow-36709]) | ||
|
|
||
| [pyarrow-36709]: https://github.com/apache/arrow/issues/36709 | ||
| """ | ||
| backend_version = self.compliant._backend_version | ||
| if backend_version >= (14, 0) and grouped._use_threads: | ||
| native = self.compliant.native | ||
| grouped = pa.TableGroupBy(native, grouped.keys, use_threads=False) | ||
| elif backend_version < (14, 0): # pragma: no cover | ||
FBruzzesi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| msg = ( | ||
| f"Using `{name}()` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', " | ||
| f"found version {requires._unparse_version(backend_version)!r}.\n\n" | ||
| f"See https://github.com/apache/arrow/issues/36709" | ||
| ) | ||
| raise NotImplementedError(msg) | ||
dangotbanned marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return grouped, pc.ScalarAggregateOptions(skip_nulls=False) | ||
|
|
||
| def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: | ||
| self._ensure_all_simple(exprs) | ||
| aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = [] | ||
| expected_pyarrow_column_names: list[str] = self._keys.copy() | ||
| new_column_names: list[str] = self._keys.copy() | ||
| exclude = (*self._keys, *self._output_key_names) | ||
| grouped = self._grouped | ||
|
|
||
| for expr in exprs: | ||
| output_names, aliases = evaluate_output_names_and_aliases( | ||
|
|
@@ -83,20 +140,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: | |
| aggs.append((self._keys[0], "count", pc.CountOptions(mode="all"))) | ||
| continue | ||
|
|
||
| function_name = self._leaf_name(expr) | ||
| if function_name in {"std", "var"}: | ||
| assert "ddof" in expr._scalar_kwargs # noqa: S101 | ||
| option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"]) | ||
| elif function_name in {"len", "n_unique"}: | ||
| option = pc.CountOptions(mode="all") | ||
| elif function_name == "count": | ||
| option = pc.CountOptions(mode="only_valid") | ||
| elif function_name in {"all", "any"}: | ||
| option = pc.ScalarAggregateOptions(min_count=0) | ||
| else: | ||
| option = None | ||
|
|
||
| function_name = self._remap_expr_name(function_name) | ||
| grouped, function_name, option = self._configure_agg(grouped, expr) | ||
| new_column_names.extend(aliases) | ||
| expected_pyarrow_column_names.extend( | ||
| [f"{output_name}_{function_name}" for output_name in output_names] | ||
|
|
@@ -105,7 +149,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: | |
| [(output_name, function_name, option) for output_name in output_names] | ||
| ) | ||
|
|
||
| result_simple = self._grouped.aggregate(aggs) | ||
| result_simple = grouped.aggregate(aggs) | ||
|
|
||
| # Rename columns, being very careful | ||
| expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.