Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,26 +443,20 @@ def max(self) -> Self:
)

def std(self, ddof: int) -> Self:
expr = self._from_call(
return self._from_call(
lambda _input, ddof: _input.std(ddof=ddof),
"std",
ddof=ddof,
returns_scalar=True,
)
if ddof != 1:
expr._depth += 1
return expr

def var(self, ddof: int) -> Self:
expr = self._from_call(
return self._from_call(
lambda _input, ddof: _input.var(ddof=ddof),
"var",
ddof=ddof,
returns_scalar=True,
)
if ddof != 1:
expr._depth += 1
return expr

def skew(self: Self) -> Self:
return self._from_call(
Expand Down
59 changes: 50 additions & 9 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,38 @@ def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:
)


def var(
ddof: int = 1,
) -> Callable[
[pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to double check the return type

]:
from functools import partial

import dask_expr as dx

return partial(dx._groupby.GroupBy.var, ddof=ddof)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The minimum version we support for dask is "dask[dataframe]==2024.7", which installs dask-expr==1.1.8, and dx._groupby.GroupBy is already available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phofl is there a public variant of this?



def std(
ddof: int = 1,
) -> Callable[
[pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy
]:
from functools import partial

import dask_expr as dx

return partial(dx._groupby.GroupBy.std, ddof=ddof)


POLARS_TO_DASK_AGGREGATIONS = {
"sum": "sum",
"mean": "mean",
"median": "median",
"max": "max",
"min": "min",
"std": "std",
"var": "var",
"std": std,
"var": var,
"len": "size",
"n_unique": n_unique,
"count": "count",
Expand Down Expand Up @@ -137,8 +161,12 @@ def agg_dask(
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
for output_name in expr._output_names:
simple_aggregations[output_name] = (keys[0], function_name)
simple_aggregations.update(
{
output_name: (keys[0], function_name)
for output_name in expr._output_names
}
)
continue

# e.g. agg(nw.mean('a')) # noqa: ERA001
Expand All @@ -149,13 +177,26 @@ def agg_dask(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
kwargs = (
{"ddof": expr._kwargs.get("ddof", 1)}
if function_name in {"std", "var"}
else {}
)

agg_function = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
# deal with n_unique case in a "lazy" mode to not depend on dask globally
function_name = function_name() if callable(function_name) else function_name

for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
agg_function = (
agg_function(**kwargs) if callable(agg_function) else agg_function
)

simple_aggregations.update(
{
output_name: (root_name, agg_function)
for root_name, output_name in zip(
expr._root_names, expr._output_names
)
}
)
result_simple = grouped.agg(**simple_aggregations)
return from_dataframe(result_simple.reset_index())

Expand Down
5 changes: 0 additions & 5 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,7 @@ def test_group_by_depth_1_std_var(
constructor: Constructor,
attr: str,
ddof: int,
request: pytest.FixtureRequest,
) -> None:
if "dask" in str(constructor):
# Complex aggregation for dask
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}
_pow = 0.5 if attr == "std" else 1
expected = {
Expand Down
Loading