From eab233ef8ddefe7addb3331f22b8acd81a9bdd2d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:39:43 -0500 Subject: [PATCH] refactor: replace device functionality with `cuda.core` APIs --- numba_cuda/numba/cuda/cudadrv/driver.py | 58 ++++++++----------------- 1 file changed, 17 insertions(+), 41 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 8039596a5..b5ffb6b73 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -43,7 +43,6 @@ import importlib import numpy as np from collections import namedtuple, deque -from uuid import UUID from numba.cuda.cext import mviewbuf @@ -67,6 +66,7 @@ from cuda.bindings.utils import get_cuda_native_handle from cuda.core.experimental import ( Stream as ExperimentalStream, + Device as ExperimentalDevice, ) @@ -527,7 +527,7 @@ def _build_reverse_device_attrs(): DEVICE_ATTRIBUTES = _build_reverse_device_attrs() -class Device(object): +class Device: """ The device object owns the CUDA contexts. This is owned by the driver object. User should not construct devices directly. @@ -548,32 +548,12 @@ def from_identity(self, identity): "Target device may not be visible in this process." ) - def __init__(self, devnum): - result = driver.cuDeviceGet(devnum) - self.id = result - got_devnum = int(result) - - msg = f"Driver returned device {got_devnum} instead of {devnum}" - if devnum != got_devnum: - raise RuntimeError(msg) - - # Read compute capability - self.compute_capability = ( - self.COMPUTE_CAPABILITY_MAJOR, - self.COMPUTE_CAPABILITY_MINOR, - ) - - # Read name - bufsz = 128 - buf = driver.cuDeviceGetName(bufsz, self.id) - name = buf.split(b"\x00", 1)[0] - - self.name = name - - # Read UUID - uuid = driver.cuDeviceGetUuid(self.id) - self.uuid = f"GPU-{UUID(bytes=uuid.bytes)}" - + def __init__(self, devnum: int) -> None: + self._dev = ExperimentalDevice(devnum) + self.id = self._dev.device_id + self.compute_capability = self._dev.compute_capability + self.name = self._dev.name + self.uuid = f"GPU-{self._dev.uuid}" self.primary_context = None def get_device_identity(self): @@ -613,13 +593,16 @@ def get_primary_context(self): if (ctx := self.primary_context) is not None: return ctx - met_requirement_for_device(self) - # create primary context - hctx = driver.cuDevicePrimaryCtxRetain(self.id) - hctx = drvapi.cu_context(int(hctx)) + if self.compute_capability < MIN_REQUIRED_CC: + raise CudaSupportError( + f"{self} has compute capability < {MIN_REQUIRED_CC}" + ) - ctx = Context(weakref.proxy(self), hctx) - self.primary_context = ctx + self._dev.set_current() + self.primary_context = ctx = Context( + weakref.proxy(self), + ctypes.c_void_p(int(self._dev.context._handle)), + ) return ctx def release_primary_context(self): @@ -648,13 +631,6 @@ def supports_bfloat16(self): return self.compute_capability >= (8, 0) -def met_requirement_for_device(device): - if device.compute_capability < MIN_REQUIRED_CC: - raise CudaSupportError( - "%s has compute capability < %s" % (device, MIN_REQUIRED_CC) - ) - - class BaseCUDAMemoryManager(object, metaclass=ABCMeta): """Abstract base class for External Memory Management (EMM) Plugins."""