Skip to content

Commit

Permalink
Refactor groupby binary ops code. (#6789)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored Jul 20, 2022
1 parent 392a614 commit 9b54b44
Showing 1 changed file with 33 additions and 38 deletions.
71 changes: 33 additions & 38 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from . import dtypes, duck_array_ops, nputils, ops
from ._reductions import DataArrayGroupByReductions, DatasetGroupByReductions
from .alignment import align
from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from .concat import concat
from .formatting import format_array_flat
Expand Down Expand Up @@ -309,7 +310,7 @@ class GroupBy(Generic[T_Xarray]):
"_squeeze",
# Save unstacked object for flox
"_original_obj",
"_unstacked_group",
"_original_group",
"_bins",
)
_obj: T_Xarray
Expand Down Expand Up @@ -374,7 +375,7 @@ def __init__(
group.name = "group"

self._original_obj: T_Xarray = obj
self._unstacked_group = group
self._original_group = group
self._bins = bins

group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj)
Expand Down Expand Up @@ -571,11 +572,22 @@ def _binary_op(self, other, f, reflexive=False):

g = f if not reflexive else lambda x, y: f(y, x)

obj = self._obj
group = self._group
dim = self._group_dim
if self._bins is None:
obj = self._original_obj
group = self._original_group
dims = group.dims
else:
obj = self._maybe_unstack(self._obj)
group = self._maybe_unstack(self._group)
dims = (self._group_dim,)

if isinstance(group, _DummyGroup):
group = obj[dim]
group = obj[group.name]
coord = group
else:
coord = self._unique_coord
if not isinstance(coord, DataArray):
coord = DataArray(self._unique_coord)
name = group.name

if not isinstance(other, (Dataset, DataArray)):
Expand All @@ -592,37 +604,19 @@ def _binary_op(self, other, f, reflexive=False):
"is not a dimension on the other argument"
)

try:
expanded = other.sel({name: group})
except KeyError:
# some labels are absent i.e. other is not aligned
# so we align by reindexing and then rename dimensions.

# Broadcast out scalars for backwards compatibility
# TODO: get rid of this when fixing GH2145
for var in other.coords:
if other[var].ndim == 0:
other[var] = (
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
)
expanded = (
other.reindex({name: group.data})
.rename({name: dim})
.assign_coords({dim: obj[dim]})
)
# Broadcast out scalars for backwards compatibility
# TODO: get rid of this when fixing GH2145
for var in other.coords:
if other[var].ndim == 0:
other[var] = (
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
)

if self._bins is not None and name == dim and dim not in obj.xindexes:
# When binning by unindexed coordinate we need to reindex obj.
# _full_index is IntervalIndex, so idx will be -1 where
# a value does not belong to any bin. Using IntervalIndex
# accounts for any non-default cut_kwargs passed to the constructor
idx = pd.cut(group, bins=self._full_index).codes
obj = obj.isel({dim: np.arange(group.size)[idx != -1]})
other, _ = align(other, coord, join="outer")
expanded = other.sel({name: group})

result = g(obj, expanded)

result = self._maybe_unstack(result)
group = self._maybe_unstack(group)
if group.ndim > 1:
# backcompat:
# TODO: get rid of this when fixing GH2145
Expand All @@ -632,8 +626,9 @@ def _binary_op(self, other, f, reflexive=False):

if isinstance(result, Dataset) and isinstance(obj, Dataset):
for var in set(result):
if dim not in obj[var].dims:
result[var] = result[var].transpose(dim, ...)
for d in dims:
if d not in obj[var].dims:
result[var] = result[var].transpose(d, ...)
return result

def _maybe_restore_empty_groups(self, combined):
Expand Down Expand Up @@ -695,10 +690,10 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs):
# group is only passed by resample
group = kwargs.pop("group", None)
if group is None:
if isinstance(self._unstacked_group, _DummyGroup):
group = self._unstacked_group.name
if isinstance(self._original_group, _DummyGroup):
group = self._original_group.name
else:
group = self._unstacked_group
group = self._original_group

unindexed_dims = tuple()
if isinstance(group, str):
Expand Down

0 comments on commit 9b54b44

Please sign in to comment.