Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ repos:
- id: debug-statements
- id: end-of-file-fixer
- id: trailing-whitespace
exclude: '.+\.patch'
- id: mixed-line-ending
args: ['--fix=lf']
- repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down
48 changes: 48 additions & 0 deletions ci/patches/cudf_numba_cuda_compatibility.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-2-Clause
diff --git i/python/cudf/cudf/core/buffer/buffer.py w/python/cudf/cudf/core/buffer/buffer.py
index 71ab509348..e7e6ea9e5f 100644
--- i/python/cudf/cudf/core/buffer/buffer.py
+++ w/python/cudf/cudf/core/buffer/buffer.py
@@ -414,7 +414,7 @@ class Buffer(Serializable):
"shape": (self.size,),
"strides": None,
"typestr": "|u1",
- "version": 0,
+ "version": 3,
}

def serialize(self) -> tuple[dict, list]:
diff --git i/python/cudf/cudf/core/buffer/spillable_buffer.py w/python/cudf/cudf/core/buffer/spillable_buffer.py
index 52dcd00e2b..a5bc9a461d 100644
--- i/python/cudf/cudf/core/buffer/spillable_buffer.py
+++ w/python/cudf/cudf/core/buffer/spillable_buffer.py
@@ -337,7 +337,7 @@ class SpillableBufferOwner(BufferOwner):
"shape": (self.size,),
"strides": None,
"typestr": "|u1",
- "version": 0,
+ "version": 3,
}

def memoryview(
@@ -460,5 +460,5 @@ class SpillableBuffer(ExposureTrackedBuffer):
"shape": (self.size,),
"strides": None,
"typestr": "|u1",
- "version": 0,
+ "version": 3,
}
diff --git i/python/cudf/cudf/core/column/column.py w/python/cudf/cudf/core/column/column.py
index 1562e4810c..a74a9fe23b 100644
--- i/python/cudf/cudf/core/column/column.py
+++ w/python/cudf/cudf/core/column/column.py
@@ -1971,7 +1971,7 @@ class ColumnBase(Serializable, BinaryOperand, Reducible):
"strides": (self.dtype.itemsize,),
"typestr": self.dtype.str,
"data": (self.data_ptr, False),
- "version": 1,
+ "version": 3,
}
if self.nullable and self.has_nulls():
# Create a simple Python object that exposes the
7 changes: 7 additions & 0 deletions ci/test_thirdparty_cudf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ python -m pip install \
rapids-logger "Shallow clone cuDF repository"
git clone --single-branch --branch 'release/25.12' https://github.com/rapidsai/cudf.git

# TODO: remove the patch and its application after 26.02 is released
patchfile="${PWD}/ci/patches/cudf_numba_cuda_compatibility.patch"
pushd "$(python -c 'import site; print(site.getsitepackages()[0])')"
# strip 3 slashes to apply from the root of the install
patch -p3 < "${patchfile}"
popd

pushd cudf

rapids-logger "Check GPU usage"
Expand Down
54 changes: 39 additions & 15 deletions numba_cuda/numba/cuda/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
memory transfers before & after the kernel call.
"""

import functools


from numba.cuda.typing.typeof import typeof, Purpose


Expand All @@ -23,40 +26,61 @@ def to_device(self, retr, stream=0):
the kernel
"""

@property
@functools.cached_property
def _numba_type_(self):
return typeof(self.value, Purpose.argument)


class In(ArgHint):
def to_device(self, retr, stream=0):
from .cudadrv.devicearray import auto_device
from .cudadrv.devicearray import _to_strided_memory_view

devary, _ = auto_device(self.value, stream=stream)
devary, _ = _to_strided_memory_view(self.value, stream=stream)
# A dummy writeback functor to keep devary alive until the kernel
# is called.
retr.append(lambda: devary)
return devary


class Out(ArgHint):
def to_device(self, retr, stream=0):
from .cudadrv.devicearray import auto_device
copy_input = False

devary, conv = auto_device(self.value, copy=False, stream=stream)
def to_device(self, retr, stream=0):
from .cudadrv.devicearray import _to_strided_memory_view
from .cudadrv.devicearray import _make_strided_memory_view
from .cudadrv.driver import driver
from .cudadrv import devices

devary, conv = _to_strided_memory_view(
value := self.value, copy=self.__class__.copy_input, stream=stream
)
if conv:
retr.append(lambda: devary.copy_to_host(self.value, stream=stream))
stream_ptr = getattr(stream, "handle", stream)

def copy_to_host(devary=devary, value=value, stream_ptr=stream_ptr):
hostary = _make_strided_memory_view(
value, stream_ptr=stream_ptr
)
nbytes = devary.size * devary.dtype.itemsize
hostptr = hostary.ptr
devptr = devary.ptr
if int(stream_ptr):
driver.cuMemcpyDtoHAsync(
hostptr, devptr, nbytes, stream_ptr
)
else:
driver.cuMemcpyDtoH(hostptr, devptr, nbytes)
ctx = devices.get_context()
stream = ctx.get_default_stream()
stream.synchronize()
return hostary

retr.append(copy_to_host)
return devary


class InOut(ArgHint):
def to_device(self, retr, stream=0):
from .cudadrv.devicearray import auto_device

devary, conv = auto_device(self.value, stream=stream)
if conv:
retr.append(lambda: devary.copy_to_host(self.value, stream=stream))
return devary
class InOut(Out):
copy_input = True


def wrap_arg(value, default=InOut):
Expand Down
151 changes: 150 additions & 1 deletion numba_cuda/numba/cuda/cudadrv/devicearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import numpy as np

from cuda.core.utils import StridedMemoryView
from cuda.core import Buffer

from numba.cuda.cudadrv import devices, dummyarray
from numba.cuda.cudadrv import driver as _driver
from numba.cuda import types
Expand Down Expand Up @@ -392,13 +395,28 @@ def view(self, dtype):
gpu_data=self.gpu_data,
)

@property
@functools.cached_property
def nbytes(self):
# Note: not using `alloc_size`. `alloc_size` reports memory
# consumption of the allocation, not the size of the array
# https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.nbytes.html
return self.dtype.itemsize * self.size

@functools.cached_property
def _strided_memory_view_shim(self):
flags = self.flags
return _StridedMemoryViewShim(
ptr=self.device_ctypes_pointer.value,
shape=self.shape,
dtype=self.dtype,
size=self.size,
_layout=_StridedLayoutShim(
strides_in_bytes=self.strides,
is_contiguous_c=flags["C_CONTIGUOUS"],
is_contiguous_f=flags["F_CONTIGUOUS"],
),
)


class DeviceRecord(DeviceNDArrayBase):
"""
Expand Down Expand Up @@ -934,6 +952,137 @@ def auto_device(obj, stream=0, copy=True, user_explicit=False):
return devobj, True


_UNSUPPORTED_DLPACK_TYPES = (
np.void,
np.datetime64,
np.timedelta64,
np.bytes_,
np.str_,
)


def _make_strided_memory_view(obj, *, stream_ptr) -> StridedMemoryView:
if isinstance(obj, _UNSUPPORTED_DLPACK_TYPES) or (
isinstance(obj, np.ndarray)
and issubclass(obj.dtype.type, _UNSUPPORTED_DLPACK_TYPES)
):
return StridedMemoryView.from_array_interface(obj)
return StridedMemoryView.from_any_interface(obj, stream_ptr=stream_ptr)


class _StridedLayoutShim:
__slots__ = ("strides_in_bytes", "is_contiguous_c", "is_contiguous_f")

def __init__(
self,
*,
strides_in_bytes: tuple[int, ...],
is_contiguous_c: bool,
is_contiguous_f: bool,
) -> None:
self.strides_in_bytes = strides_in_bytes
self.is_contiguous_c = is_contiguous_c
self.is_contiguous_f = is_contiguous_f


class _StridedMemoryViewShim:
__slots__ = ("ptr", "shape", "dtype", "size", "_layout")

def __init__(
self,
*,
ptr: int,
shape: tuple[int, ...],
dtype: np.dtype,
size: int,
_layout: _StridedLayoutShim,
) -> None:
self.ptr = ptr
self.shape = shape
self.dtype = dtype
self.size = size
self._layout = _layout


def _to_strided_memory_view(
obj, stream=0, copy: bool = True, user_explicit: bool = False
) -> tuple[StridedMemoryView, bool]:
if _driver.is_device_memory(obj):
return obj._strided_memory_view_shim, False
elif (
not isinstance(obj, (np.ndarray, _UNSUPPORTED_DLPACK_TYPES))
and hasattr(obj, "__dlpack__")
and (
(dtype := getattr(obj, "dtype", None)) is None
or not issubclass(
getattr(dtype, "type", type(None)), _UNSUPPORTED_DLPACK_TYPES
)
)
):
# numpy arrays need to be copied to the device
# so we can't view them as SMVs until then
#
# not sure if this is true in general, since what if a numpy array was
# constructed using `np.from_dlpack`?
return StridedMemoryView.from_dlpack(
obj, stream_ptr=getattr(stream, "handle", stream)
), False
elif (desc := getattr(obj, "__cuda_array_interface__", None)) is not None:
smv = StridedMemoryView.from_cuda_array_interface(
obj, stream_ptr=getattr(stream, "handle", stream)
)

if (
external_stream_ptr := desc.get("stream")
) is not None and config.CUDA_ARRAY_INTERFACE_SYNC:
ctx = devices.get_context()
ext_stream = ctx.create_external_stream(external_stream_ptr)
ext_stream.synchronize()
return smv, False
else:
array_obj = np.asanyarray(obj)
nbytes = array_obj.nbytes
stream_ptr = getattr(stream, "handle", stream)

ctx = devices.get_context()
if not nbytes:
# TODO: once cuda-core fixes zero-byte allocation, this branch can go away
assert not array_obj.size
buf = Buffer.from_handle(
ptr=0, size=0, mr=ctx.device._dev.memory_resource
)
else:
# TODO: potentially rebuild EMM around these (cuda-core) APIs instead
# of numba-cuda APIs in the future
buf = ctx.device._dev.allocate(nbytes, stream=stream)

hostobj = _make_strided_memory_view(array_obj, stream_ptr=stream_ptr)
devobj = StridedMemoryView.from_buffer(
buf,
shape=hostobj.shape,
strides=hostobj.strides,
dtype=hostobj.dtype,
)
if copy:
if (
config.CUDA_WARN_ON_IMPLICIT_COPY
and not config.DISABLE_PERFORMANCE_WARNINGS
):
if not user_explicit and (
not isinstance(obj, DeviceNDArray)
and isinstance(obj, np.ndarray)
):
msg = (
"Host array used in CUDA kernel will incur "
"copy overhead to/from device."
)
warn(NumbaPerformanceWarning(msg))
_driver.driver.cuMemcpyHtoDAsync(
devobj.ptr, hostobj.ptr, nbytes, stream_ptr
)
return devobj, True


def check_array_compatibility(ary1, ary2):
ary1sq, ary2sq = ary1.squeeze(), ary2.squeeze()
if ary1.dtype != ary2.dtype:
Expand Down
23 changes: 15 additions & 8 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from numba.cuda import serialize, utils
from numba import cuda

from numba.cuda.np import numpy_support
from numba.cuda.core.compiler_lock import global_compiler_lock
from numba.cuda.typeconv.rules import default_type_manager
from numba.cuda.typing.templates import fold_arguments
Expand All @@ -41,7 +42,6 @@
from numba.cuda.core import sigutils, config, entrypoints
from numba.cuda.flags import Flags
from numba.cuda.cudadrv import driver, nvvm

from numba.cuda.locks import module_init_lock
from numba.cuda.core.caching import Cache, CacheImpl, NullCache
from numba.cuda.descriptor import cuda_target
Expand Down Expand Up @@ -560,8 +560,7 @@ def _prepare_args(self, ty, val, stream, retr, kernelargs):
if isinstance(ty, types.Array):
devary = wrap_arg(val).to_device(retr, stream)

meminfo = 0
parent = 0
meminfo = parent = 0

kernelargs.append(meminfo)
kernelargs.append(parent)
Expand All @@ -571,10 +570,18 @@ def _prepare_args(self, ty, val, stream, retr, kernelargs):
# however, this saves a noticeable amount of overhead in kernel
# invocation
kernelargs.append(devary.size)
kernelargs.append(devary.dtype.itemsize)
kernelargs.append(devary.device_ctypes_pointer.value)
kernelargs.extend(devary.shape)
kernelargs.extend(devary.strides)
kernelargs.append(itemsize := devary.dtype.itemsize)
kernelargs.append(devary.ptr)
kernelargs.extend(shape := devary.shape)
kernelargs.extend(
(layout := devary._layout).strides_in_bytes
or numpy_support.strides_from_shape(
shape=shape,
itemsize=itemsize,
c_contiguous=layout.is_contiguous_c,
f_contiguous=layout.is_contiguous_f,
)
)

elif isinstance(ty, types.CPointer):
# Pointer arguments should be a pointer-sized integer
Expand Down Expand Up @@ -612,7 +619,7 @@ def _prepare_args(self, ty, val, stream, retr, kernelargs):

elif isinstance(ty, types.Record):
devrec = wrap_arg(val).to_device(retr, stream)
kernelargs.append(devrec.device_ctypes_pointer.value)
kernelargs.append(devrec.ptr)

elif isinstance(ty, types.BaseTuple):
assert len(ty) == len(val)
Expand Down
Loading
Loading