-
Notifications
You must be signed in to change notification settings - Fork 172
fix: dask group by with kwargs #1676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eeee16f
cc1fe57
5810ce5
c5c4d91
acfd2d8
90c5583
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,14 +36,38 @@ def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int: | |
| ) | ||
|
|
||
|
|
||
| def var( | ||
| ddof: int = 1, | ||
| ) -> Callable[ | ||
| [pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy | ||
| ]: | ||
| from functools import partial | ||
|
|
||
| import dask_expr as dx | ||
|
|
||
| return partial(dx._groupby.GroupBy.var, ddof=ddof) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The minimum version we support for dask is
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @phofl is there a public variant of this? |
||
|
|
||
|
|
||
| def std( | ||
| ddof: int = 1, | ||
| ) -> Callable[ | ||
| [pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy | ||
| ]: | ||
| from functools import partial | ||
|
|
||
| import dask_expr as dx | ||
|
|
||
| return partial(dx._groupby.GroupBy.std, ddof=ddof) | ||
|
|
||
|
|
||
| POLARS_TO_DASK_AGGREGATIONS = { | ||
| "sum": "sum", | ||
| "mean": "mean", | ||
| "median": "median", | ||
| "max": "max", | ||
| "min": "min", | ||
| "std": "std", | ||
| "var": "var", | ||
| "std": std, | ||
| "var": var, | ||
| "len": "size", | ||
| "n_unique": n_unique, | ||
| "count": "count", | ||
|
|
@@ -137,8 +161,12 @@ def agg_dask( | |
| function_name = POLARS_TO_DASK_AGGREGATIONS.get( | ||
| expr._function_name, expr._function_name | ||
| ) | ||
| for output_name in expr._output_names: | ||
| simple_aggregations[output_name] = (keys[0], function_name) | ||
| simple_aggregations.update( | ||
| { | ||
| output_name: (keys[0], function_name) | ||
| for output_name in expr._output_names | ||
| } | ||
| ) | ||
| continue | ||
|
|
||
| # e.g. agg(nw.mean('a')) # noqa: ERA001 | ||
|
|
@@ -149,13 +177,26 @@ def agg_dask( | |
| raise AssertionError(msg) | ||
|
|
||
| function_name = remove_prefix(expr._function_name, "col->") | ||
| function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) | ||
| kwargs = ( | ||
| {"ddof": expr._kwargs.get("ddof", 1)} | ||
| if function_name in {"std", "var"} | ||
| else {} | ||
| ) | ||
|
|
||
| agg_function = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) | ||
| # deal with n_unique case in a "lazy" mode to not depend on dask globally | ||
| function_name = function_name() if callable(function_name) else function_name | ||
|
|
||
| for root_name, output_name in zip(expr._root_names, expr._output_names): | ||
| simple_aggregations[output_name] = (root_name, function_name) | ||
| agg_function = ( | ||
| agg_function(**kwargs) if callable(agg_function) else agg_function | ||
| ) | ||
|
|
||
| simple_aggregations.update( | ||
| { | ||
| output_name: (root_name, agg_function) | ||
| for root_name, output_name in zip( | ||
| expr._root_names, expr._output_names | ||
| ) | ||
| } | ||
| ) | ||
| result_simple = grouped.agg(**simple_aggregations) | ||
| return from_dataframe(result_simple.reset_index()) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to double check the return type