From cc17b5d7db1264d3cbe551b74c80db47ef822087 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 21 Oct 2025 16:07:37 +0000 Subject: [PATCH] perf: reduce the number of `__cuda_array_interface__` accesses --- numba_cuda/numba/cuda/api.py | 11 +++++------ numba_cuda/numba/cuda/cudadrv/devicearray.py | 9 ++++++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/numba_cuda/numba/cuda/api.py b/numba_cuda/numba/cuda/api.py index e58b3f588..8c443567e 100644 --- a/numba_cuda/numba/cuda/api.py +++ b/numba_cuda/numba/cuda/api.py @@ -75,12 +75,11 @@ def as_cuda_array(obj, sync=True): If ``sync`` is ``True``, then the imported stream (if present) will be synchronized. """ - if not is_cuda_array(obj): - raise TypeError("*obj* doesn't implement the cuda array interface.") - else: - return from_cuda_array_interface( - obj.__cuda_array_interface__, owner=obj, sync=sync - ) + if ( + interface := getattr(obj, "__cuda_array_interface__", None) + ) is not None: + return from_cuda_array_interface(interface, owner=obj, sync=sync) + raise TypeError("*obj* doesn't implement the cuda array interface.") def is_cuda_array(obj): diff --git a/numba_cuda/numba/cuda/cudadrv/devicearray.py b/numba_cuda/numba/cuda/cudadrv/devicearray.py index e8dc27f3b..6ef4ab6fb 100644 --- a/numba_cuda/numba/cuda/cudadrv/devicearray.py +++ b/numba_cuda/numba/cuda/cudadrv/devicearray.py @@ -15,7 +15,6 @@ import numpy as np -import numba from numba.cuda.cext import _devicearray from numba.cuda.cudadrv import devices, dummyarray from numba.cuda.cudadrv import driver as _driver @@ -901,8 +900,12 @@ def auto_device(obj, stream=0, copy=True, user_explicit=False): """ if _driver.is_device_memory(obj): return obj, False - elif hasattr(obj, "__cuda_array_interface__"): - return numba.cuda.as_cuda_array(obj), False + elif ( + interface := getattr(obj, "__cuda_array_interface__", None) + ) is not None: + from numba.cuda.api import from_cuda_array_interface + + return from_cuda_array_interface(interface, owner=obj), False else: if isinstance(obj, np.void): devobj = from_record_like(obj, stream=stream)