Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import importlib
import numpy as np
from collections import namedtuple, deque
from uuid import UUID


from numba.cuda.cext import mviewbuf
Expand Down Expand Up @@ -536,11 +537,10 @@ def from_identity(self, identity):
if d.get_device_identity() == identity:
return d
else:
errmsg = (
"No device of {} is found. "
raise RuntimeError(
f"No device of {identity} is found. "
"Target device may not be visible in this process."
).format(identity)
raise RuntimeError(errmsg)
)

def __init__(self, devnum):
result = driver.cuDeviceGet(devnum)
Expand All @@ -551,8 +551,6 @@ def __init__(self, devnum):
if devnum != got_devnum:
raise RuntimeError(msg)

self.attributes = {}

# Read compute capability
self.compute_capability = (
self.COMPUTE_CAPABILITY_MAJOR,
Expand All @@ -562,20 +560,13 @@ def __init__(self, devnum):
# Read name
bufsz = 128
buf = driver.cuDeviceGetName(bufsz, self.id)
name = buf.split(b"\x00")[0]
name = buf.split(b"\x00", 1)[0]

self.name = name

# Read UUID
uuid = driver.cuDeviceGetUuid(self.id)
uuid_vals = tuple(uuid.bytes)

b = "%02x"
b2 = b * 2
b4 = b * 4
b6 = b * 6
fmt = f"GPU-{b4}-{b2}-{b2}-{b2}-{b6}"
self.uuid = fmt % uuid_vals
self.uuid = f"GPU-{UUID(bytes=uuid.bytes)}"

self.primary_context = None

Expand All @@ -587,7 +578,7 @@ def get_device_identity(self):
}

def __repr__(self):
return "<CUDA device %d '%s'>" % (self.id, self.name)
return f"<CUDA device {self.id:d} '{self.name}'>"

def __getattr__(self, attr):
"""Read attributes lazily"""
Expand All @@ -603,9 +594,7 @@ def __hash__(self):
return hash(self.id)

def __eq__(self, other):
if isinstance(other, Device):
return self.id == other.id
return False
return isinstance(other, Device) and self.id == other.id

def __ne__(self, other):
return not (self == other)
Expand All @@ -615,8 +604,8 @@ def get_primary_context(self):
Returns the primary context for the device.
Note: it is not pushed to the CPU thread.
"""
if self.primary_context is not None:
return self.primary_context
if (ctx := self.primary_context) is not None:
return ctx

met_requirement_for_device(self)
# create primary context
Expand All @@ -637,8 +626,8 @@ def release_primary_context(self):

def reset(self):
try:
if self.primary_context is not None:
self.primary_context.reset()
if (ctx := self.primary_context) is not None:
ctx.reset()
self.release_primary_context()
finally:
# reset at the driver level
Expand Down