From 57155abfbfd311c17acfd4ced66604e018ce6ebf Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 2 Mar 2025 19:03:39 +0100 Subject: [PATCH 01/12] explicitly cast the dtype of `condition` to `bool` --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index faec5ded04e..7ea283b87c2 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -373,7 +373,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) + return xp.where(xp.astype(condition, xp.bool), *as_shared_dtype([x, y], xp=xp)) def where_method(data, cond, other=dtypes.NA): From 84a2c84de488b23a0ff24a43ce9671012e0aed9b Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 21:37:47 +0100 Subject: [PATCH 02/12] cast `condition` to bool in every case for `where` --- xarray/core/duck_array_ops.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 7ea283b87c2..6ddf664faf9 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -224,14 +224,17 @@ def empty_like(a, **kwargs): return xp.empty_like(a, **kwargs) -def astype(data, dtype, **kwargs): - if hasattr(data, "__array_namespace__"): +def astype(data, dtype, *, xp=None, **kwargs): + if not hasattr(data, "__array_namespace__") and xp is None: + return data.astype(dtype, **kwargs) + + if xp is None: xp = get_array_namespace(data) - if xp == np: - # numpy currently doesn't have a astype: - return data.astype(dtype, **kwargs) - return xp.astype(data, dtype, **kwargs) - return data.astype(dtype, **kwargs) + + if xp == np: + # numpy currently doesn't have a astype: + return data.astype(dtype, **kwargs) + return xp.astype(data, dtype, **kwargs) def asarray(data, xp=np, dtype=None): @@ -373,7 +376,13 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) - return xp.where(xp.astype(condition, xp.bool), *as_shared_dtype([x, y], xp=xp)) + + if not is_duck_array(condition): + condition = asarray(condition, dtype=xp.bool, xp=xp) + else: + condition = astype(condition, dtype=xp.bool, xp=xp) + + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) def where_method(data, cond, other=dtypes.NA): From d4ad871bda075debd14a83af8b1c41dccb55d10f Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 21:38:18 +0100 Subject: [PATCH 03/12] don't pass a `DataArray` to `where` --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b28ba390a9f..6a3ce156ce6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -524,7 +524,7 @@ def factorize(self) -> EncodedGroups: # Restore these after the raveling broadcasted_masks = broadcast(*masks) mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type] - _flatcodes = where(mask, -1, _flatcodes) + _flatcodes = where(mask.data, -1, _flatcodes) full_index = pd.MultiIndex.from_product( (grouper.full_index.values for grouper in groupers), From 56b254df13f7573e1a9d3fdfcac97a6c26108e5f Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 21:51:05 +0100 Subject: [PATCH 04/12] use strings to specify the dtype for backwards compat --- xarray/core/duck_array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6ddf664faf9..6cff20b5010 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -378,9 +378,9 @@ def where(condition, x, y): xp = get_array_namespace(condition, x, y) if not is_duck_array(condition): - condition = asarray(condition, dtype=xp.bool, xp=xp) + condition = asarray(condition, dtype="bool", xp=xp) else: - condition = astype(condition, dtype=xp.bool, xp=xp) + condition = astype(condition, dtype="bool", xp=xp) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) From ea3d4f783ce198fd42363b2c8262a2e237cd8eab Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 22:52:13 +0100 Subject: [PATCH 05/12] revert the strings and instead ignore the warning --- pyproject.toml | 3 +++ xarray/core/duck_array_ops.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 817fda6c328..9029ca6b482 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -358,6 +358,9 @@ filterwarnings = [ "default:the `pandas.MultiIndex` object:FutureWarning:xarray.tests.test_variable", "default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", + # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API + # TODO: remove once we can drop numpy<2 + "ignore:In the future `np.bool` will be defined as the corresponding NumpPy scalar:FutureWarning", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6cff20b5010..6ddf664faf9 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -378,9 +378,9 @@ def where(condition, x, y): xp = get_array_namespace(condition, x, y) if not is_duck_array(condition): - condition = asarray(condition, dtype="bool", xp=xp) + condition = asarray(condition, dtype=xp.bool, xp=xp) else: - condition = astype(condition, dtype="bool", xp=xp) + condition = astype(condition, dtype=xp.bool, xp=xp) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) From 8fbfa06344d25bf014d106782b24b677abe87445 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 23:15:37 +0100 Subject: [PATCH 06/12] typo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9029ca6b482..a7b0b37a759 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,7 +360,7 @@ filterwarnings = [ "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API # TODO: remove once we can drop numpy<2 - "ignore:In the future `np.bool` will be defined as the corresponding NumpPy scalar:FutureWarning", + "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols From 00b9720eda7b351f4e596dc4d4e0e87a27ddae3d Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 23:15:50 +0100 Subject: [PATCH 07/12] restrict to just numpy --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a7b0b37a759..1afcc8fef14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,7 +360,7 @@ filterwarnings = [ "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API # TODO: remove once we can drop numpy<2 - "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning", + "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning:numpy.*", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols From 58ecc32f68f9423c9db7bb2184b0e5843792a663 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 23:23:53 +0100 Subject: [PATCH 08/12] unrestrict --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1afcc8fef14..a7b0b37a759 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,7 +360,7 @@ filterwarnings = [ "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API # TODO: remove once we can drop numpy<2 - "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning:numpy.*", + "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols From afce80b39ea8013e08a707c889c314362a3e9ed1 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 12 Mar 2025 18:12:42 +0100 Subject: [PATCH 09/12] fall back to `xp.bool_` if `xp.bool` doesn't exist --- xarray/core/duck_array_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6ddf664faf9..30531dafc9c 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -377,10 +377,11 @@ def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) + dtype = xp.bool if hasattr(xp, "bool") else xp.bool_ if not is_duck_array(condition): - condition = asarray(condition, dtype=xp.bool, xp=xp) + condition = asarray(condition, dtype=dtype, xp=xp) else: - condition = astype(condition, dtype=xp.bool, xp=xp) + condition = astype(condition, dtype=dtype, xp=xp) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) From 8a3e0d26b0b1e8b61cea384d1328838972c9df37 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 12 Mar 2025 18:18:35 +0100 Subject: [PATCH 10/12] unskip the `where` test --- xarray/tests/test_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index c273260d7dd..022d2e3750e 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -139,7 +139,6 @@ def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) -@pytest.mark.skip def test_where() -> None: np_arr = xr.DataArray(np.array([1, 0]), dims="x") xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x") From 3e48223733fd84828d7cdd5782fb9543286d4cb9 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 12 Mar 2025 18:20:07 +0100 Subject: [PATCH 11/12] reverse to avoid warnings --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 30531dafc9c..262c023059a 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -377,7 +377,7 @@ def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) - dtype = xp.bool if hasattr(xp, "bool") else xp.bool_ + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool if not is_duck_array(condition): condition = asarray(condition, dtype=dtype, xp=xp) else: From 420d1c013fcc1b8ed714f8311a8372d98c30ff8e Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 12 Mar 2025 18:51:02 +0100 Subject: [PATCH 12/12] remove the outdated ignore --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f5d6b004ae3..85c9183b30e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -344,9 +344,6 @@ filterwarnings = [ "default:the `pandas.MultiIndex` object:FutureWarning:xarray.tests.test_variable", "default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", - # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API - # TODO: remove once we can drop numpy<2 - "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols