diff --git a/numba_cuda/numba/cuda/api.py b/numba_cuda/numba/cuda/api.py index e58b3f588..8c443567e 100644 --- a/numba_cuda/numba/cuda/api.py +++ b/numba_cuda/numba/cuda/api.py @@ -75,12 +75,11 @@ def as_cuda_array(obj, sync=True): If ``sync`` is ``True``, then the imported stream (if present) will be synchronized. """ - if not is_cuda_array(obj): - raise TypeError("*obj* doesn't implement the cuda array interface.") - else: - return from_cuda_array_interface( - obj.__cuda_array_interface__, owner=obj, sync=sync - ) + if ( + interface := getattr(obj, "__cuda_array_interface__", None) + ) is not None: + return from_cuda_array_interface(interface, owner=obj, sync=sync) + raise TypeError("*obj* doesn't implement the cuda array interface.") def is_cuda_array(obj): diff --git a/numba_cuda/numba/cuda/cudadrv/devicearray.py b/numba_cuda/numba/cuda/cudadrv/devicearray.py index e8dc27f3b..6ef4ab6fb 100644 --- a/numba_cuda/numba/cuda/cudadrv/devicearray.py +++ b/numba_cuda/numba/cuda/cudadrv/devicearray.py @@ -15,7 +15,6 @@ import numpy as np -import numba from numba.cuda.cext import _devicearray from numba.cuda.cudadrv import devices, dummyarray from numba.cuda.cudadrv import driver as _driver @@ -901,8 +900,12 @@ def auto_device(obj, stream=0, copy=True, user_explicit=False): """ if _driver.is_device_memory(obj): return obj, False - elif hasattr(obj, "__cuda_array_interface__"): - return numba.cuda.as_cuda_array(obj), False + elif ( + interface := getattr(obj, "__cuda_array_interface__", None) + ) is not None: + from numba.cuda.api import from_cuda_array_interface + + return from_cuda_array_interface(interface, owner=obj), False else: if isinstance(obj, np.void): devobj = from_record_like(obj, stream=stream)