Skip to content
forked from pydata/xarray

Commit a848044

Browse files
authored
Use flox for grouped first, last (pydata#9986)
1 parent 99dfffa commit a848044

File tree

4 files changed

+44
-7
lines changed

4 files changed

+44
-7
lines changed

Diff for: doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ New Features
5454
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_ and `Spencer Clark <https://github.com/spencerkclark>`_.
5555
- Improve the error message raised when no key is matching the available variables in a dataset. (:pull:`9943`)
5656
By `Jimmy Westling <https://github.com/illviljan>`_.
57+
- :py:meth:`DatasetGroupBy.first` and :py:meth:`DatasetGroupBy.last` can now use ``flox`` if available. (:issue:`9647`)
58+
By `Deepak Cherian <https://github.com/dcherian>`_.
5759

5860
Breaking changes
5961
~~~~~~~~~~~~~~~~

Diff for: xarray/core/groupby.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -1357,7 +1357,12 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray:
13571357
"""
13581358
return ops.where_method(self, cond, other)
13591359

1360-
def _first_or_last(self, op, skipna, keep_attrs):
1360+
def _first_or_last(
1361+
self,
1362+
op: Literal["first" | "last"],
1363+
skipna: bool | None,
1364+
keep_attrs: bool | None,
1365+
):
13611366
if all(
13621367
isinstance(maybe_slice, slice)
13631368
and (maybe_slice.stop == maybe_slice.start + 1)
@@ -1368,17 +1373,30 @@ def _first_or_last(self, op, skipna, keep_attrs):
13681373
return self._obj
13691374
if keep_attrs is None:
13701375
keep_attrs = _get_keep_attrs(default=True)
1371-
return self.reduce(
1372-
op, dim=[self._group_dim], skipna=skipna, keep_attrs=keep_attrs
1373-
)
1376+
if (
1377+
module_available("flox", minversion="0.9.16")
1378+
and OPTIONS["use_flox"]
1379+
and contains_only_chunked_or_numpy(self._obj)
1380+
):
1381+
result, *_ = self._flox_reduce(
1382+
dim=None, func=op, skipna=skipna, keep_attrs=keep_attrs
1383+
)
1384+
else:
1385+
result = self.reduce(
1386+
getattr(duck_array_ops, op),
1387+
dim=[self._group_dim],
1388+
skipna=skipna,
1389+
keep_attrs=keep_attrs,
1390+
)
1391+
return result
13741392

13751393
def first(self, skipna: bool | None = None, keep_attrs: bool | None = None):
13761394
"""Return the first element of each group along the group dimension"""
1377-
return self._first_or_last(duck_array_ops.first, skipna, keep_attrs)
1395+
return self._first_or_last("first", skipna, keep_attrs)
13781396

13791397
def last(self, skipna: bool | None = None, keep_attrs: bool | None = None):
13801398
"""Return the last element of each group along the group dimension"""
1381-
return self._first_or_last(duck_array_ops.last, skipna, keep_attrs)
1399+
return self._first_or_last("last", skipna, keep_attrs)
13821400

13831401
def assign_coords(self, coords=None, **coords_kwargs):
13841402
"""Assign coordinates by group.

Diff for: xarray/core/resample.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from collections.abc import Callable, Hashable, Iterable, Sequence
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, Literal
66

77
from xarray.core._aggregations import (
88
DataArrayResampleAggregations,
@@ -103,6 +103,21 @@ def shuffle_to_chunks(self, chunks: T_Chunks = None):
103103
(grouper,) = self.groupers
104104
return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM)
105105

106+
def _first_or_last(
107+
self, op: Literal["first", "last"], skipna: bool | None, keep_attrs: bool | None
108+
) -> T_Xarray:
109+
from xarray.core.dataset import Dataset
110+
111+
result = super()._first_or_last(op=op, skipna=skipna, keep_attrs=keep_attrs)
112+
if isinstance(result, Dataset):
113+
# Can't do this in the base class because group_dim is RESAMPLE_DIM
114+
# which is not present in the original object
115+
for var in result.data_vars:
116+
result._variables[var] = result._variables[var].transpose(
117+
*self._obj._variables[var].dims
118+
)
119+
return result
120+
106121
def _drop_coords(self) -> T_Xarray:
107122
"""Drop non-dimension coordinates along the resampled dimension."""
108123
obj = self._obj

Diff for: xarray/tests/test_groupby.py

+2
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,8 @@ def test_groupby_first_and_last(self) -> None:
16181618
expected = array # should be a no-op
16191619
assert_identical(expected, actual)
16201620

1621+
# TODO: groupby_bins too
1622+
16211623
def make_groupby_multidim_example_array(self) -> DataArray:
16221624
return DataArray(
16231625
[[[0, 1], [2, 3]], [[5, 10], [15, 20]]],

0 commit comments

Comments
 (0)