diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 9b4b5f6a0b..8bd3588c4c 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -443,26 +443,20 @@ def max(self) -> Self: ) def std(self, ddof: int) -> Self: - expr = self._from_call( + return self._from_call( lambda _input, ddof: _input.std(ddof=ddof), "std", ddof=ddof, returns_scalar=True, ) - if ddof != 1: - expr._depth += 1 - return expr def var(self, ddof: int) -> Self: - expr = self._from_call( + return self._from_call( lambda _input, ddof: _input.var(ddof=ddof), "var", ddof=ddof, returns_scalar=True, ) - if ddof != 1: - expr._depth += 1 - return expr def skew(self: Self) -> Self: return self._from_call( diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index af269abd78..7bda88ee5d 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -36,14 +36,38 @@ def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int: ) +def var( + ddof: int = 1, +) -> Callable[ + [pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy +]: + from functools import partial + + import dask_expr as dx + + return partial(dx._groupby.GroupBy.var, ddof=ddof) + + +def std( + ddof: int = 1, +) -> Callable[ + [pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy +]: + from functools import partial + + import dask_expr as dx + + return partial(dx._groupby.GroupBy.std, ddof=ddof) + + POLARS_TO_DASK_AGGREGATIONS = { "sum": "sum", "mean": "mean", "median": "median", "max": "max", "min": "min", - "std": "std", - "var": "var", + "std": std, + "var": var, "len": "size", "n_unique": n_unique, "count": "count", @@ -137,8 +161,12 @@ def agg_dask( function_name = POLARS_TO_DASK_AGGREGATIONS.get( expr._function_name, expr._function_name ) - for output_name in expr._output_names: - simple_aggregations[output_name] = (keys[0], function_name) + simple_aggregations.update( + { + output_name: (keys[0], function_name) + for output_name in expr._output_names + } + ) continue # e.g. agg(nw.mean('a')) # noqa: ERA001 @@ -149,13 +177,26 @@ def agg_dask( raise AssertionError(msg) function_name = remove_prefix(expr._function_name, "col->") - function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) + kwargs = ( + {"ddof": expr._kwargs.get("ddof", 1)} + if function_name in {"std", "var"} + else {} + ) + agg_function = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) # deal with n_unique case in a "lazy" mode to not depend on dask globally - function_name = function_name() if callable(function_name) else function_name - - for root_name, output_name in zip(expr._root_names, expr._output_names): - simple_aggregations[output_name] = (root_name, function_name) + agg_function = ( + agg_function(**kwargs) if callable(agg_function) else agg_function + ) + + simple_aggregations.update( + { + output_name: (root_name, agg_function) + for root_name, output_name in zip( + expr._root_names, expr._output_names + ) + } + ) result_simple = grouped.agg(**simple_aggregations) return from_dataframe(result_simple.reset_index()) diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 3c57ce0276..22c3b6f195 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -137,12 +137,7 @@ def test_group_by_depth_1_std_var( constructor: Constructor, attr: str, ddof: int, - request: pytest.FixtureRequest, ) -> None: - if "dask" in str(constructor): - # Complex aggregation for dask - request.applymarker(pytest.mark.xfail) - data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} _pow = 0.5 if attr == "std" else 1 expected = {