Skip to content

Commit

Permalink
Update duck_array_ops.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hmaarrfk authored May 5, 2024
1 parent 6ac5cd5 commit e5b46b5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def asarray(data, xp=np):

def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
array_type_cupy = array_type("cupy")
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
Expand All @@ -234,7 +233,10 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
raise ValueError(
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
)
elif any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):

# Avoid calling array_type("cupy") repeatidely in the any check
array_type_cupy = array_type("cupy")
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
import cupy as cp

arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
Expand Down

0 comments on commit e5b46b5

Please sign in to comment.