@@ -1925,16 +1925,63 @@ def test_where() -> None:
1925
1925
1926
1926
1927
1927
def test_where_attrs () -> None :
1928
- cond = xr .DataArray ([True , False ], dims = "x" , attrs = {"attr" : "cond" })
1929
- x = xr .DataArray ([1 , 1 ], dims = "x" , attrs = {"attr" : "x" })
1930
- y = xr .DataArray ([0 , 0 ], dims = "x" , attrs = {"attr" : "y" })
1928
+ cond = xr .DataArray ([True , False ], coords = {"a" : [0 , 1 ]}, attrs = {"attr" : "cond_da" })
1929
+ cond ["a" ].attrs = {"attr" : "cond_coord" }
1930
+ x = xr .DataArray ([1 , 1 ], coords = {"a" : [0 , 1 ]}, attrs = {"attr" : "x_da" })
1931
+ x ["a" ].attrs = {"attr" : "x_coord" }
1932
+ y = xr .DataArray ([0 , 0 ], coords = {"a" : [0 , 1 ]}, attrs = {"attr" : "y_da" })
1933
+ y ["a" ].attrs = {"attr" : "y_coord" }
1934
+
1935
+ # 3 DataArrays, takes attrs from x
1931
1936
actual = xr .where (cond , x , y , keep_attrs = True )
1932
- expected = xr .DataArray ([1 , 0 ], dims = "x" , attrs = {"attr" : "x" })
1937
+ expected = xr .DataArray ([1 , 0 ], coords = {"a" : [0 , 1 ]}, attrs = {"attr" : "x_da" })
1938
+ expected ["a" ].attrs = {"attr" : "x_coord" }
1933
1939
assert_identical (expected , actual )
1934
1940
1935
- # ensure keep_attrs can handle scalar values
1941
+ # x as a scalar, takes no attrs
1942
+ actual = xr .where (cond , 0 , y , keep_attrs = True )
1943
+ expected = xr .DataArray ([0 , 0 ], coords = {"a" : [0 , 1 ]})
1944
+ assert_identical (expected , actual )
1945
+
1946
+ # y as a scalar, takes attrs from x
1947
+ actual = xr .where (cond , x , 0 , keep_attrs = True )
1948
+ expected = xr .DataArray ([1 , 0 ], coords = {"a" : [0 , 1 ]}, attrs = {"attr" : "x_da" })
1949
+ expected ["a" ].attrs = {"attr" : "x_coord" }
1950
+ assert_identical (expected , actual )
1951
+
1952
+ # x and y as a scalar, takes no attrs
1936
1953
actual = xr .where (cond , 1 , 0 , keep_attrs = True )
1937
- assert actual .attrs == {}
1954
+ expected = xr .DataArray ([1 , 0 ], coords = {"a" : [0 , 1 ]})
1955
+ assert_identical (expected , actual )
1956
+
1957
+ # cond and y as a scalar, takes attrs from x
1958
+ actual = xr .where (True , x , y , keep_attrs = True )
1959
+ expected = xr .DataArray ([1 , 1 ], coords = {"a" : [0 , 1 ]}, attrs = {"attr" : "x_da" })
1960
+ expected ["a" ].attrs = {"attr" : "x_coord" }
1961
+ assert_identical (expected , actual )
1962
+
1963
+ # DataArray and 2 Datasets, takes attrs from x
1964
+ ds_x = xr .Dataset (data_vars = {"x" : x }, attrs = {"attr" : "x_ds" })
1965
+ ds_y = xr .Dataset (data_vars = {"x" : y }, attrs = {"attr" : "y_ds" })
1966
+ ds_actual = xr .where (cond , ds_x , ds_y , keep_attrs = True )
1967
+ ds_expected = xr .Dataset (
1968
+ data_vars = {
1969
+ "x" : xr .DataArray ([1 , 0 ], coords = {"a" : [0 , 1 ]}, attrs = {"attr" : "x_da" })
1970
+ },
1971
+ attrs = {"attr" : "x_ds" },
1972
+ )
1973
+ ds_expected ["a" ].attrs = {"attr" : "x_coord" }
1974
+ assert_identical (ds_expected , ds_actual )
1975
+
1976
+ # 2 DataArrays and 1 Dataset, takes attrs from x
1977
+ ds_actual = xr .where (cond , x .rename ("x" ), ds_y , keep_attrs = True )
1978
+ ds_expected = xr .Dataset (
1979
+ data_vars = {
1980
+ "x" : xr .DataArray ([1 , 0 ], coords = {"a" : [0 , 1 ]}, attrs = {"attr" : "x_da" })
1981
+ },
1982
+ )
1983
+ ds_expected ["a" ].attrs = {"attr" : "x_coord" }
1984
+ assert_identical (ds_expected , ds_actual )
1938
1985
1939
1986
1940
1987
@pytest .mark .parametrize (
0 commit comments