Skip to content

Commit 1f5c633

Browse files
dcherianmathause
andauthored
Refactor dataset groupby tests (#5506)
Co-authored-by: Mathias Hauser <[email protected]>
1 parent 6a101a9 commit 1f5c633

File tree

3 files changed

+206
-195
lines changed

3 files changed

+206
-195
lines changed

xarray/tests/__init__.py

+29
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from unittest import mock # noqa: F401
77

88
import numpy as np
9+
import pandas as pd
910
import pytest
1011
from numpy.testing import assert_array_equal # noqa: F401
1112
from pandas.testing import assert_frame_equal # noqa: F401
1213

1314
import xarray.testing
15+
from xarray import Dataset
1416
from xarray.core import utils
1517
from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401
1618
from xarray.core.indexing import ExplicitlyIndexed
@@ -200,3 +202,30 @@ def assert_allclose(a, b, **kwargs):
200202
xarray.testing.assert_allclose(a, b, **kwargs)
201203
xarray.testing._assert_internal_invariants(a)
202204
xarray.testing._assert_internal_invariants(b)
205+
206+
207+
def create_test_data(seed=None, add_attrs=True):
208+
rs = np.random.RandomState(seed)
209+
_vars = {
210+
"var1": ["dim1", "dim2"],
211+
"var2": ["dim1", "dim2"],
212+
"var3": ["dim3", "dim1"],
213+
}
214+
_dims = {"dim1": 8, "dim2": 9, "dim3": 10}
215+
216+
obj = Dataset()
217+
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
218+
obj["dim3"] = ("dim3", list("abcdefghij"))
219+
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
220+
for v, dims in sorted(_vars.items()):
221+
data = rs.normal(size=tuple(_dims[d] for d in dims))
222+
obj[v] = (dims, data)
223+
if add_attrs:
224+
obj[v].attrs = {"foo": "variable"}
225+
obj.coords["numbers"] = (
226+
"dim3",
227+
np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"),
228+
)
229+
obj.encoding = {"foo": "bar"}
230+
assert all(obj.data.flags.writeable for obj in obj.variables.values())
231+
return obj

xarray/tests/test_dataset.py

+1-194
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
assert_array_equal,
4040
assert_equal,
4141
assert_identical,
42+
create_test_data,
4243
has_cftime,
4344
has_dask,
4445
requires_bottleneck,
@@ -62,33 +63,6 @@
6263
]
6364

6465

65-
def create_test_data(seed=None, add_attrs=True):
66-
rs = np.random.RandomState(seed)
67-
_vars = {
68-
"var1": ["dim1", "dim2"],
69-
"var2": ["dim1", "dim2"],
70-
"var3": ["dim3", "dim1"],
71-
}
72-
_dims = {"dim1": 8, "dim2": 9, "dim3": 10}
73-
74-
obj = Dataset()
75-
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
76-
obj["dim3"] = ("dim3", list("abcdefghij"))
77-
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
78-
for v, dims in sorted(_vars.items()):
79-
data = rs.normal(size=tuple(_dims[d] for d in dims))
80-
obj[v] = (dims, data)
81-
if add_attrs:
82-
obj[v].attrs = {"foo": "variable"}
83-
obj.coords["numbers"] = (
84-
"dim3",
85-
np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"),
86-
)
87-
obj.encoding = {"foo": "bar"}
88-
assert all(obj.data.flags.writeable for obj in obj.variables.values())
89-
return obj
90-
91-
9266
def create_append_test_data(seed=None):
9367
rs = np.random.RandomState(seed)
9468

@@ -3785,173 +3759,6 @@ def test_squeeze_drop(self):
37853759
selected = data.squeeze(drop=True)
37863760
assert_identical(data, selected)
37873761

3788-
def test_groupby(self):
3789-
data = Dataset(
3790-
{"z": (["x", "y"], np.random.randn(3, 5))},
3791-
{"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)},
3792-
)
3793-
groupby = data.groupby("x")
3794-
assert len(groupby) == 3
3795-
expected_groups = {"a": 0, "b": 1, "c": 2}
3796-
assert groupby.groups == expected_groups
3797-
expected_items = [
3798-
("a", data.isel(x=0)),
3799-
("b", data.isel(x=1)),
3800-
("c", data.isel(x=2)),
3801-
]
3802-
for actual, expected in zip(groupby, expected_items):
3803-
assert actual[0] == expected[0]
3804-
assert_equal(actual[1], expected[1])
3805-
3806-
def identity(x):
3807-
return x
3808-
3809-
for k in ["x", "c", "y"]:
3810-
actual = data.groupby(k, squeeze=False).map(identity)
3811-
assert_equal(data, actual)
3812-
3813-
def test_groupby_returns_new_type(self):
3814-
data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))})
3815-
3816-
actual = data.groupby("x").map(lambda ds: ds["z"])
3817-
expected = data["z"]
3818-
assert_identical(expected, actual)
3819-
3820-
actual = data["z"].groupby("x").map(lambda x: x.to_dataset())
3821-
expected = data
3822-
assert_identical(expected, actual)
3823-
3824-
def test_groupby_iter(self):
3825-
data = create_test_data()
3826-
for n, (t, sub) in enumerate(list(data.groupby("dim1"))[:3]):
3827-
assert data["dim1"][n] == t
3828-
assert_equal(data["var1"][n], sub["var1"])
3829-
assert_equal(data["var2"][n], sub["var2"])
3830-
assert_equal(data["var3"][:, n], sub["var3"])
3831-
3832-
def test_groupby_errors(self):
3833-
data = create_test_data()
3834-
with pytest.raises(TypeError, match=r"`group` must be"):
3835-
data.groupby(np.arange(10))
3836-
with pytest.raises(ValueError, match=r"length does not match"):
3837-
data.groupby(data["dim1"][:3])
3838-
with pytest.raises(TypeError, match=r"`group` must be"):
3839-
data.groupby(data.coords["dim1"].to_index())
3840-
3841-
def test_groupby_reduce(self):
3842-
data = Dataset(
3843-
{
3844-
"xy": (["x", "y"], np.random.randn(3, 4)),
3845-
"xonly": ("x", np.random.randn(3)),
3846-
"yonly": ("y", np.random.randn(4)),
3847-
"letters": ("y", ["a", "a", "b", "b"]),
3848-
}
3849-
)
3850-
3851-
expected = data.mean("y")
3852-
expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3})
3853-
actual = data.groupby("x").mean(...)
3854-
assert_allclose(expected, actual)
3855-
3856-
actual = data.groupby("x").mean("y")
3857-
assert_allclose(expected, actual)
3858-
3859-
letters = data["letters"]
3860-
expected = Dataset(
3861-
{
3862-
"xy": data["xy"].groupby(letters).mean(...),
3863-
"xonly": (data["xonly"].mean().variable.set_dims({"letters": 2})),
3864-
"yonly": data["yonly"].groupby(letters).mean(),
3865-
}
3866-
)
3867-
actual = data.groupby("letters").mean(...)
3868-
assert_allclose(expected, actual)
3869-
3870-
def test_groupby_math(self):
3871-
def reorder_dims(x):
3872-
return x.transpose("dim1", "dim2", "dim3", "time")
3873-
3874-
ds = create_test_data()
3875-
ds["dim1"] = ds["dim1"]
3876-
for squeeze in [True, False]:
3877-
grouped = ds.groupby("dim1", squeeze=squeeze)
3878-
3879-
expected = reorder_dims(ds + ds.coords["dim1"])
3880-
actual = grouped + ds.coords["dim1"]
3881-
assert_identical(expected, reorder_dims(actual))
3882-
3883-
actual = ds.coords["dim1"] + grouped
3884-
assert_identical(expected, reorder_dims(actual))
3885-
3886-
ds2 = 2 * ds
3887-
expected = reorder_dims(ds + ds2)
3888-
actual = grouped + ds2
3889-
assert_identical(expected, reorder_dims(actual))
3890-
3891-
actual = ds2 + grouped
3892-
assert_identical(expected, reorder_dims(actual))
3893-
3894-
grouped = ds.groupby("numbers")
3895-
zeros = DataArray([0, 0, 0, 0], [("numbers", range(4))])
3896-
expected = (ds + Variable("dim3", np.zeros(10))).transpose(
3897-
"dim3", "dim1", "dim2", "time"
3898-
)
3899-
actual = grouped + zeros
3900-
assert_equal(expected, actual)
3901-
3902-
actual = zeros + grouped
3903-
assert_equal(expected, actual)
3904-
3905-
with pytest.raises(ValueError, match=r"incompat.* grouped binary"):
3906-
grouped + ds
3907-
with pytest.raises(ValueError, match=r"incompat.* grouped binary"):
3908-
ds + grouped
3909-
with pytest.raises(TypeError, match=r"only support binary ops"):
3910-
grouped + 1
3911-
with pytest.raises(TypeError, match=r"only support binary ops"):
3912-
grouped + grouped
3913-
with pytest.raises(TypeError, match=r"in-place operations"):
3914-
ds += grouped
3915-
3916-
ds = Dataset(
3917-
{
3918-
"x": ("time", np.arange(100)),
3919-
"time": pd.date_range("2000-01-01", periods=100),
3920-
}
3921-
)
3922-
with pytest.raises(ValueError, match=r"incompat.* grouped binary"):
3923-
ds + ds.groupby("time.month")
3924-
3925-
def test_groupby_math_virtual(self):
3926-
ds = Dataset(
3927-
{"x": ("t", [1, 2, 3])}, {"t": pd.date_range("20100101", periods=3)}
3928-
)
3929-
grouped = ds.groupby("t.day")
3930-
actual = grouped - grouped.mean(...)
3931-
expected = Dataset({"x": ("t", [0, 0, 0])}, ds[["t", "t.day"]])
3932-
assert_identical(actual, expected)
3933-
3934-
def test_groupby_nan(self):
3935-
# nan should be excluded from groupby
3936-
ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, {"bar": ("x", [1, 1, 2, np.nan])})
3937-
actual = ds.groupby("bar").mean(...)
3938-
expected = Dataset({"foo": ("bar", [1.5, 3]), "bar": [1, 2]})
3939-
assert_identical(actual, expected)
3940-
3941-
def test_groupby_order(self):
3942-
# groupby should preserve variables order
3943-
ds = Dataset()
3944-
for vn in ["a", "b", "c"]:
3945-
ds[vn] = DataArray(np.arange(10), dims=["t"])
3946-
data_vars_ref = list(ds.data_vars.keys())
3947-
ds = ds.groupby("t").mean(...)
3948-
data_vars = list(ds.data_vars.keys())
3949-
assert data_vars == data_vars_ref
3950-
# coords are now at the end of the list, so the test below fails
3951-
# all_vars = list(ds.variables.keys())
3952-
# all_vars_ref = list(ds.variables.keys())
3953-
# self.assertEqual(all_vars, all_vars_ref)
3954-
39553762
def test_resample_and_first(self):
39563763
times = pd.date_range("2000-01-01", freq="6H", periods=10)
39573764
ds = Dataset(

0 commit comments

Comments
 (0)