Skip to content

Commit 2e93d54

Browse files
dcherianIllviljankeewis
authored
Vectorize groupby binary ops (#6160)
Co-authored-by: Illviljan <[email protected]> Co-authored-by: keewis <[email protected]>
1 parent d8fc346 commit 2e93d54

File tree

3 files changed

+99
-38
lines changed

3 files changed

+99
-38
lines changed

doc/whats-new.rst

+6
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ Bug fixes
7272
Documentation
7373
~~~~~~~~~~~~~
7474

75+
Performance
76+
~~~~~~~~~~~
77+
78+
- GroupBy binary operations are now vectorized.
79+
Previously this involved looping over all groups. (:issue:`5804`,:pull:`6160`)
80+
By `Deepak Cherian <https://github.com/dcherian>`_.
7581

7682
Internal Changes
7783
~~~~~~~~~~~~~~~~

xarray/core/groupby.py

+66-24
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ class GroupBy:
264264
"_stacked_dim",
265265
"_unique_coord",
266266
"_dims",
267+
"_bins",
267268
)
268269

269270
def __init__(
@@ -401,6 +402,7 @@ def __init__(
401402
self._inserted_dims = inserted_dims
402403
self._full_index = full_index
403404
self._restore_coord_dims = restore_coord_dims
405+
self._bins = bins
404406

405407
# cached attributes
406408
self._groups = None
@@ -478,35 +480,75 @@ def _infer_concat_args(self, applied_example):
478480
return coord, dim, positions
479481

480482
def _binary_op(self, other, f, reflexive=False):
483+
from .dataarray import DataArray
484+
from .dataset import Dataset
485+
481486
g = f if not reflexive else lambda x, y: f(y, x)
482-
applied = self._yield_binary_applied(g, other)
483-
return self._combine(applied)
484487

485-
def _yield_binary_applied(self, func, other):
486-
dummy = None
488+
obj = self._obj
489+
group = self._group
490+
dim = self._group_dim
491+
if isinstance(group, _DummyGroup):
492+
group = obj[dim]
493+
name = group.name
494+
495+
if not isinstance(other, (Dataset, DataArray)):
496+
raise TypeError(
497+
"GroupBy objects only support binary ops "
498+
"when the other argument is a Dataset or "
499+
"DataArray"
500+
)
487501

488-
for group_value, obj in self:
489-
try:
490-
other_sel = other.sel(**{self._group.name: group_value})
491-
except AttributeError:
492-
raise TypeError(
493-
"GroupBy objects only support binary ops "
494-
"when the other argument is a Dataset or "
495-
"DataArray"
496-
)
497-
except (KeyError, ValueError):
498-
if self._group.name not in other.dims:
499-
raise ValueError(
500-
"incompatible dimensions for a grouped "
501-
f"binary operation: the group variable {self._group.name!r} "
502-
"is not a dimension on the other argument"
502+
if name not in other.dims:
503+
raise ValueError(
504+
"incompatible dimensions for a grouped "
505+
f"binary operation: the group variable {name!r} "
506+
"is not a dimension on the other argument"
507+
)
508+
509+
try:
510+
expanded = other.sel({name: group})
511+
except KeyError:
512+
# some labels are absent i.e. other is not aligned
513+
# so we align by reindexing and then rename dimensions.
514+
515+
# Broadcast out scalars for backwards compatibility
516+
# TODO: get rid of this when fixing GH2145
517+
for var in other.coords:
518+
if other[var].ndim == 0:
519+
other[var] = (
520+
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
503521
)
504-
if dummy is None:
505-
dummy = _dummy_copy(other)
506-
other_sel = dummy
522+
expanded = (
523+
other.reindex({name: group.data})
524+
.rename({name: dim})
525+
.assign_coords({dim: obj[dim]})
526+
)
507527

508-
result = func(obj, other_sel)
509-
yield result
528+
if self._bins is not None and name == dim and dim not in obj.xindexes:
529+
# When binning by unindexed coordinate we need to reindex obj.
530+
# _full_index is IntervalIndex, so idx will be -1 where
531+
# a value does not belong to any bin. Using IntervalIndex
532+
# accounts for any non-default cut_kwargs passed to the constructor
533+
idx = pd.cut(group, bins=self._full_index).codes
534+
obj = obj.isel({dim: np.arange(group.size)[idx != -1]})
535+
536+
result = g(obj, expanded)
537+
538+
result = self._maybe_unstack(result)
539+
group = self._maybe_unstack(group)
540+
if group.ndim > 1:
541+
# backcompat:
542+
# TODO: get rid of this when fixing GH2145
543+
for var in set(obj.coords) - set(obj.xindexes):
544+
if set(obj[var].dims) < set(group.dims):
545+
result[var] = obj[var].reset_coords(drop=True).broadcast_like(group)
546+
547+
if isinstance(result, Dataset) and isinstance(obj, Dataset):
548+
for var in set(result):
549+
if dim not in obj[var].dims:
550+
result[var] = result[var].transpose(dim, ...)
551+
return result
510552

511553
def _maybe_restore_empty_groups(self, combined):
512554
"""Our index contained empty groups (e.g., from a resampling). If we

xarray/tests/test_groupby.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,17 @@ def test_groupby_dataset_math_virtual() -> None:
857857
assert_identical(actual, expected)
858858

859859

860+
def test_groupby_math_dim_order() -> None:
861+
da = DataArray(
862+
np.ones((10, 10, 12)),
863+
dims=("x", "y", "time"),
864+
coords={"time": pd.date_range("2001-01-01", periods=12, freq="6H")},
865+
)
866+
grouped = da.groupby("time.day")
867+
result = grouped - grouped.mean()
868+
assert result.dims == da.dims
869+
870+
860871
def test_groupby_dataset_nan() -> None:
861872
# nan should be excluded from groupby
862873
ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, {"bar": ("x", [1, 1, 2, np.nan])})
@@ -1155,26 +1166,28 @@ def change_metadata(x):
11551166
expected = change_metadata(expected)
11561167
assert_equal(expected, actual)
11571168

1158-
def test_groupby_math(self):
1169+
@pytest.mark.parametrize("squeeze", [True, False])
1170+
def test_groupby_math_squeeze(self, squeeze):
11591171
array = self.da
1160-
for squeeze in [True, False]:
1161-
grouped = array.groupby("x", squeeze=squeeze)
1172+
grouped = array.groupby("x", squeeze=squeeze)
11621173

1163-
expected = array + array.coords["x"]
1164-
actual = grouped + array.coords["x"]
1165-
assert_identical(expected, actual)
1174+
expected = array + array.coords["x"]
1175+
actual = grouped + array.coords["x"]
1176+
assert_identical(expected, actual)
11661177

1167-
actual = array.coords["x"] + grouped
1168-
assert_identical(expected, actual)
1178+
actual = array.coords["x"] + grouped
1179+
assert_identical(expected, actual)
11691180

1170-
ds = array.coords["x"].to_dataset(name="X")
1171-
expected = array + ds
1172-
actual = grouped + ds
1173-
assert_identical(expected, actual)
1181+
ds = array.coords["x"].to_dataset(name="X")
1182+
expected = array + ds
1183+
actual = grouped + ds
1184+
assert_identical(expected, actual)
11741185

1175-
actual = ds + grouped
1176-
assert_identical(expected, actual)
1186+
actual = ds + grouped
1187+
assert_identical(expected, actual)
11771188

1189+
def test_groupby_math(self):
1190+
array = self.da
11781191
grouped = array.groupby("abc")
11791192
expected_agg = (grouped.mean(...) - np.arange(3)).rename(None)
11801193
actual = grouped - DataArray(range(3), [("abc", ["a", "b", "c"])])

0 commit comments

Comments
 (0)