diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index a006e54468a..fc3e1434684 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1152,21 +1152,33 @@ def test_groupby_count(self): expected = DataArray([1, 1, 2], coords=[("cat", ["a", "b", "c"])]) assert_identical(actual, expected) - @pytest.mark.skip("needs to be fixed for shortcut=False, keep_attrs=False") - def test_groupby_reduce_attrs(self): + @pytest.mark.parametrize("shortcut", [True, False]) + @pytest.mark.parametrize("keep_attrs", [None, True, False]) + def test_groupby_reduce_keep_attrs(self, shortcut, keep_attrs): + array = self.da + array.attrs["foo"] = "bar" + + actual = array.groupby("abc").reduce( + np.mean, keep_attrs=keep_attrs, shortcut=shortcut + ) + with xr.set_options(use_flox=False): + expected = array.groupby("abc").mean(keep_attrs=keep_attrs) + assert_identical(expected, actual) + + @pytest.mark.parametrize("keep_attrs", [None, True, False]) + def test_groupby_keep_attrs(self, keep_attrs): array = self.da array.attrs["foo"] = "bar" - for shortcut in [True, False]: - for keep_attrs in [True, False]: - print(f"shortcut={shortcut}, keep_attrs={keep_attrs}") - actual = array.groupby("abc").reduce( - np.mean, keep_attrs=keep_attrs, shortcut=shortcut - ) - expected = array.groupby("abc").mean() - if keep_attrs: - expected.attrs["foo"] = "bar" - assert_identical(expected, actual) + with xr.set_options(use_flox=False): + expected = array.groupby("abc").mean(keep_attrs=keep_attrs) + with xr.set_options(use_flox=True): + actual = array.groupby("abc").mean(keep_attrs=keep_attrs) + + # values are tested elsewhere, here we jsut check data + # TODO: add check_attrs kwarg to assert_allclose + actual.data = expected.data + assert_identical(expected, actual) def test_groupby_map_center(self): def center(x):