Skip to content

Commit 675a3ff

Browse files
slevangdcherian
andauthored
Fix coordinate attr handling in xr.where(..., keep_attrs=True) (#7229)
* better tests, use modified attrs[1] * add whats new * update keep_attrs docstring * cast to DataArray * whats-new * fix whats new * Update doc/whats-new.rst * rebuild attrs after apply_ufunc * fix mypy * better comment Co-authored-by: Deepak Cherian <[email protected]>
1 parent 2fb22cf commit 675a3ff

File tree

3 files changed

+79
-12
lines changed

3 files changed

+79
-12
lines changed

doc/whats-new.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ Deprecations
5757

5858
Bug fixes
5959
~~~~~~~~~
60-
60+
- Fix handling of coordinate attributes in :py:func:`where`. (:issue:`7220`, :pull:`7229`)
61+
By `Sam Levang <https://github.com/slevang>`_.
6162
- Import ``nc_time_axis`` when needed (:issue:`7275`, :pull:`7276`).
6263
By `Michael Niklas <https://github.com/headtr1ck>`_.
6364
- Fix static typing of :py:meth:`xr.polyval` (:issue:`7312`, :pull:`7315`).

xarray/core/computation.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -1855,15 +1855,13 @@ def where(cond, x, y, keep_attrs=None):
18551855
Dataset.where, DataArray.where :
18561856
equivalent methods
18571857
"""
1858+
from .dataset import Dataset
1859+
18581860
if keep_attrs is None:
18591861
keep_attrs = _get_keep_attrs(default=False)
1860-
if keep_attrs is True:
1861-
# keep the attributes of x, the second parameter, by default to
1862-
# be consistent with the `where` method of `DataArray` and `Dataset`
1863-
keep_attrs = lambda attrs, context: getattr(x, "attrs", {})
18641862

18651863
# alignment for three arguments is complicated, so don't support it yet
1866-
return apply_ufunc(
1864+
result = apply_ufunc(
18671865
duck_array_ops.where,
18681866
cond,
18691867
x,
@@ -1874,6 +1872,27 @@ def where(cond, x, y, keep_attrs=None):
18741872
keep_attrs=keep_attrs,
18751873
)
18761874

1875+
# keep the attributes of x, the second parameter, by default to
1876+
# be consistent with the `where` method of `DataArray` and `Dataset`
1877+
# rebuild the attrs from x at each level of the output, which could be
1878+
# Dataset, DataArray, or Variable, and also handle coords
1879+
if keep_attrs is True:
1880+
if isinstance(y, Dataset) and not isinstance(x, Dataset):
1881+
# handle special case where x gets promoted to Dataset
1882+
result.attrs = {}
1883+
if getattr(x, "name", None) in result.data_vars:
1884+
result[x.name].attrs = getattr(x, "attrs", {})
1885+
else:
1886+
# otherwise, fill in global attrs and variable attrs (if they exist)
1887+
result.attrs = getattr(x, "attrs", {})
1888+
for v in getattr(result, "data_vars", []):
1889+
result[v].attrs = getattr(getattr(x, v, None), "attrs", {})
1890+
for c in getattr(result, "coords", []):
1891+
# always fill coord attrs of x
1892+
result[c].attrs = getattr(getattr(x, c, None), "attrs", {})
1893+
1894+
return result
1895+
18771896

18781897
@overload
18791898
def polyval(

xarray/tests/test_computation.py

+53-6
Original file line numberDiff line numberDiff line change
@@ -1925,16 +1925,63 @@ def test_where() -> None:
19251925

19261926

19271927
def test_where_attrs() -> None:
1928-
cond = xr.DataArray([True, False], dims="x", attrs={"attr": "cond"})
1929-
x = xr.DataArray([1, 1], dims="x", attrs={"attr": "x"})
1930-
y = xr.DataArray([0, 0], dims="x", attrs={"attr": "y"})
1928+
cond = xr.DataArray([True, False], coords={"a": [0, 1]}, attrs={"attr": "cond_da"})
1929+
cond["a"].attrs = {"attr": "cond_coord"}
1930+
x = xr.DataArray([1, 1], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
1931+
x["a"].attrs = {"attr": "x_coord"}
1932+
y = xr.DataArray([0, 0], coords={"a": [0, 1]}, attrs={"attr": "y_da"})
1933+
y["a"].attrs = {"attr": "y_coord"}
1934+
1935+
# 3 DataArrays, takes attrs from x
19311936
actual = xr.where(cond, x, y, keep_attrs=True)
1932-
expected = xr.DataArray([1, 0], dims="x", attrs={"attr": "x"})
1937+
expected = xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
1938+
expected["a"].attrs = {"attr": "x_coord"}
19331939
assert_identical(expected, actual)
19341940

1935-
# ensure keep_attrs can handle scalar values
1941+
# x as a scalar, takes no attrs
1942+
actual = xr.where(cond, 0, y, keep_attrs=True)
1943+
expected = xr.DataArray([0, 0], coords={"a": [0, 1]})
1944+
assert_identical(expected, actual)
1945+
1946+
# y as a scalar, takes attrs from x
1947+
actual = xr.where(cond, x, 0, keep_attrs=True)
1948+
expected = xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
1949+
expected["a"].attrs = {"attr": "x_coord"}
1950+
assert_identical(expected, actual)
1951+
1952+
# x and y as a scalar, takes no attrs
19361953
actual = xr.where(cond, 1, 0, keep_attrs=True)
1937-
assert actual.attrs == {}
1954+
expected = xr.DataArray([1, 0], coords={"a": [0, 1]})
1955+
assert_identical(expected, actual)
1956+
1957+
# cond and y as a scalar, takes attrs from x
1958+
actual = xr.where(True, x, y, keep_attrs=True)
1959+
expected = xr.DataArray([1, 1], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
1960+
expected["a"].attrs = {"attr": "x_coord"}
1961+
assert_identical(expected, actual)
1962+
1963+
# DataArray and 2 Datasets, takes attrs from x
1964+
ds_x = xr.Dataset(data_vars={"x": x}, attrs={"attr": "x_ds"})
1965+
ds_y = xr.Dataset(data_vars={"x": y}, attrs={"attr": "y_ds"})
1966+
ds_actual = xr.where(cond, ds_x, ds_y, keep_attrs=True)
1967+
ds_expected = xr.Dataset(
1968+
data_vars={
1969+
"x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
1970+
},
1971+
attrs={"attr": "x_ds"},
1972+
)
1973+
ds_expected["a"].attrs = {"attr": "x_coord"}
1974+
assert_identical(ds_expected, ds_actual)
1975+
1976+
# 2 DataArrays and 1 Dataset, takes attrs from x
1977+
ds_actual = xr.where(cond, x.rename("x"), ds_y, keep_attrs=True)
1978+
ds_expected = xr.Dataset(
1979+
data_vars={
1980+
"x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"})
1981+
},
1982+
)
1983+
ds_expected["a"].attrs = {"attr": "x_coord"}
1984+
assert_identical(ds_expected, ds_actual)
19381985

19391986

19401987
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)