Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 17 additions & 41 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import importlib
import numpy as np
from collections import namedtuple, deque
from uuid import UUID


from numba.cuda.cext import mviewbuf
Expand All @@ -67,6 +66,7 @@
from cuda.bindings.utils import get_cuda_native_handle
from cuda.core.experimental import (
Stream as ExperimentalStream,
Device as ExperimentalDevice,
)


Expand Down Expand Up @@ -527,7 +527,7 @@ def _build_reverse_device_attrs():
DEVICE_ATTRIBUTES = _build_reverse_device_attrs()


class Device(object):
class Device:
"""
The device object owns the CUDA contexts. This is owned by the driver
object. User should not construct devices directly.
Expand All @@ -548,32 +548,12 @@ def from_identity(self, identity):
"Target device may not be visible in this process."
)

def __init__(self, devnum):
result = driver.cuDeviceGet(devnum)
self.id = result
got_devnum = int(result)

msg = f"Driver returned device {got_devnum} instead of {devnum}"
if devnum != got_devnum:
raise RuntimeError(msg)

# Read compute capability
self.compute_capability = (
self.COMPUTE_CAPABILITY_MAJOR,
self.COMPUTE_CAPABILITY_MINOR,
)

# Read name
bufsz = 128
buf = driver.cuDeviceGetName(bufsz, self.id)
name = buf.split(b"\x00", 1)[0]

self.name = name

# Read UUID
uuid = driver.cuDeviceGetUuid(self.id)
self.uuid = f"GPU-{UUID(bytes=uuid.bytes)}"

def __init__(self, devnum: int) -> None:
self._dev = ExperimentalDevice(devnum)
self.id = self._dev.device_id
self.compute_capability = self._dev.compute_capability
self.name = self._dev.name
self.uuid = f"GPU-{self._dev.uuid}"
self.primary_context = None

def get_device_identity(self):
Expand Down Expand Up @@ -613,13 +593,16 @@ def get_primary_context(self):
if (ctx := self.primary_context) is not None:
return ctx

met_requirement_for_device(self)
# create primary context
hctx = driver.cuDevicePrimaryCtxRetain(self.id)
hctx = drvapi.cu_context(int(hctx))
if self.compute_capability < MIN_REQUIRED_CC:
raise CudaSupportError(
f"{self} has compute capability < {MIN_REQUIRED_CC}"
)

ctx = Context(weakref.proxy(self), hctx)
self.primary_context = ctx
self._dev.set_current()
self.primary_context = ctx = Context(
weakref.proxy(self),
ctypes.c_void_p(int(self._dev.context._handle)),
)
return ctx

def release_primary_context(self):
Expand Down Expand Up @@ -648,13 +631,6 @@ def supports_bfloat16(self):
return self.compute_capability >= (8, 0)


def met_requirement_for_device(device):
if device.compute_capability < MIN_REQUIRED_CC:
raise CudaSupportError(
"%s has compute capability < %s" % (device, MIN_REQUIRED_CC)
)


class BaseCUDAMemoryManager(object, metaclass=ABCMeta):
"""Abstract base class for External Memory Management (EMM) Plugins."""

Expand Down
Loading