diff --git a/docs/source/user/globals.rst b/docs/source/user/globals.rst new file mode 100644 index 000000000..f30bf9b76 --- /dev/null +++ b/docs/source/user/globals.rst @@ -0,0 +1,91 @@ +.. + SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + SPDX-License-Identifier: BSD-2-Clause + + +.. _cuda-globals: + +===================================== +Global Variables and Captured Values +===================================== + +Numba CUDA kernels and device functions can reference global variables defined +at module scope. This section describes how these values are captured and the +implications for your code. + + +Capture as constants +==================== + +By default, global variables referenced in kernels are captured as constants at +compilation time. This applies to scalars and host arrays (e.g. NumPy arrays). + +The following example demonstrates this behavior. Both ``TAX_RATE`` and +``PRICES`` are captured when the kernel is first compiled. Because they are +embedded as constants, **modifications to these variables after compilation +have no effect**—the second kernel call still uses the original values: + +.. literalinclude:: ../../../numba_cuda/numba/cuda/tests/doc_examples/test_globals.py + :language: python + :caption: Demonstrating constant capture of global variables + :start-after: magictoken.ex_globals_constant_capture.begin + :end-before: magictoken.ex_globals_constant_capture.end + :dedent: 8 + :linenos: + +Running the above code prints: + +.. code-block:: text + + Value of d_totals: [ 10.8 54. 16.2 64.8 162. ] + Value of d_totals: [ 10.8 54. 16.2 64.8 162. ] + +Note that both outputs are identical—the modifications to ``TAX_RATE`` and +``PRICES`` after the first kernel call have no effect. + +This behaviour is useful for small amounts of truly constant data like +configuration values, lookup tables, or mathematical constants. For larger +arrays, consider using device arrays instead. + + +Device array capture +==================== + +Device arrays are an exception to the constant capture rule. When a kernel +references a global device array—any object implementing +``__cuda_array_interface__``, such as CuPy arrays or Numba device arrays—the +device pointer is captured rather than the data. No copy occurs, and +modifications to the array **are** visible to subsequent kernel calls. + +The following example demonstrates this behavior. The global ``PRICES`` device +array is mutated after the first kernel call, and the second kernel call sees +the updated values: + +.. literalinclude:: ../../../numba_cuda/numba/cuda/tests/doc_examples/test_globals.py + :language: python + :caption: Demonstrating device array capture by pointer + :start-after: magictoken.ex_globals_device_array_capture.begin + :end-before: magictoken.ex_globals_device_array_capture.end + :dedent: 8 + :linenos: + +Running the above code prints: + +.. code-block:: text + + [10. 25. 5. 15. 30.] + [20. 50. 10. 30. 60.] + +Note that the outputs are different—the mutation to ``PRICES`` after the first +kernel call *is* visible to the second call, unlike with host arrays. + +This makes device arrays suitable for global state that needs to be updated +between kernel calls without recompilation. + +.. note:: + + Kernels and device functions that capture global device arrays cannot use + ``cache=True``. Because the device pointer is embedded in the compiled code, + caching would serialize an invalid pointer. Attempting to cache such a kernel + will raise a ``PicklingError``. See :doc:`caching` for more information on + kernel caching. diff --git a/docs/source/user/index.rst b/docs/source/user/index.rst index 988fd015f..5a09d7e7c 100644 --- a/docs/source/user/index.rst +++ b/docs/source/user/index.rst @@ -14,6 +14,7 @@ User guide kernels.rst memory.rst device-functions.rst + globals.rst cudapysupported.rst fastmath.rst intrinsics.rst diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index b261b1d3c..9ac4cc1e2 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -12,6 +12,7 @@ from numba.cuda.memory_management.nrt import NRT_LIBRARY import os +import pickle import subprocess import tempfile @@ -189,6 +190,11 @@ def __init__( self.use_cooperative = False + # Objects that need to be kept alive for the lifetime of the + # kernels or device functions generated by this code library, + # e.g., device arrays captured from global scope. + self.referenced_objects = {} + @property def llvm_strs(self): if self._llvm_strs is None: @@ -377,6 +383,9 @@ def add_linking_library(self, library): self._setup_functions.extend(library._setup_functions) self._teardown_functions.extend(library._teardown_functions) self.use_cooperative |= library.use_cooperative + self.referenced_objects.update( + getattr(library, "referenced_objects", {}) + ) def add_linking_file(self, path_or_obj): if isinstance(path_or_obj, LinkableCode): @@ -442,6 +451,18 @@ def _reduce_states(self): but loaded functions are discarded. They are recreated when needed after deserialization. """ + # Check for captured device arrays that cannot be safely cached. + if self.referenced_objects: + if any( + getattr(obj, "__cuda_array_interface__", None) is not None + for obj in self.referenced_objects.values() + ): + raise pickle.PicklingError( + "Cannot serialize kernels or device functions referencing " + "global device arrays. Pass the array(s) as arguments " + "to the kernel instead." + ) + nrt = False if self._linking_files: if ( diff --git a/numba_cuda/numba/cuda/core/typeinfer.py b/numba_cuda/numba/cuda/core/typeinfer.py index 6b71463b8..10965aa6b 100644 --- a/numba_cuda/numba/cuda/core/typeinfer.py +++ b/numba_cuda/numba/cuda/core/typeinfer.py @@ -1738,8 +1738,11 @@ def typeof_global(self, inst, target, gvar): ) if isinstance(typ, types.Array): - # Global array in nopython mode is constant - typ = typ.copy(readonly=True) + # Global array in nopython mode is constant, except for device + # arrays implementing __cuda_array_interface__ which are references + # to mutable device memory + if not hasattr(gvar.value, "__cuda_array_interface__"): + typ = typ.copy(readonly=True) if isinstance(typ, types.BaseAnonymousTuple): # if it's a tuple of literal types, swap the type for the more diff --git a/numba_cuda/numba/cuda/np/arrayobj.py b/numba_cuda/numba/cuda/np/arrayobj.py index de8834e91..0db69cec1 100644 --- a/numba_cuda/numba/cuda/np/arrayobj.py +++ b/numba_cuda/numba/cuda/np/arrayobj.py @@ -31,6 +31,7 @@ type_is_scalar, lt_complex, lt_floats, + strides_from_shape, ) from numba.cuda.np.numpy_support import ( type_can_asarray, @@ -3642,10 +3643,63 @@ def record_static_setitem_int(context, builder, sig, args): def constant_array(context, builder, ty, pyval): """ Create a constant array (mechanism is target-dependent). + + For objects implementing __cuda_array_interface__, + the device pointer is embedded directly as a constant. For other arrays, + the target-specific mechanism is used. """ + # Check if this is a device array (implements __cuda_array_interface__) + if getattr(pyval, "__cuda_array_interface__", None) is not None: + return _lower_constant_device_array(context, builder, ty, pyval) + return context.make_constant_array(builder, ty, pyval) +def _lower_constant_device_array(context, builder, ty, pyval): + """ + Lower objects with __cuda_array_interface__ by embedding the device + pointer as a constant. + + This allows device arrays captured from globals to be used in CUDA + kernels and device functions. + """ + interface = pyval.__cuda_array_interface__ + + # Hold on to the device array to prevent garbage collection. + context.active_code_library.referenced_objects[id(pyval)] = pyval + + shape = interface["shape"] + strides = interface.get("strides") + data_ptr = interface["data"][0] + typestr = interface["typestr"] + itemsize = np.dtype(typestr).itemsize + + # Calculate strides if not provided (C-contiguous) + if strides is None: + strides = strides_from_shape(shape, itemsize, order="C") + + # Embed device pointer as constant + llvoidptr = context.get_value_type(types.voidptr) + data = context.get_constant(types.uintp, data_ptr).inttoptr(llvoidptr) + + # Build array structure + ary = context.make_array(ty)(context, builder) + kshape = [context.get_constant(types.intp, s) for s in shape] + kstrides = [context.get_constant(types.intp, s) for s in strides] + + context.populate_array( + ary, + data=builder.bitcast(data, ary.data.type), + shape=kshape, + strides=kstrides, + itemsize=context.get_constant(types.intp, itemsize), + parent=None, + meminfo=None, + ) + + return ary._getvalue() + + @lower_constant(types.Record) def constant_record(context, builder, ty, pyval): """ diff --git a/numba_cuda/numba/cuda/np/numpy_support.py b/numba_cuda/numba/cuda/np/numpy_support.py index 69d048efa..14d24f90b 100644 --- a/numba_cuda/numba/cuda/np/numpy_support.py +++ b/numba_cuda/numba/cuda/np/numpy_support.py @@ -3,7 +3,10 @@ import collections import ctypes +import itertools +import operator import re + import numpy as np from numba.cuda import types @@ -17,6 +20,29 @@ numpy_version = tuple(map(int, np.__version__.split(".")[:2])) + +def strides_from_shape( + shape: tuple[int, ...], itemsize: int, *, order: str +) -> tuple[int, ...]: + """Compute strides for a contiguous array with given shape and order.""" + if len(shape) == 0: + # 0-D arrays have empty strides + return () + limits = slice(1, None) if order == "C" else slice(None, -1) + transform = reversed if order == "C" else lambda x: x + strides = tuple( + map( + itemsize.__mul__, + itertools.accumulate( + transform(shape[limits]), operator.mul, initial=1 + ), + ) + ) + if order == "F": + return strides + return strides[::-1] + + FROM_DTYPE = { np.dtype("bool"): types.boolean, np.dtype("int8"): types.int8, diff --git a/numba_cuda/numba/cuda/serialize.py b/numba_cuda/numba/cuda/serialize.py index f0a689a2e..abe66aa46 100644 --- a/numba_cuda/numba/cuda/serialize.py +++ b/numba_cuda/numba/cuda/serialize.py @@ -197,6 +197,16 @@ def reducer_override(self, obj): # Overridden to disable pickling of certain types if type(obj) in self.disabled_types: _no_pickle(obj) # noreturn + + # Prevent pickling of objects implementing __cuda_array_interface__ + # These contain device pointers that would become stale after unpickling + if getattr(obj, "__cuda_array_interface__", None) is not None: + raise pickle.PicklingError( + "Cannot serialize kernels or device functions referencing " + "global device arrays. Pass the array(s) as arguments " + "to the kernel instead." + ) + return super().reducer_override(obj) diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_caching.py b/numba_cuda/numba/cuda/tests/cudapy/test_caching.py index 1bfd03595..96f69f30b 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_caching.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_caching.py @@ -25,6 +25,11 @@ temp_directory, import_dynamic, ) +import numpy as np +from pickle import PicklingError + +# Module-level global for testing that caching rejects global device arrays +GLOBAL_DEVICE_ARRAY = None class BaseCacheTest(TestCase): @@ -368,6 +373,48 @@ def test_cannot_cache_linking_libraries(self): def f(): pass + def test_cannot_cache_captured_device_array(self): + # Test that kernels capturing device arrays from closures cannot + # be cached. The error can come from either NumbaPickler (for closure + # variables) or CUDACodeLibrary._reduce_states (for referenced objects). + host_data = np.array([1.0, 2.0, 3.0], dtype=np.float32) + captured_arr = cuda.to_device(host_data) + + msg = "global device arrays" + with self.assertRaisesRegex(PicklingError, msg): + + @cuda.jit(cache=True) + def cached_kernel(output): + i = cuda.grid(1) + if i < output.size: + output[i] = captured_arr[i] * 2.0 + + output = cuda.device_array(3, dtype=np.float32) + cached_kernel[1, 3](output) + + def test_cannot_cache_global_device_array(self): + # Test that kernels referencing module-level global device arrays + # cannot be cached. + global GLOBAL_DEVICE_ARRAY + + host_data = np.array([1.0, 2.0, 3.0], dtype=np.float32) + GLOBAL_DEVICE_ARRAY = cuda.to_device(host_data) + + try: + msg = "global device arrays" + with self.assertRaisesRegex(PicklingError, msg): + + @cuda.jit(cache=True) + def cached_kernel_global(output): + i = cuda.grid(1) + if i < output.size: + output[i] = GLOBAL_DEVICE_ARRAY[i] * 2.0 + + output = cuda.device_array(3, dtype=np.float32) + cached_kernel_global[1, 3](output) + finally: + GLOBAL_DEVICE_ARRAY = None + @skip_on_cudasim("Simulator does not implement caching") class CUDACooperativeGroupTest(DispatcherCacheUsecasesTest): diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_device_array_capture.py b/numba_cuda/numba/cuda/tests/cudapy/test_device_array_capture.py new file mode 100644 index 000000000..f0899475c --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_device_array_capture.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Tests for capturing device arrays (objects implementing __cuda_array_interface__) +from global scope in CUDA kernels and device functions. + +This tests the capture of arrays that implement __cuda_array_interface__: +- Numba device arrays (cuda.to_device) +- ForeignArray (wrapper implementing __cuda_array_interface__) +""" + +import numpy as np + +from numba import cuda +from numba.cuda.testing import unittest, CUDATestCase, ForeignArray +from numba.cuda.testing import skip_on_cudasim + + +def make_numba_array(host_arr): + """Create a Numba device array from host array.""" + return cuda.to_device(host_arr) + + +def make_foreign_array(host_arr): + """Create a ForeignArray wrapping a Numba device array.""" + return ForeignArray(cuda.to_device(host_arr)) + + +def get_host_data(arr): + """Copy array data back to host.""" + if isinstance(arr, ForeignArray): + return arr._arr.copy_to_host() + return arr.copy_to_host() + + +# Array factories to test: (name, factory) +ARRAY_FACTORIES = [ + ("numba_device", make_numba_array), + ("foreign", make_foreign_array), +] + + +@skip_on_cudasim("Global device array capture not supported in simulator") +class TestDeviceArrayCapture(CUDATestCase): + """Test capturing device arrays from global scope.""" + + def test_basic_capture(self): + """Test basic global capture with different array types.""" + for name, make_array in ARRAY_FACTORIES: + with self.subTest(array_type=name): + host_data = np.array( + [1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32 + ) + global_array = make_array(host_data) + + @cuda.jit(device=True) + def read_global(idx): + return global_array[idx] + + @cuda.jit + def kernel(output): + i = cuda.grid(1) + if i < output.size: + output[i] = read_global(i) + + n = len(host_data) + output = cuda.device_array(n, dtype=np.float32) + kernel[1, n](output) + + result = output.copy_to_host() + np.testing.assert_array_equal(result, host_data) + + def test_computation(self): + """Test captured global arrays used in computations.""" + for name, make_array in ARRAY_FACTORIES: + with self.subTest(array_type=name): + host_data = np.array( + [1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32 + ) + global_array = make_array(host_data) + + @cuda.jit(device=True) + def double_global_value(idx): + return global_array[idx] * 2.0 + + @cuda.jit + def kernel(output): + i = cuda.grid(1) + if i < output.size: + output[i] = double_global_value(i) + + n = len(host_data) + output = cuda.device_array(n, dtype=np.float32) + kernel[1, n](output) + + result = output.copy_to_host() + expected = host_data * 2.0 + np.testing.assert_array_equal(result, expected) + + def test_mutability(self): + """Test that captured arrays can be written to (mutability).""" + for name, make_array in ARRAY_FACTORIES: + with self.subTest(array_type=name): + host_data = np.zeros(5, dtype=np.float32) + mutable_array = make_array(host_data) + + @cuda.jit + def write_kernel(): + i = cuda.grid(1) + if i < 5: + mutable_array[i] = float(i + 1) + + write_kernel[1, 5]() + + result = get_host_data(mutable_array) + expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_multiple_arrays(self): + """Test capturing multiple arrays from globals.""" + for name, make_array in ARRAY_FACTORIES: + with self.subTest(array_type=name): + host_a = np.array([1.0, 2.0, 3.0], dtype=np.float32) + host_b = np.array([10.0, 20.0, 30.0], dtype=np.float32) + arr_a = make_array(host_a) + arr_b = make_array(host_b) + + @cuda.jit(device=True) + def add_globals(idx): + return arr_a[idx] + arr_b[idx] + + @cuda.jit + def kernel(output): + i = cuda.grid(1) + if i < output.size: + output[i] = add_globals(i) + + output = cuda.device_array(3, dtype=np.float32) + kernel[1, 3](output) + + result = output.copy_to_host() + expected = np.array([11.0, 22.0, 33.0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_multidimensional(self): + """Test capturing multidimensional arrays.""" + for name, make_array in ARRAY_FACTORIES: + with self.subTest(array_type=name): + host_2d = np.array( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32 + ) + arr_2d = make_array(host_2d) + + @cuda.jit(device=True) + def read_2d(row, col): + return arr_2d[row, col] + + @cuda.jit + def kernel(output): + i = cuda.grid(1) + if i < 6: + row = i // 2 + col = i % 2 + output[i] = read_2d(row, col) + + output = cuda.device_array(6, dtype=np.float32) + kernel[1, 6](output) + + result = output.copy_to_host() + expected = host_2d.flatten() + np.testing.assert_array_equal(result, expected) + + def test_dtypes(self): + """Test capturing arrays with different dtypes.""" + dtypes = [ + (np.int32, [10, 20, 30, 40]), + (np.float64, [1.5, 2.5, 3.5, 4.5]), + ] + + for name, make_array in ARRAY_FACTORIES: + for dtype, values in dtypes: + with self.subTest(array_type=name, dtype=dtype): + host_data = np.array(values, dtype=dtype) + global_arr = make_array(host_data) + + @cuda.jit(device=True) + def read_arr(idx): + return global_arr[idx] + + @cuda.jit + def kernel(output): + i = cuda.grid(1) + if i < output.size: + output[i] = read_arr(i) + + output = cuda.device_array(len(host_data), dtype=dtype) + kernel[1, len(host_data)](output) + np.testing.assert_array_equal( + output.copy_to_host(), host_data + ) + + def test_direct_kernel_access(self): + """Test direct kernel access (not via device function).""" + for name, make_array in ARRAY_FACTORIES: + with self.subTest(array_type=name): + host_data = np.array([7.0, 8.0, 9.0], dtype=np.float32) + global_direct = make_array(host_data) + + @cuda.jit + def direct_access_kernel(output): + i = cuda.grid(1) + if i < output.size: + output[i] = global_direct[i] + 1.0 + + output = cuda.device_array(3, dtype=np.float32) + direct_access_kernel[1, 3](output) + + result = output.copy_to_host() + expected = np.array([8.0, 9.0, 10.0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_zero_dimensional(self): + """Test capturing 0-D (scalar) device arrays.""" + for name, make_array in ARRAY_FACTORIES: + with self.subTest(array_type=name): + host_0d = np.array(42.0, dtype=np.float32) + global_0d = make_array(host_0d) + + @cuda.jit + def kernel_0d(output): + output[()] = global_0d[()] * 2.0 + + output = cuda.device_array((), dtype=np.float32) + kernel_0d[1, 1](output) + + result = output.copy_to_host() + expected = 84.0 + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/numba_cuda/numba/cuda/tests/doc_examples/test_globals.py b/numba_cuda/numba/cuda/tests/doc_examples/test_globals.py new file mode 100644 index 000000000..40913a150 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/doc_examples/test_globals.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import unittest + +from numba.cuda.testing import CUDATestCase, skip_on_cudasim +from numba.cuda.tests.support import captured_stdout + + +@skip_on_cudasim("cudasim doesn't support cuda import at non-top-level") +class TestGlobals(CUDATestCase): + """ + Tests demonstrating how global variables are captured in CUDA kernels. + """ + + def setUp(self): + # Prevent output from this test showing + # up when running the test suite + self._captured_stdout = captured_stdout() + self._captured_stdout.__enter__() + super().setUp() + + def tearDown(self): + # No exception type, value, or traceback + self._captured_stdout.__exit__(None, None, None) + super().tearDown() + + def test_ex_globals_constant_capture(self): + """ + Test demonstrating how global variables are captured as constants. + """ + # magictoken.ex_globals_constant_capture.begin + import numpy as np + from numba import cuda + + TAX_RATE = 0.08 + PRICES = np.array([10.0, 25.0, 5.0, 15.0, 30.0], dtype=np.float64) + + @cuda.jit + def compute_totals(quantities, totals): + i = cuda.grid(1) + if i < totals.size: + totals[i] = quantities[i] * PRICES[i] * (1 + TAX_RATE) + + d_quantities = cuda.to_device( + np.array([1, 2, 3, 4, 5], dtype=np.float64) + ) + d_totals = cuda.device_array(5, dtype=np.float64) + + # First kernel call - compiles and captures values + compute_totals[1, 32](d_quantities, d_totals) + print("Value of d_totals:", d_totals.copy_to_host()) + + # These modifications have no effect on subsequent kernel calls + TAX_RATE = 0.10 # noqa: F841 + PRICES[:] = [20.0, 50.0, 10.0, 30.0, 60.0] + + # Second kernel call still uses the original values + compute_totals[1, 32](d_quantities, d_totals) + print("Value of d_totals:", d_totals.copy_to_host()) + # magictoken.ex_globals_constant_capture.end + + # Verify the values are the same (original values were captured) + expected = np.array([10.8, 54.0, 16.2, 64.8, 162.0]) + np.testing.assert_allclose(d_totals.copy_to_host(), expected) + + def test_ex_globals_device_array_capture(self): + """ + Test demonstrating how global device arrays are captured by pointer. + """ + # magictoken.ex_globals_device_array_capture.begin + import numpy as np + from numba import cuda + + # Global device array - pointer is captured, not data + PRICES = cuda.to_device( + np.array([10.0, 25.0, 5.0, 15.0, 30.0], dtype=np.float32) + ) + + @cuda.jit + def compute_totals(quantities, totals): + i = cuda.grid(1) + if i < totals.size: + totals[i] = quantities[i] * PRICES[i] + + d_quantities = cuda.to_device( + np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32) + ) + d_totals = cuda.device_array(5, dtype=np.float32) + + # First kernel call + compute_totals[1, 32](d_quantities, d_totals) + print(d_totals.copy_to_host()) # [10. 25. 5. 15. 30.] + + # Mutate the device array in-place + PRICES.copy_to_device( + np.array([20.0, 50.0, 10.0, 30.0, 60.0], dtype=np.float32) + ) + + # Second kernel call sees the updated values + compute_totals[1, 32](d_quantities, d_totals) + print(d_totals.copy_to_host()) # [20. 50. 10. 30. 60.] + # magictoken.ex_globals_device_array_capture.end + + # Verify the second call sees updated values + expected = np.array([20.0, 50.0, 10.0, 30.0, 60.0], dtype=np.float32) + np.testing.assert_allclose(d_totals.copy_to_host(), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/numba_cuda/numba/cuda/typing/typeof.py b/numba_cuda/numba/cuda/typing/typeof.py index a3091d282..2d3e1e826 100644 --- a/numba_cuda/numba/cuda/typing/typeof.py +++ b/numba_cuda/numba/cuda/typing/typeof.py @@ -55,6 +55,15 @@ def typeof_impl(val, c): if tp is not None: return tp + # Check for __cuda_array_interface__ objects (third-party device arrays) + + # Numba's own DeviceNDArray is handled above via _numba_type_. + cai = getattr(val, "__cuda_array_interface__", None) + if cai is not None: + tp = _typeof_cuda_array_interface(cai, c) + if tp is not None: + return tp + # cffi is handled here as it does not expose a public base class # for exported functions or CompiledFFI instances. from numba.cuda.typing import cffi_utils @@ -299,3 +308,50 @@ def typeof_numpy_polynomial(val, c): domain = typeof(val.domain) window = typeof(val.window) return types.PolynomialType(coef, domain, window) + + +def _typeof_cuda_array_interface(val, c): + """ + Determine the type of a __cuda_array_interface__ object. + + This handles third-party device arrays that implement the CUDA + Array Interface. These are typed as regular Array types, with lowering + handled in numba.cuda.np.arrayobj. + """ + # Only handle constants, not arguments (arguments use regular array typing) + if c.purpose == Purpose.argument: + return None + + dtype = numpy_support.from_dtype(np.dtype(val["typestr"])) + shape = val["shape"] + ndim = len(shape) + strides = val.get("strides") + + # Determine layout + if ndim == 0: + layout = "C" + elif strides is None: + layout = "C" + else: + itemsize = np.dtype(val["typestr"]).itemsize + # Quick rejection: C-contiguous has strides[-1] == itemsize, + # F-contiguous has strides[0] == itemsize. If neither, it's "A". + if strides[-1] == itemsize: + c_strides = numpy_support.strides_from_shape( + shape, itemsize, order="C" + ) + layout = ( + "C" if all(x == y for x, y in zip(strides, c_strides)) else "A" + ) + elif strides[0] == itemsize: + f_strides = numpy_support.strides_from_shape( + shape, itemsize, order="F" + ) + layout = ( + "F" if all(x == y for x, y in zip(strides, f_strides)) else "A" + ) + else: + layout = "A" + + readonly = val["data"][1] + return types.Array(dtype, ndim, layout, readonly=readonly)