diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index b45483d3f..5d781b1bb 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -43,6 +43,7 @@ import importlib import numpy as np from collections import namedtuple, deque +from uuid import UUID from numba.cuda.cext import mviewbuf @@ -536,11 +537,10 @@ def from_identity(self, identity): if d.get_device_identity() == identity: return d else: - errmsg = ( - "No device of {} is found. " + raise RuntimeError( + f"No device of {identity} is found. " "Target device may not be visible in this process." - ).format(identity) - raise RuntimeError(errmsg) + ) def __init__(self, devnum): result = driver.cuDeviceGet(devnum) @@ -551,8 +551,6 @@ def __init__(self, devnum): if devnum != got_devnum: raise RuntimeError(msg) - self.attributes = {} - # Read compute capability self.compute_capability = ( self.COMPUTE_CAPABILITY_MAJOR, @@ -562,20 +560,13 @@ def __init__(self, devnum): # Read name bufsz = 128 buf = driver.cuDeviceGetName(bufsz, self.id) - name = buf.split(b"\x00")[0] + name = buf.split(b"\x00", 1)[0] self.name = name # Read UUID uuid = driver.cuDeviceGetUuid(self.id) - uuid_vals = tuple(uuid.bytes) - - b = "%02x" - b2 = b * 2 - b4 = b * 4 - b6 = b * 6 - fmt = f"GPU-{b4}-{b2}-{b2}-{b2}-{b6}" - self.uuid = fmt % uuid_vals + self.uuid = f"GPU-{UUID(bytes=uuid.bytes)}" self.primary_context = None @@ -587,7 +578,7 @@ def get_device_identity(self): } def __repr__(self): - return "" % (self.id, self.name) + return f"" def __getattr__(self, attr): """Read attributes lazily""" @@ -603,9 +594,7 @@ def __hash__(self): return hash(self.id) def __eq__(self, other): - if isinstance(other, Device): - return self.id == other.id - return False + return isinstance(other, Device) and self.id == other.id def __ne__(self, other): return not (self == other) @@ -615,8 +604,8 @@ def get_primary_context(self): Returns the primary context for the device. Note: it is not pushed to the CPU thread. """ - if self.primary_context is not None: - return self.primary_context + if (ctx := self.primary_context) is not None: + return ctx met_requirement_for_device(self) # create primary context @@ -637,8 +626,8 @@ def release_primary_context(self): def reset(self): try: - if self.primary_context is not None: - self.primary_context.reset() + if (ctx := self.primary_context) is not None: + ctx.reset() self.release_primary_context() finally: # reset at the driver level