From 2827d0da0bfaabc42bf0da2b16bb6f0027112108 Mon Sep 17 00:00:00 2001 From: isVoid Date: Mon, 3 Mar 2025 13:53:15 -0800 Subject: [PATCH 01/36] initial --- numba_cuda/numba/cuda/codegen.py | 16 +++++- .../numba/cuda/cudadrv/linkable_code.py | 3 +- .../cuda/tests/cudadrv/test_module_init.py | 57 +++++++++++++++++++ 3 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index 426eb82b3..e3b4376b2 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -4,6 +4,7 @@ from numba.core.codegen import Codegen, CodeLibrary from .cudadrv import devices, driver, nvvm, runtime from numba.cuda.cudadrv.libs import get_cudalib +from numba.cuda.cudadrv.linkable_code import LinkableCode import os import subprocess @@ -94,6 +95,9 @@ def __init__( # Files to link with the generated PTX. These are linked using the # Driver API at link time. self._linking_files = set() + # Module Init functions to be applied to loaded module, the order is + # determined by the order they are added to the codelib. + self._init_functions = [] # Should we link libcudadevrt? self.needs_cudadevrt = False @@ -248,6 +252,10 @@ def get_cufunc(self): cubin = self.get_cubin(cc=device.compute_capability) module = ctx.create_module_image(cubin) + # Init + for init_fn in self._init_functions: + init_fn(module) + # Load cufunc = module.get_function(self._entry_name) @@ -284,8 +292,12 @@ def add_linking_library(self, library): self._linking_libraries.add(library) - def add_linking_file(self, filepath): - self._linking_files.add(filepath) + def add_linking_file(self, path_or_obj): + if isinstance(path_or_obj, LinkableCode): + if path_or_obj.init_callback is not None: + self._init_functions.append(path_or_obj.init_callback) + + self._linking_files.add(path_or_obj) def get_function(self, name): for fn in self._module.functions: diff --git a/numba_cuda/numba/cuda/cudadrv/linkable_code.py b/numba_cuda/numba/cuda/cudadrv/linkable_code.py index d1d715930..de1e64fba 100644 --- a/numba_cuda/numba/cuda/cudadrv/linkable_code.py +++ b/numba_cuda/numba/cuda/cudadrv/linkable_code.py @@ -9,9 +9,10 @@ class LinkableCode: linking errors that may be produced. """ - def __init__(self, data, name=None): + def __init__(self, data, name=None, init_callback=None): self.data = data self._name = name + self.init_callback = init_callback @property def name(self): diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py new file mode 100644 index 000000000..8c84c71a8 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py @@ -0,0 +1,57 @@ +import numpy as np + +from numba import cuda +from numba.cuda.cudadrv.linkable_code import CUSource +from numba.cuda.testing import CUDATestCase + +from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD + + +class TestModuleInitCallback(CUDATestCase): + + def setUp(self): + super().setUp() + + module = """ +__device__ int num = 0; +extern "C" +__device__ int get_num(int &retval) { + retval = num; + return 0; +} +""" + + def set_fourty_two(mod): + # Initialize 42 to global variable `num` + res, dptr, size = cuModuleGetGlobal( + mod.handle.value, "num".encode() + ) + self.assertEqual(res, 0) + self.assertEqual(size, 4) + + arr = np.array([42], np.int32) + cuMemcpyHtoD(dptr, arr.ctypes.data, size) + + self.lib = CUSource(module, init_callback=set_fourty_two) + + def test_decldevice_arg(self): + get_num = cuda.declare_device("get_num", "int32()", link=[self.lib]) + + @cuda.jit + def kernel(arr): + arr[0] = get_num() + + arr = np.zeros(1, np.int32) + kernel[1, 1](arr) + self.assertEqual(arr[0], 42) + + def test_jitarg(self): + get_num = cuda.declare_device("get_num", "int32()") + + @cuda.jit(link=[self.lib]) + def kernel(arr): + arr[0] = get_num() + + arr = np.zeros(1, np.int32) + kernel[1, 1](arr) + self.assertEqual(arr[0], 42) From 7806707da78b9dee8e4f73c2234ee9405b8b7989 Mon Sep 17 00:00:00 2001 From: isVoid Date: Wed, 5 Mar 2025 15:30:39 -0800 Subject: [PATCH 02/36] mocking object type using cuda.core objects --- numba_cuda/numba/cuda/codegen.py | 13 ++++++++----- .../numba/cuda/tests/cudadrv/test_module_init.py | 16 +++++++++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index e3b4376b2..82929cde1 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -6,6 +6,8 @@ from numba.cuda.cudadrv.libs import get_cudalib from numba.cuda.cudadrv.linkable_code import LinkableCode +from cuda.core.experimental import ObjectCode + import os import subprocess import tempfile @@ -250,14 +252,15 @@ def get_cufunc(self): return cufunc cubin = self.get_cubin(cc=device.compute_capability) - module = ctx.create_module_image(cubin) + + # just a mock, https://github.com/NVIDIA/numba-cuda/pull/133 will + # formalize the object code interface + obj_code = ObjectCode.from_cubin(cubin) + cufunc = obj_code.get_kernel(self._entry_name) # Init for init_fn in self._init_functions: - init_fn(module) - - # Load - cufunc = module.get_function(self._entry_name) + init_fn(obj_code) # Populate caches self._cufunc_cache[device.id] = cufunc diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py index 8c84c71a8..d72da293e 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py @@ -4,7 +4,11 @@ from numba.cuda.cudadrv.linkable_code import CUSource from numba.cuda.testing import CUDATestCase -from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD +from cuda.bindings.driver import ( + cuModuleGetGlobal, + cuMemcpyHtoD, + cuLibraryGetModule +) class TestModuleInitCallback(CUDATestCase): @@ -21,11 +25,13 @@ def setUp(self): } """ - def set_fourty_two(mod): + def set_fourty_two(obj): # Initialize 42 to global variable `num` - res, dptr, size = cuModuleGetGlobal( - mod.handle.value, "num".encode() - ) + culib = obj._handle + res, mod = cuLibraryGetModule(culib) + self.assertEqual(res, 0) + + res, dptr, size = cuModuleGetGlobal(mod, "num".encode()) self.assertEqual(res, 0) self.assertEqual(size, 4) From bc10ce3c2a5d2f1acb5d5524bd8f2fae3fb24776 Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 13 Mar 2025 14:55:57 +0000 Subject: [PATCH 03/36] Update to ctypes module wrapper and type checks --- numba_cuda/numba/cuda/codegen.py | 23 ++++++--- .../numba/cuda/cudadrv/linkable_code.py | 9 +++- .../numba/cuda/cudadrv/managed_module.py | 30 +++++++++++ ...odule_init.py => test_module_callbacks.py} | 50 +++++++++++++++++-- 4 files changed, 97 insertions(+), 15 deletions(-) create mode 100644 numba_cuda/numba/cuda/cudadrv/managed_module.py rename numba_cuda/numba/cuda/tests/cudadrv/{test_module_init.py => test_module_callbacks.py} (52%) diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index e3b4376b2..dc1cf21d4 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -5,6 +5,7 @@ from .cudadrv import devices, driver, nvvm, runtime from numba.cuda.cudadrv.libs import get_cudalib from numba.cuda.cudadrv.linkable_code import LinkableCode +from numba.cuda.cudadrv.managed_module import ManagedModule import os import subprocess @@ -95,9 +96,12 @@ def __init__( # Files to link with the generated PTX. These are linked using the # Driver API at link time. self._linking_files = set() - # Module Init functions to be applied to loaded module, the order is - # determined by the order they are added to the codelib. - self._init_functions = [] + # List of setup functions to the loaded module + # the order is determined by the order they are added to the codelib. + self._setup_functions = [] + # List of teardown functions to the loaded module + # the order is determined by the order they are added to the codelib. + self._teardown_functions = [] # Should we link libcudadevrt? self.needs_cudadevrt = False @@ -252,9 +256,10 @@ def get_cufunc(self): cubin = self.get_cubin(cc=device.compute_capability) module = ctx.create_module_image(cubin) - # Init - for init_fn in self._init_functions: - init_fn(module) + # Wrap ctypesmodule with managed module with auto module setup/teardown + module = ManagedModule( + module, self._setup_functions, self._teardown_functions + ) # Load cufunc = module.get_function(self._entry_name) @@ -294,8 +299,10 @@ def add_linking_library(self, library): def add_linking_file(self, path_or_obj): if isinstance(path_or_obj, LinkableCode): - if path_or_obj.init_callback is not None: - self._init_functions.append(path_or_obj.init_callback) + if path_or_obj.setup_callback is not None: + self._setup_functions.append(path_or_obj.setup_callback) + if path_or_obj.teardown_callback is not None: + self._teardown_functions.append(path_or_obj.teardown_callback) self._linking_files.add(path_or_obj) diff --git a/numba_cuda/numba/cuda/cudadrv/linkable_code.py b/numba_cuda/numba/cuda/cudadrv/linkable_code.py index de1e64fba..7303c4832 100644 --- a/numba_cuda/numba/cuda/cudadrv/linkable_code.py +++ b/numba_cuda/numba/cuda/cudadrv/linkable_code.py @@ -7,12 +7,17 @@ class LinkableCode: :param data: A buffer containing the data to link. :param name: The name of the file to be referenced in any compilation or linking errors that may be produced. + :param setup_callback Callback function on kernel before launching. + :param teardown_callback Callback function on kernel before tearing down. """ - def __init__(self, data, name=None, init_callback=None): + def __init__(self, data, name=None, + setup_callback=None, + teardown_callback=None): self.data = data self._name = name - self.init_callback = init_callback + self.setup_callback = setup_callback + self.teardown_callback = teardown_callback @property def name(self): diff --git a/numba_cuda/numba/cuda/cudadrv/managed_module.py b/numba_cuda/numba/cuda/cudadrv/managed_module.py new file mode 100644 index 000000000..ff7ccf64a --- /dev/null +++ b/numba_cuda/numba/cuda/cudadrv/managed_module.py @@ -0,0 +1,30 @@ +from .driver import CtypesModule + + +class ManagedModule: + def __init__(self, module, setup_callbacks, teardown_callbacks): + # To be updated to object code + if not isinstance(module, CtypesModule): + raise TypeError("module must be a CtypesModule") + + if not isinstance(setup_callbacks, list): + raise TypeError("setup_callbacks must be a list") + if not isinstance(teardown_callbacks, list): + raise TypeError("teardown_callbacks must be a list") + + self._module = module + self._setup_callbacks = setup_callbacks + self._teardown_callbacks = teardown_callbacks + + for initialize in self._setup_callbacks: + if not callable(initialize): + raise TypeError("setup_callbacks must be callable") + initialize(self._module) + + def __del__(self): + for teardown in self._teardown_callbacks: + if not callable(teardown): + teardown(self._module) + + def __getattr__(self, name): + return getattr(self._module, name) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py similarity index 52% rename from numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py rename to numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 8c84c71a8..e3d76d1aa 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -1,3 +1,5 @@ +import gc + import numpy as np from numba import cuda @@ -7,7 +9,36 @@ from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD -class TestModuleInitCallback(CUDATestCase): +class TestModuleCallbacksBasic(CUDATestCase): + + def test_basic(self): + counter = [0] + + def setup(mod, counter=counter): + counter[0] += 1 + + def teardown(mod, counter=counter): + counter[0] -= 1 + + lib = CUSource("", setup_callback=setup, teardown_callback=teardown) + + @cuda.jit(link=[lib]) + def kernel(): + pass + + self.assertEqual(counter, [0]) + kernel[1, 1]() + self.assertEqual(counter, [1]) + kernel[1, 1]() # cached + self.assertEqual(counter, [1]) + breakpoint() + del kernel + gc.collect() + # When does the cache gets cleared? + self.assertEqual(counter, [0]) # FAILS + + +class TestModuleCallbacks(CUDATestCase): def setUp(self): super().setUp() @@ -21,18 +52,27 @@ def setUp(self): } """ - def set_fourty_two(mod): + self.counter = 0 + + def set_forty_two(mod): + self.counter += 1 # Initialize 42 to global variable `num` res, dptr, size = cuModuleGetGlobal( mod.handle.value, "num".encode() ) - self.assertEqual(res, 0) - self.assertEqual(size, 4) arr = np.array([42], np.int32) cuMemcpyHtoD(dptr, arr.ctypes.data, size) - self.lib = CUSource(module, init_callback=set_fourty_two) + def teardown(mod): + self.counter -= 1 + + self.lib = CUSource( + module, setup_callback=set_forty_two, teardown_callback=teardown) + + def tearDown(self): + super().tearDown() + del self.lib def test_decldevice_arg(self): get_num = cuda.declare_device("get_num", "int32()", link=[self.lib]) From 54bc46cb6670037b3335f299d60dbb19fb85ea36 Mon Sep 17 00:00:00 2001 From: isVoid Date: Fri, 14 Mar 2025 19:57:43 +0000 Subject: [PATCH 04/36] add managed module object --- .../numba/cuda/cudadrv/managed_module.py | 67 ++++++++++++++++--- numba_cuda/numba/cuda/dispatcher.py | 7 ++ .../tests/cudadrv/test_module_callbacks.py | 66 +++++++++++++----- 3 files changed, 115 insertions(+), 25 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/managed_module.py b/numba_cuda/numba/cuda/cudadrv/managed_module.py index ff7ccf64a..fd755cda2 100644 --- a/numba_cuda/numba/cuda/cudadrv/managed_module.py +++ b/numba_cuda/numba/cuda/cudadrv/managed_module.py @@ -1,5 +1,26 @@ +import weakref + +from numba import config +from . import devices from .driver import CtypesModule +USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING + + +class CuFuncProxy: + def __init__(self, module, cufunc): + self._module = module + self._cufunc = cufunc + + def init_module(self, stream): + self._module._init(stream) + + def lazy_finalize_module(self, stream): + self._module._lazy_finalize(stream) + + def __getattr__(self, name): + return getattr(self._cufunc, name) + class ManagedModule: def __init__(self, module, setup_callbacks, teardown_callbacks): @@ -12,19 +33,49 @@ def __init__(self, module, setup_callbacks, teardown_callbacks): if not isinstance(teardown_callbacks, list): raise TypeError("teardown_callbacks must be a list") + for callback in setup_callbacks: + if not callable(callback): + raise TypeError("all callbacks must be callable") + for callback in teardown_callbacks: + if not callable(callback): + raise TypeError("all callbacks must be callable") + + self._initialized = False self._module = module self._setup_callbacks = setup_callbacks self._teardown_callbacks = teardown_callbacks - for initialize in self._setup_callbacks: - if not callable(initialize): - raise TypeError("setup_callbacks must be callable") - initialize(self._module) + def _init(self, stream): + for setup in self._setup_callbacks: + setup(self._module, stream) + + def _lazy_finalize(self, stream): + + def lazy_callback(callbacks, module, stream): + for teardown in callbacks: + teardown(module, stream) + + ctx = devices.get_context() + if USE_NV_BINDING: + key = self._module.handle + else: + key = self._module.handle.value + module_obj = ctx.modules.get(key, None) + + if module_obj is not None: + weakref.finalize( + module_obj, + lazy_callback, + self._teardown_callbacks, + self._module, + stream + ) - def __del__(self): - for teardown in self._teardown_callbacks: - if not callable(teardown): - teardown(self._module) + def get_function(self, name): + ctypesfunc = self._module.get_function(name) + return CuFuncProxy(self, ctypesfunc) def __getattr__(self, name): + if name == "get_function": + return getattr(self, "get_function") return getattr(self._module, name) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 2c2682555..8ffff3792 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -93,6 +93,7 @@ def __init__(self, py_func, argtypes, link=None, debug=False, self.debug = debug self.lineinfo = lineinfo self.extensions = extensions or [] + self.initialized = False nvvm_options = { 'fastmath': fastmath, @@ -405,6 +406,12 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0): stream_handle = stream and stream.handle or zero_stream + # Set init and finalize module callbacks + if not self.initialized: + cufunc.init_module(stream_handle) + cufunc.lazy_finalize_module(stream_handle) + self.initialized = True + # Invoke kernel driver.launch_kernel(cufunc.handle, *griddim, diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index e3d76d1aa..ed6e96c2a 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -1,5 +1,3 @@ -import gc - import numpy as np from numba import cuda @@ -12,13 +10,15 @@ class TestModuleCallbacksBasic(CUDATestCase): def test_basic(self): - counter = [0] + counter = 0 - def setup(mod, counter=counter): - counter[0] += 1 + def setup(mod, stream): + nonlocal counter + counter += 1 - def teardown(mod, counter=counter): - counter[0] -= 1 + def teardown(mod, stream): + nonlocal counter + counter -= 1 lib = CUSource("", setup_callback=setup, teardown_callback=teardown) @@ -26,16 +26,48 @@ def teardown(mod, counter=counter): def kernel(): pass - self.assertEqual(counter, [0]) + self.assertEqual(counter, 0) kernel[1, 1]() - self.assertEqual(counter, [1]) + self.assertEqual(counter, 1) kernel[1, 1]() # cached - self.assertEqual(counter, [1]) - breakpoint() - del kernel - gc.collect() - # When does the cache gets cleared? - self.assertEqual(counter, [0]) # FAILS + self.assertEqual(counter, 1) + # del kernel + # gc.collect() + # cuda.current_context().deallocations.clear() + # self.assertEqual(counter, 0) + # We don't have a way to explicitly evict kernel and its modules at + # the moment. + + def test_different_argtypes(self): + counter = 0 + + def setup(mod, stream): + nonlocal counter + counter += 1 + + def teardown(mod, stream): + nonlocal counter + counter -= 1 + + lib = CUSource("", setup_callback=setup, teardown_callback=teardown) + + @cuda.jit(link=[lib]) + def kernel(arg): + pass + + self.assertEqual(counter, 0) + kernel[1, 1](42) # (int64)->() : module 1 + self.assertEqual(counter, 1) + kernel[1, 1](100) # (int64)->() : module 1, cached + self.assertEqual(counter, 1) + kernel[1, 1](3.14) # (float64)->() : module 2 + self.assertEqual(counter, 2) + + # del kernel + # gc.collect() + # cuda.current_context().deallocations.clear() + # self.assertEqual(counter, 0) # We don't have a way to explicitly + # evict kernel and its modules at the moment. class TestModuleCallbacks(CUDATestCase): @@ -54,7 +86,7 @@ def setUp(self): self.counter = 0 - def set_forty_two(mod): + def set_forty_two(mod, stream): self.counter += 1 # Initialize 42 to global variable `num` res, dptr, size = cuModuleGetGlobal( @@ -64,7 +96,7 @@ def set_forty_two(mod): arr = np.array([42], np.int32) cuMemcpyHtoD(dptr, arr.ctypes.data, size) - def teardown(mod): + def teardown(mod, stream): self.counter -= 1 self.lib = CUSource( From 51dd29330696b392a70ae33daefde2210a42a388 Mon Sep 17 00:00:00 2001 From: isVoid Date: Sat, 15 Mar 2025 00:15:19 +0000 Subject: [PATCH 05/36] add two kernels test --- .../tests/cudadrv/test_module_callbacks.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index ed6e96c2a..1879a9f7a 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -69,6 +69,38 @@ def kernel(arg): # self.assertEqual(counter, 0) # We don't have a way to explicitly # evict kernel and its modules at the moment. + def test_two_kernels(self): + counter = 0 + + def setup(mod, stream): + nonlocal counter + counter += 1 + + def teardown(mod, stream): + nonlocal counter + counter -= 1 + + lib = CUSource("", setup_callback=setup, teardown_callback=teardown) + + @cuda.jit(link=[lib]) + def kernel(): + pass + + @cuda.jit(link=[lib]) + def kernel2(): + pass + + kernel[1, 1]() + self.assertEqual(counter, 1) + kernel2[1, 1]() + self.assertEqual(counter, 2) + + # del kernel + # gc.collect() + # cuda.current_context().deallocations.clear() + # self.assertEqual(counter, 0) # We don't have a way to explicitly + # evict kernel and its modules at the moment. + class TestModuleCallbacks(CUDATestCase): From d05f76b6d828c74a0f7868e41f2924bb9b5c8cc4 Mon Sep 17 00:00:00 2001 From: isVoid Date: Mon, 17 Mar 2025 16:54:03 +0000 Subject: [PATCH 06/36] add kernel finalizer --- .../numba/cuda/cudadrv/managed_module.py | 9 +++------ numba_cuda/numba/cuda/dispatcher.py | 18 ++++++++++++++++++ .../tests/cudadrv/test_module_callbacks.py | 11 +++++++---- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/managed_module.py b/numba_cuda/numba/cuda/cudadrv/managed_module.py index fd755cda2..3c820fc13 100644 --- a/numba_cuda/numba/cuda/cudadrv/managed_module.py +++ b/numba_cuda/numba/cuda/cudadrv/managed_module.py @@ -1,13 +1,10 @@ import weakref -from numba import config from . import devices -from .driver import CtypesModule +from .driver import CtypesModule, USE_NV_BINDING -USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING - -class CuFuncProxy: +class _CuFuncProxy: def __init__(self, module, cufunc): self._module = module self._cufunc = cufunc @@ -73,7 +70,7 @@ def lazy_callback(callbacks, module, stream): def get_function(self, name): ctypesfunc = self._module.get_function(name) - return CuFuncProxy(self, ctypesfunc) + return _CuFuncProxy(self, ctypesfunc) def __getattr__(self, name): if name == "get_function": diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 8ffff3792..f9d464ec3 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -1,4 +1,5 @@ import numpy as np +import weakref import os import re import sys @@ -247,6 +248,7 @@ def _rebuild(cls, cooperative, name, signature, codelibrary, instance.lineinfo = lineinfo instance.call_helper = call_helper instance.extensions = extensions + instance.initialized = False return instance def _reduce_states(self): @@ -1023,6 +1025,7 @@ def compile(self, sig): raise RuntimeError("Compilation disabled") kernel = _Kernel(self.py_func, argtypes, **self.targetoptions) + weakref.finalize(kernel, _kernel_finalize_callback, kernel) # We call bind to force codegen, so that there is a cubin to cache kernel.bind() self._cache.save_overload(sig, kernel) @@ -1148,3 +1151,18 @@ def _reduce_states(self): """ return dict(py_func=self.py_func, targetoptions=self.targetoptions) + + +def _kernel_finalize_callback(kernel): + module = kernel.library.get_cufunc().module + try: + if driver.USE_NV_BINDING: + key = module.handle + else: + key = module.handle.value + except ReferenceError: + return + + ctx = cuda.current_context() + if key in ctx.modules: + del ctx.modules[key] diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 1879a9f7a..f7b253952 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -1,3 +1,5 @@ +import gc + import numpy as np from numba import cuda @@ -31,10 +33,11 @@ def kernel(): self.assertEqual(counter, 1) kernel[1, 1]() # cached self.assertEqual(counter, 1) - # del kernel - # gc.collect() - # cuda.current_context().deallocations.clear() - # self.assertEqual(counter, 0) + breakpoint() + del kernel + gc.collect() + cuda.current_context().deallocations.clear() + self.assertEqual(counter, 0) # We don't have a way to explicitly evict kernel and its modules at # the moment. From b32983109de44cd98a736ca5feb29382198400bc Mon Sep 17 00:00:00 2001 From: isVoid Date: Mon, 17 Mar 2025 17:59:39 +0000 Subject: [PATCH 07/36] removing kernel finalizers --- numba_cuda/numba/cuda/dispatcher.py | 17 -------- .../tests/cudadrv/test_module_callbacks.py | 39 ++++++++----------- 2 files changed, 17 insertions(+), 39 deletions(-) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index f9d464ec3..3424a0866 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -1,5 +1,4 @@ import numpy as np -import weakref import os import re import sys @@ -1025,7 +1024,6 @@ def compile(self, sig): raise RuntimeError("Compilation disabled") kernel = _Kernel(self.py_func, argtypes, **self.targetoptions) - weakref.finalize(kernel, _kernel_finalize_callback, kernel) # We call bind to force codegen, so that there is a cubin to cache kernel.bind() self._cache.save_overload(sig, kernel) @@ -1151,18 +1149,3 @@ def _reduce_states(self): """ return dict(py_func=self.py_func, targetoptions=self.targetoptions) - - -def _kernel_finalize_callback(kernel): - module = kernel.library.get_cufunc().module - try: - if driver.USE_NV_BINDING: - key = module.handle - else: - key = module.handle.value - except ReferenceError: - return - - ctx = cuda.current_context() - if key in ctx.modules: - del ctx.modules[key] diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index f7b253952..4b13c582a 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -1,14 +1,23 @@ -import gc +import unittest import numpy as np -from numba import cuda +from numba import cuda, config from numba.cuda.cudadrv.linkable_code import CUSource from numba.cuda.testing import CUDATestCase from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD +def wipe_all_modules_in_context(): + ctx = cuda.current_context() + ctx.modules.clear() + + +@unittest.skipIf( + config.CUDA_USE_NVIDIA_BINDING, + "NV binding support superceded by cuda.bindings." +) class TestModuleCallbacksBasic(CUDATestCase): def test_basic(self): @@ -33,13 +42,9 @@ def kernel(): self.assertEqual(counter, 1) kernel[1, 1]() # cached self.assertEqual(counter, 1) - breakpoint() - del kernel - gc.collect() - cuda.current_context().deallocations.clear() + + wipe_all_modules_in_context() self.assertEqual(counter, 0) - # We don't have a way to explicitly evict kernel and its modules at - # the moment. def test_different_argtypes(self): counter = 0 @@ -66,11 +71,8 @@ def kernel(arg): kernel[1, 1](3.14) # (float64)->() : module 2 self.assertEqual(counter, 2) - # del kernel - # gc.collect() - # cuda.current_context().deallocations.clear() - # self.assertEqual(counter, 0) # We don't have a way to explicitly - # evict kernel and its modules at the moment. + wipe_all_modules_in_context() + self.assertEqual(counter, 0) def test_two_kernels(self): counter = 0 @@ -98,11 +100,8 @@ def kernel2(): kernel2[1, 1]() self.assertEqual(counter, 2) - # del kernel - # gc.collect() - # cuda.current_context().deallocations.clear() - # self.assertEqual(counter, 0) # We don't have a way to explicitly - # evict kernel and its modules at the moment. + wipe_all_modules_in_context() + self.assertEqual(counter, 0) class TestModuleCallbacks(CUDATestCase): @@ -137,10 +136,6 @@ def teardown(mod, stream): self.lib = CUSource( module, setup_callback=set_forty_two, teardown_callback=teardown) - def tearDown(self): - super().tearDown() - del self.lib - def test_decldevice_arg(self): get_num = cuda.declare_device("get_num", "int32()", link=[self.lib]) From c6a9b731089c92e163eddf2b3fee787e3862021a Mon Sep 17 00:00:00 2001 From: isVoid Date: Mon, 17 Mar 2025 18:10:53 +0000 Subject: [PATCH 08/36] add docstring --- numba_cuda/numba/cuda/cudadrv/managed_module.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/numba_cuda/numba/cuda/cudadrv/managed_module.py b/numba_cuda/numba/cuda/cudadrv/managed_module.py index 3c820fc13..5cd174fad 100644 --- a/numba_cuda/numba/cuda/cudadrv/managed_module.py +++ b/numba_cuda/numba/cuda/cudadrv/managed_module.py @@ -5,6 +5,8 @@ class _CuFuncProxy: + """See `ManagedModule` for more details + """ def __init__(self, module, cufunc): self._module = module self._cufunc = cufunc @@ -21,6 +23,14 @@ def __getattr__(self, name): class ManagedModule: def __init__(self, module, setup_callbacks, teardown_callbacks): + """ctypes module with setup and teardown callbacks + The use of managedmodule is the same as a ctypes module, + with the exception of `get_function`, which returns a wrapped + cufunc object. The wrapped object provides `init_module` + and `lazy_finalize_module` method. Which are used + to initialize and finalize the module when stream is available + in the later stage of the compilation pipeline. + """ # To be updated to object code if not isinstance(module, CtypesModule): raise TypeError("module must be a CtypesModule") @@ -43,11 +53,14 @@ def __init__(self, module, setup_callbacks, teardown_callbacks): self._teardown_callbacks = teardown_callbacks def _init(self, stream): + """Eagerly call the setup functions for cumodule on `stream` + """ for setup in self._setup_callbacks: setup(self._module, stream) def _lazy_finalize(self, stream): - + """Set teardown function for cumodule via weakref finalizer. + """ def lazy_callback(callbacks, module, stream): for teardown in callbacks: teardown(module, stream) @@ -69,6 +82,8 @@ def lazy_callback(callbacks, module, stream): ) def get_function(self, name): + """Returns wrapped CtypesFunc object. + """ ctypesfunc = self._module.get_function(name) return _CuFuncProxy(self, ctypesfunc) From 881648c6ecf0a06e49f75660926f2f681c35dc8b Mon Sep 17 00:00:00 2001 From: isVoid Date: Mon, 17 Mar 2025 18:29:03 +0000 Subject: [PATCH 09/36] add doc --- docs/source/user/cuda_ffi.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/source/user/cuda_ffi.rst b/docs/source/user/cuda_ffi.rst index 48d48f8f9..cc6d49bee 100644 --- a/docs/source/user/cuda_ffi.rst +++ b/docs/source/user/cuda_ffi.rst @@ -160,6 +160,19 @@ CUDA C/C++ source code will be compiled with the `NVIDIA Runtime Compiler kernel as either PTX or LTOIR, depending on whether LTO is enabled. Other files will be passed directly to the CUDA Linker. +For cumodules linked to a linkable code object, additional setup +and teardown callbacks are available. Setup functions are invoked once for +every new module loaded. Teardown functions are invoked before the module is +garbage collected by python. These callbacks are defined as follows: + +.. code:: + + def setup_callback( + mod: cuda.driver.CtypeModule, stream: cuda.driver.Stream + ) + +The `stream` argument is the same stream user passes in to launch the kernel. + :class:`LinkableCode ` objects are initialized using the parameters of their base class: From 6f7fe2a947de44d9bc82dbd3fbca66344bbef385 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Tue, 18 Mar 2025 13:07:26 -0700 Subject: [PATCH 10/36] Update docs/source/user/cuda_ffi.rst Co-authored-by: Graham Markall <535640+gmarkall@users.noreply.github.com> --- docs/source/user/cuda_ffi.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/user/cuda_ffi.rst b/docs/source/user/cuda_ffi.rst index cc6d49bee..e1f7f54dc 100644 --- a/docs/source/user/cuda_ffi.rst +++ b/docs/source/user/cuda_ffi.rst @@ -162,8 +162,8 @@ will be passed directly to the CUDA Linker. For cumodules linked to a linkable code object, additional setup and teardown callbacks are available. Setup functions are invoked once for -every new module loaded. Teardown functions are invoked before the module is -garbage collected by python. These callbacks are defined as follows: +every new module loaded. Teardown functions are invoked just prior to +module unloading. These callbacks are defined as follows: .. code:: From d4cc4acd6d75003c4cd998e3bcae83069c951869 Mon Sep 17 00:00:00 2001 From: isVoid Date: Wed, 19 Mar 2025 15:17:18 -0700 Subject: [PATCH 11/36] add a test that involves two streams --- .../tests/cudadrv/test_module_callbacks.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 4b13c582a..c8211794a 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -103,6 +103,42 @@ def kernel2(): wipe_all_modules_in_context() self.assertEqual(counter, 0) + def test_different_streams(self): + + s1 = cuda.stream() + s2 = cuda.stream() + + counter = 0 + + def setup(mod, stream, self=self): + nonlocal counter, s1, s2 + if counter == 0: + self.assertEqual(stream, s1.handle) + elif counter == 1: + self.assertEqual(stream, s2.handle) + else: + raise RuntimeError("Setup should only be invoked twice.") + counter += 1 + + def teardown(mod, stream, self=self): + nonlocal counter, s1, s2 + if counter == 2: + self.assertEqual(stream, s2.handle) + elif counter == 1: + self.assertEqual(stream, s1.handle) + else: + raise RuntimeError("Teardown should only be invoked twice.") + counter -= 1 + + lib = CUSource("", setup_callback=setup, teardown_callback=teardown) + + @cuda.jit(link=[lib]) + def kernel(): + pass + + kernel[1, 1, s1]() + kernel[1, 1, s2]() + class TestModuleCallbacks(CUDATestCase): From af2d618517d0e394d2dcd48f86223f6b25cb6b26 Mon Sep 17 00:00:00 2001 From: isVoid Date: Wed, 19 Mar 2025 15:24:57 -0700 Subject: [PATCH 12/36] add wipe all module call --- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index c8211794a..be9cc35bb 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -138,6 +138,8 @@ def kernel(): kernel[1, 1, s1]() kernel[1, 1, s2]() + wipe_all_modules_in_context() + self.assertEqual(counter, 0) class TestModuleCallbacks(CUDATestCase): From a5a367100add50075fcd0fb0ac894d1829b3b58c Mon Sep 17 00:00:00 2001 From: isVoid Date: Wed, 19 Mar 2025 15:41:14 -0700 Subject: [PATCH 13/36] use context reset is a better option to unload modules --- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index be9cc35bb..1f95d6878 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -11,7 +11,7 @@ def wipe_all_modules_in_context(): ctx = cuda.current_context() - ctx.modules.clear() + ctx.reset() @unittest.skipIf( From 577c8ffd1a264081437b95a7810764cd5a0fb445 Mon Sep 17 00:00:00 2001 From: isVoid Date: Wed, 19 Mar 2025 15:53:37 -0700 Subject: [PATCH 14/36] add test for stream completeness --- .../cuda/tests/cudadrv/test_module_callbacks.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 1f95d6878..76ad272dc 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -108,26 +108,33 @@ def test_different_streams(self): s1 = cuda.stream() s2 = cuda.stream() + setup_seen = set() + teardown_seen = set() + counter = 0 def setup(mod, stream, self=self): - nonlocal counter, s1, s2 + nonlocal counter, s1, s2, setup_seen + print("setup!") if counter == 0: self.assertEqual(stream, s1.handle) elif counter == 1: self.assertEqual(stream, s2.handle) else: raise RuntimeError("Setup should only be invoked twice.") + setup_seen.add(stream.value) counter += 1 def teardown(mod, stream, self=self): - nonlocal counter, s1, s2 + print("teardown!") + nonlocal counter, s1, s2, teardown_seen if counter == 2: self.assertEqual(stream, s2.handle) elif counter == 1: self.assertEqual(stream, s1.handle) else: raise RuntimeError("Teardown should only be invoked twice.") + teardown_seen.add(stream.value) counter -= 1 lib = CUSource("", setup_callback=setup, teardown_callback=teardown) @@ -140,6 +147,8 @@ def kernel(): kernel[1, 1, s2]() wipe_all_modules_in_context() self.assertEqual(counter, 0) + self.assertEqual(setup_seen, {s1.handle.value, s2.handle.value}) + self.assertEqual(teardown_seen, {s1.handle.value, s2.handle.value}) class TestModuleCallbacks(CUDATestCase): From d8f2f230089b6f4d52703c5fe546d0d1d52e8588 Mon Sep 17 00:00:00 2001 From: isVoid Date: Wed, 19 Mar 2025 22:17:27 -0700 Subject: [PATCH 15/36] move logic into CtypesModule --- numba_cuda/numba/cuda/codegen.py | 9 +- numba_cuda/numba/cuda/cudadrv/driver.py | 51 ++++++++++ .../numba/cuda/cudadrv/managed_module.py | 93 ------------------- numba_cuda/numba/cuda/dispatcher.py | 5 +- .../tests/cudadrv/test_module_callbacks.py | 2 - 5 files changed, 56 insertions(+), 104 deletions(-) delete mode 100644 numba_cuda/numba/cuda/cudadrv/managed_module.py diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index dc1cf21d4..85fa35b79 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -5,7 +5,6 @@ from .cudadrv import devices, driver, nvvm, runtime from numba.cuda.cudadrv.libs import get_cudalib from numba.cuda.cudadrv.linkable_code import LinkableCode -from numba.cuda.cudadrv.managed_module import ManagedModule import os import subprocess @@ -254,12 +253,8 @@ def get_cufunc(self): return cufunc cubin = self.get_cubin(cc=device.compute_capability) - module = ctx.create_module_image(cubin) - - # Wrap ctypesmodule with managed module with auto module setup/teardown - module = ManagedModule( - module, self._setup_functions, self._teardown_functions - ) + module = ctx.create_module_image_with_callbacks( + cubin, self._setup_functions, self._teardown_functions) # Load cufunc = module.get_function(self._entry_name) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 1641bf779..58e07811e 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -1442,6 +1442,13 @@ def create_module_image(self, image): self.modules[key] = module return weakref.proxy(module) + def create_module_image_with_callbacks( + self, image, setup_callbacks, teardown_callbacks): + mod = self.create_module_image(image) + mod.set_setup_functions(setup_callbacks) + mod.set_teardown_functions(teardown_callbacks) + return mod + def unload_module(self, module): if USE_NV_BINDING: key = module.handle @@ -2394,6 +2401,9 @@ def __init__(self, context, handle, info_log, finalizer=None): if finalizer is not None: self._finalizer = weakref.finalize(self, finalizer) + self.setup_functions = None + self.teardown_functions = None + def unload(self): """Unload this module from the context""" self.context.unload_module(self) @@ -2406,6 +2416,47 @@ def get_function(self, name): def get_global_symbol(self, name): """Return a MemoryPointer referring to the named symbol""" + def _set_callables(self, callbacks, attr): + if not isinstance(callbacks, list): + raise TypeError("callbacks must be a list") + + for f in callbacks: + if not callable(f): + raise TypeError("callback must be callable") + + setattr(self, attr, callbacks) + + def set_setup_functions(self, callbacks): + """Set the setup callback for this module""" + self._set_callables(callbacks, "setup_functions") + + def set_teardown_functions(self, callbacks): + """Set the finalize callback for this module""" + self._set_callables(callbacks, "teardown_functions") + + def setup(self, stream): + if self.setup_functions is None: + return + + for f in self.setup_functions: + f(weakref.proxy(self), stream) + + def set_finalizers(self, stream): + if self.teardown_functions is None: + return + + def _teardown(teardowns, modref, stream): + for f in teardowns: + f(modref, stream) + + weakref.finalize( + self, + _teardown, + self.teardown_functions, + weakref.proxy(self), + stream + ) + class CtypesModule(Module): diff --git a/numba_cuda/numba/cuda/cudadrv/managed_module.py b/numba_cuda/numba/cuda/cudadrv/managed_module.py deleted file mode 100644 index 5cd174fad..000000000 --- a/numba_cuda/numba/cuda/cudadrv/managed_module.py +++ /dev/null @@ -1,93 +0,0 @@ -import weakref - -from . import devices -from .driver import CtypesModule, USE_NV_BINDING - - -class _CuFuncProxy: - """See `ManagedModule` for more details - """ - def __init__(self, module, cufunc): - self._module = module - self._cufunc = cufunc - - def init_module(self, stream): - self._module._init(stream) - - def lazy_finalize_module(self, stream): - self._module._lazy_finalize(stream) - - def __getattr__(self, name): - return getattr(self._cufunc, name) - - -class ManagedModule: - def __init__(self, module, setup_callbacks, teardown_callbacks): - """ctypes module with setup and teardown callbacks - The use of managedmodule is the same as a ctypes module, - with the exception of `get_function`, which returns a wrapped - cufunc object. The wrapped object provides `init_module` - and `lazy_finalize_module` method. Which are used - to initialize and finalize the module when stream is available - in the later stage of the compilation pipeline. - """ - # To be updated to object code - if not isinstance(module, CtypesModule): - raise TypeError("module must be a CtypesModule") - - if not isinstance(setup_callbacks, list): - raise TypeError("setup_callbacks must be a list") - if not isinstance(teardown_callbacks, list): - raise TypeError("teardown_callbacks must be a list") - - for callback in setup_callbacks: - if not callable(callback): - raise TypeError("all callbacks must be callable") - for callback in teardown_callbacks: - if not callable(callback): - raise TypeError("all callbacks must be callable") - - self._initialized = False - self._module = module - self._setup_callbacks = setup_callbacks - self._teardown_callbacks = teardown_callbacks - - def _init(self, stream): - """Eagerly call the setup functions for cumodule on `stream` - """ - for setup in self._setup_callbacks: - setup(self._module, stream) - - def _lazy_finalize(self, stream): - """Set teardown function for cumodule via weakref finalizer. - """ - def lazy_callback(callbacks, module, stream): - for teardown in callbacks: - teardown(module, stream) - - ctx = devices.get_context() - if USE_NV_BINDING: - key = self._module.handle - else: - key = self._module.handle.value - module_obj = ctx.modules.get(key, None) - - if module_obj is not None: - weakref.finalize( - module_obj, - lazy_callback, - self._teardown_callbacks, - self._module, - stream - ) - - def get_function(self, name): - """Returns wrapped CtypesFunc object. - """ - ctypesfunc = self._module.get_function(name) - return _CuFuncProxy(self, ctypesfunc) - - def __getattr__(self, name): - if name == "get_function": - return getattr(self, "get_function") - return getattr(self._module, name) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index ba712a071..66be7e57f 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -452,8 +452,9 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0): # Set init and finalize module callbacks if not self.initialized: - cufunc.init_module(stream_handle) - cufunc.lazy_finalize_module(stream_handle) + mod = cufunc.module + mod.setup(stream_handle) + mod.set_finalizers(stream_handle) self.initialized = True # Invoke kernel diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 76ad272dc..6c3a5ca54 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -115,7 +115,6 @@ def test_different_streams(self): def setup(mod, stream, self=self): nonlocal counter, s1, s2, setup_seen - print("setup!") if counter == 0: self.assertEqual(stream, s1.handle) elif counter == 1: @@ -126,7 +125,6 @@ def setup(mod, stream, self=self): counter += 1 def teardown(mod, stream, self=self): - print("teardown!") nonlocal counter, s1, s2, teardown_seen if counter == 2: self.assertEqual(stream, s2.handle) From bb13580a7d5c1373ad27c0dc7831d2e065a34a11 Mon Sep 17 00:00:00 2001 From: isVoid Date: Wed, 19 Mar 2025 22:23:08 -0700 Subject: [PATCH 16/36] update docstrings --- numba_cuda/numba/cuda/cudadrv/driver.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 58e07811e..3872448aa 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -2435,6 +2435,7 @@ def set_teardown_functions(self, callbacks): self._set_callables(callbacks, "teardown_functions") def setup(self, stream): + """Call the setup functions for cumodule with the given `stream`""" if self.setup_functions is None: return @@ -2442,6 +2443,7 @@ def setup(self, stream): f(weakref.proxy(self), stream) def set_finalizers(self, stream): + """Create finalizers that tears down the cumodule on `stream`""" if self.teardown_functions is None: return From 379a69b06498db2c71d49d36d40ef58614a2ae5f Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 20 Mar 2025 15:28:42 -0700 Subject: [PATCH 17/36] consolidate changes into create_module_image --- numba_cuda/numba/cuda/codegen.py | 2 +- numba_cuda/numba/cuda/cudadrv/driver.py | 49 +++++++++++-------------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index 85fa35b79..d26dce5d6 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -253,7 +253,7 @@ def get_cufunc(self): return cufunc cubin = self.get_cubin(cc=device.compute_capability) - module = ctx.create_module_image_with_callbacks( + module = ctx.create_module_image( cubin, self._setup_functions, self._teardown_functions) # Load diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 3872448aa..2a70fc4e9 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -1433,8 +1433,10 @@ def create_module_ptx(self, ptx): image = c_char_p(ptx) return self.create_module_image(image) - def create_module_image(self, image): - module = load_module_image(self, image) + def create_module_image(self, image, + setup_callbacks=None, teardown_callbacks=None): + module = load_module_image(self, image, + setup_callbacks, teardown_callbacks) if USE_NV_BINDING: key = module.handle else: @@ -1442,13 +1444,6 @@ def create_module_image(self, image): self.modules[key] = module return weakref.proxy(module) - def create_module_image_with_callbacks( - self, image, setup_callbacks, teardown_callbacks): - mod = self.create_module_image(image) - mod.set_setup_functions(setup_callbacks) - mod.set_teardown_functions(teardown_callbacks) - return mod - def unload_module(self, module): if USE_NV_BINDING: key = module.handle @@ -1535,17 +1530,21 @@ def __ne__(self, other): return not self.__eq__(other) -def load_module_image(context, image): +def load_module_image(context, image, + setup_callbacks=None, teardown_callbacks=None): """ image must be a pointer """ if USE_NV_BINDING: - return load_module_image_cuda_python(context, image) + return load_module_image_cuda_python( + context, image, setup_callbacks, teardown_callbacks) else: - return load_module_image_ctypes(context, image) + return load_module_image_ctypes( + context, image, setup_callbacks, teardown_callbacks) -def load_module_image_ctypes(context, image): +def load_module_image_ctypes(context, image, + setup_callbacks, teardown_callbacks): logsz = config.CUDA_LOG_SIZE jitinfo = (c_char * logsz)() @@ -1573,10 +1572,12 @@ def load_module_image_ctypes(context, image): info_log = jitinfo.value return CtypesModule(weakref.proxy(context), handle, info_log, - _module_finalizer(context, handle)) + _module_finalizer(context, handle), + setup_callbacks, teardown_callbacks) -def load_module_image_cuda_python(context, image): +def load_module_image_cuda_python(context, image, + setup_callbacks, teardown_callbacks): """ image must be a pointer """ @@ -1608,7 +1609,8 @@ def load_module_image_cuda_python(context, image): info_log = jitinfo.decode('utf-8') return CudaPythonModule(weakref.proxy(context), handle, info_log, - _module_finalizer(context, handle)) + _module_finalizer(context, handle), + setup_callbacks, teardown_callbacks) def _alloc_finalizer(memory_manager, ptr, alloc_key, size): @@ -2394,15 +2396,16 @@ def event_elapsed_time(evtstart, evtend): class Module(metaclass=ABCMeta): """Abstract base class for modules""" - def __init__(self, context, handle, info_log, finalizer=None): + def __init__(self, context, handle, info_log, finalizer=None, + setup_callbacks=None, teardown_callbacks=None): self.context = context self.handle = handle self.info_log = info_log if finalizer is not None: self._finalizer = weakref.finalize(self, finalizer) - self.setup_functions = None - self.teardown_functions = None + self.setup_functions = setup_callbacks + self.teardown_functions = teardown_callbacks def unload(self): """Unload this module from the context""" @@ -2426,14 +2429,6 @@ def _set_callables(self, callbacks, attr): setattr(self, attr, callbacks) - def set_setup_functions(self, callbacks): - """Set the setup callback for this module""" - self._set_callables(callbacks, "setup_functions") - - def set_teardown_functions(self, callbacks): - """Set the finalize callback for this module""" - self._set_callables(callbacks, "teardown_functions") - def setup(self, stream): """Call the setup functions for cumodule with the given `stream`""" if self.setup_functions is None: From c5d21fed16e216546f96444d7c116d01b52b0252 Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 20 Mar 2025 15:30:10 -0700 Subject: [PATCH 18/36] explicitly delete kernel reference --- .../numba/cuda/tests/cudadrv/test_module_callbacks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 6c3a5ca54..34f89e9d4 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -44,6 +44,7 @@ def kernel(): self.assertEqual(counter, 1) wipe_all_modules_in_context() + del kernel self.assertEqual(counter, 0) def test_different_argtypes(self): @@ -72,6 +73,7 @@ def kernel(arg): self.assertEqual(counter, 2) wipe_all_modules_in_context() + del kernel self.assertEqual(counter, 0) def test_two_kernels(self): @@ -101,6 +103,7 @@ def kernel2(): self.assertEqual(counter, 2) wipe_all_modules_in_context() + del kernel self.assertEqual(counter, 0) def test_different_streams(self): @@ -144,9 +147,8 @@ def kernel(): kernel[1, 1, s1]() kernel[1, 1, s2]() wipe_all_modules_in_context() + del kernel self.assertEqual(counter, 0) - self.assertEqual(setup_seen, {s1.handle.value, s2.handle.value}) - self.assertEqual(teardown_seen, {s1.handle.value, s2.handle.value}) class TestModuleCallbacks(CUDATestCase): From 36ad1158129e4e657073bdd1198fa6ee06fb7d3f Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 20 Mar 2025 17:07:26 -0700 Subject: [PATCH 19/36] remove stream from setup and teardown callbacks --- numba_cuda/numba/cuda/cudadrv/driver.py | 26 ++++-- numba_cuda/numba/cuda/dispatcher.py | 12 +-- .../tests/cudadrv/test_module_callbacks.py | 84 ++++++------------- 3 files changed, 47 insertions(+), 75 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 2a70fc4e9..dc0c95fb2 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -2404,9 +2404,12 @@ def __init__(self, context, handle, info_log, finalizer=None, if finalizer is not None: self._finalizer = weakref.finalize(self, finalizer) + self.initialized = False self.setup_functions = setup_callbacks self.teardown_functions = teardown_callbacks + self._set_finalizers() + def unload(self): """Unload this module from the context""" self.context.unload_module(self) @@ -2429,29 +2432,34 @@ def _set_callables(self, callbacks, attr): setattr(self, attr, callbacks) - def setup(self, stream): + def setup(self): """Call the setup functions for cumodule with the given `stream`""" - if self.setup_functions is None: + if self.setup_functions is None or self.initialized: return for f in self.setup_functions: - f(weakref.proxy(self), stream) + f(self.handle) + + self.initialized = True - def set_finalizers(self, stream): - """Create finalizers that tears down the cumodule on `stream`""" + def _set_finalizers(self): + """Create finalizers that tears down the cumodule. + + Unlike the setup functions which takes in streams, Numba does not + provide a stream object to the teardown function. + """ if self.teardown_functions is None: return - def _teardown(teardowns, modref, stream): + def _teardown(teardowns, modref): for f in teardowns: - f(modref, stream) + f(modref) weakref.finalize( self, _teardown, self.teardown_functions, - weakref.proxy(self), - stream + self.handle, ) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 66be7e57f..2afbed853 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -143,7 +143,6 @@ def __init__(self, py_func, argtypes, link=None, debug=False, self.debug = debug self.lineinfo = lineinfo self.extensions = extensions or [] - self.initialized = False nvvm_options = { 'fastmath': fastmath, @@ -290,7 +289,6 @@ def _rebuild(cls, cooperative, name, signature, codelibrary, instance.lineinfo = lineinfo instance.call_helper = call_helper instance.extensions = extensions - instance.initialized = False return instance def _reduce_states(self): @@ -312,6 +310,9 @@ def bind(self): """ cufunc = self._codelibrary.get_cufunc() + mod = cufunc.module + mod.setup() + if ( hasattr(self, "target_context") and self.target_context.enable_nrt @@ -450,13 +451,6 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0): stream_handle = stream and stream.handle or zero_stream - # Set init and finalize module callbacks - if not self.initialized: - mod = cufunc.module - mod.setup(stream_handle) - mod.set_finalizers(stream_handle) - self.initialized = True - # Invoke kernel driver.launch_kernel(cufunc.handle, *griddim, diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 34f89e9d4..487dff033 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -4,7 +4,7 @@ from numba import cuda, config from numba.cuda.cudadrv.linkable_code import CUSource -from numba.cuda.testing import CUDATestCase +from numba.cuda.testing import CUDATestCase, ContextResettingTestCase from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD @@ -18,16 +18,16 @@ def wipe_all_modules_in_context(): config.CUDA_USE_NVIDIA_BINDING, "NV binding support superceded by cuda.bindings." ) -class TestModuleCallbacksBasic(CUDATestCase): +class TestModuleCallbacksBasic(ContextResettingTestCase): def test_basic(self): counter = 0 - def setup(mod, stream): + def setup(handle): nonlocal counter counter += 1 - def teardown(mod, stream): + def teardown(handle): nonlocal counter counter -= 1 @@ -49,14 +49,18 @@ def kernel(): def test_different_argtypes(self): counter = 0 + setup_seen = set() + teardown_seen = set() - def setup(mod, stream): - nonlocal counter + def setup(handle): + nonlocal counter, setup_seen counter += 1 + setup_seen.add(handle.value) - def teardown(mod, stream): + def teardown(handle): nonlocal counter counter -= 1 + teardown_seen.add(handle.value) lib = CUSource("", setup_callback=setup, teardown_callback=teardown) @@ -76,16 +80,23 @@ def kernel(arg): del kernel self.assertEqual(counter, 0) + self.assertEqual(len(setup_seen), 2) + self.assertEqual(len(teardown_seen), 2) + def test_two_kernels(self): counter = 0 + setup_seen = set() + teardown_seen = set() - def setup(mod, stream): - nonlocal counter + def setup(handle): + nonlocal counter, setup_seen counter += 1 + setup_seen.add(handle.value) - def teardown(mod, stream): - nonlocal counter + def teardown(handle): + nonlocal counter, teardown_seen counter -= 1 + teardown_seen.add(handle.value) lib = CUSource("", setup_callback=setup, teardown_callback=teardown) @@ -106,49 +117,8 @@ def kernel2(): del kernel self.assertEqual(counter, 0) - def test_different_streams(self): - - s1 = cuda.stream() - s2 = cuda.stream() - - setup_seen = set() - teardown_seen = set() - - counter = 0 - - def setup(mod, stream, self=self): - nonlocal counter, s1, s2, setup_seen - if counter == 0: - self.assertEqual(stream, s1.handle) - elif counter == 1: - self.assertEqual(stream, s2.handle) - else: - raise RuntimeError("Setup should only be invoked twice.") - setup_seen.add(stream.value) - counter += 1 - - def teardown(mod, stream, self=self): - nonlocal counter, s1, s2, teardown_seen - if counter == 2: - self.assertEqual(stream, s2.handle) - elif counter == 1: - self.assertEqual(stream, s1.handle) - else: - raise RuntimeError("Teardown should only be invoked twice.") - teardown_seen.add(stream.value) - counter -= 1 - - lib = CUSource("", setup_callback=setup, teardown_callback=teardown) - - @cuda.jit(link=[lib]) - def kernel(): - pass - - kernel[1, 1, s1]() - kernel[1, 1, s2]() - wipe_all_modules_in_context() - del kernel - self.assertEqual(counter, 0) + self.assertEqual(len(setup_seen), 2) + self.assertEqual(len(teardown_seen), 2) class TestModuleCallbacks(CUDATestCase): @@ -167,17 +137,17 @@ def setUp(self): self.counter = 0 - def set_forty_two(mod, stream): + def set_forty_two(handle): self.counter += 1 # Initialize 42 to global variable `num` res, dptr, size = cuModuleGetGlobal( - mod.handle.value, "num".encode() + handle.value, "num".encode() ) arr = np.array([42], np.int32) cuMemcpyHtoD(dptr, arr.ctypes.data, size) - def teardown(mod, stream): + def teardown(handle): self.counter -= 1 self.lib = CUSource( From 2cc4c28e08baeeec78202cd857c0dc0600c9ca8c Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 20 Mar 2025 17:16:59 -0700 Subject: [PATCH 20/36] remove counter --- .../numba/cuda/tests/cudadrv/test_module_callbacks.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 487dff033..c83bd0786 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -135,10 +135,7 @@ def setUp(self): } """ - self.counter = 0 - def set_forty_two(handle): - self.counter += 1 # Initialize 42 to global variable `num` res, dptr, size = cuModuleGetGlobal( handle.value, "num".encode() @@ -147,11 +144,8 @@ def set_forty_two(handle): arr = np.array([42], np.int32) cuMemcpyHtoD(dptr, arr.ctypes.data, size) - def teardown(handle): - self.counter -= 1 - self.lib = CUSource( - module, setup_callback=set_forty_two, teardown_callback=teardown) + module, setup_callback=set_forty_two, teardown_callback=None) def test_decldevice_arg(self): get_num = cuda.declare_device("get_num", "int32()", link=[self.lib]) From 01d2d85bbcdd81660ef8807da6a2858d5a35dc07 Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 20 Mar 2025 17:19:57 -0700 Subject: [PATCH 21/36] add API coverage test --- .../tests/cudadrv/test_module_callbacks.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index c83bd0786..d2bf39a7a 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -121,6 +121,34 @@ def kernel2(): self.assertEqual(len(teardown_seen), 2) +class TestModuleCallbacksAPICompleteness(CUDATestCase): + + def test_api(self): + def setup(handle): + pass + + def teardown(handle): + pass + + api_combo = [ + (setup, teardown), + (setup, None), + (None, teardown), + (None, None) + ] + + for setup, teardown in api_combo: + with self.subTest(setup=setup, teardown=teardown): + lib = CUSource( + "", setup_callback=setup, teardown_callback=teardown) + + @cuda.jit(link=[lib]) + def kernel(): + pass + + kernel[1, 1]() + + class TestModuleCallbacks(CUDATestCase): def setUp(self): From 24e426029ecf79de73d56a594b05b03391852423 Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 20 Mar 2025 17:25:05 -0700 Subject: [PATCH 22/36] asserting types of passed in module handle --- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index d2bf39a7a..53afe203e 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -5,6 +5,7 @@ from numba import cuda, config from numba.cuda.cudadrv.linkable_code import CUSource from numba.cuda.testing import CUDATestCase, ContextResettingTestCase +from numba.cuda.cudadrv.drvapi import cu_module from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD @@ -24,10 +25,12 @@ def test_basic(self): counter = 0 def setup(handle): + self.assertTrue(isinstance(handle, cu_module)) nonlocal counter counter += 1 def teardown(handle): + self.assertTrue(isinstance(handle, cu_module)) nonlocal counter counter -= 1 From 328c4308df11bd507c7fc97325ef933f8fb8b84f Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 20 Mar 2025 17:35:57 -0700 Subject: [PATCH 23/36] update documentation --- docs/source/user/cuda_ffi.rst | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/source/user/cuda_ffi.rst b/docs/source/user/cuda_ffi.rst index e1f7f54dc..03d78e0c5 100644 --- a/docs/source/user/cuda_ffi.rst +++ b/docs/source/user/cuda_ffi.rst @@ -167,11 +167,8 @@ module unloading. These callbacks are defined as follows: .. code:: - def setup_callback( - mod: cuda.driver.CtypeModule, stream: cuda.driver.Stream - ) - -The `stream` argument is the same stream user passes in to launch the kernel. + def setup_callback(mod: cuda.cudadrv.drvapi.cu_module):... + def teardown_callback(mod: cuda.cudadrv.drvapi.cu_module):... :class:`LinkableCode ` objects are initialized using the parameters of their base class: From 43bcfa226fbc80f2a907af841716ae5a00d7a29d Mon Sep 17 00:00:00 2001 From: isVoid Date: Fri, 21 Mar 2025 10:50:20 -0700 Subject: [PATCH 24/36] address review comments --- numba_cuda/numba/cuda/cudadrv/driver.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index dc0c95fb2..f6e7c462c 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -2422,18 +2422,8 @@ def get_function(self, name): def get_global_symbol(self, name): """Return a MemoryPointer referring to the named symbol""" - def _set_callables(self, callbacks, attr): - if not isinstance(callbacks, list): - raise TypeError("callbacks must be a list") - - for f in callbacks: - if not callable(f): - raise TypeError("callback must be callable") - - setattr(self, attr, callbacks) - def setup(self): - """Call the setup functions for cumodule with the given `stream`""" + """Call the setup functions for the module""" if self.setup_functions is None or self.initialized: return @@ -2443,10 +2433,7 @@ def setup(self): self.initialized = True def _set_finalizers(self): - """Create finalizers that tears down the cumodule. - - Unlike the setup functions which takes in streams, Numba does not - provide a stream object to the teardown function. + """Create finalizers that tears down the module. """ if self.teardown_functions is None: return From 76630933a446374f32c1588a0bef03b22a239368 Mon Sep 17 00:00:00 2001 From: isVoid Date: Fri, 21 Mar 2025 10:58:09 -0700 Subject: [PATCH 25/36] update linkable code doc --- numba_cuda/numba/cuda/cudadrv/linkable_code.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/linkable_code.py b/numba_cuda/numba/cuda/cudadrv/linkable_code.py index 7303c4832..67ed50f21 100644 --- a/numba_cuda/numba/cuda/cudadrv/linkable_code.py +++ b/numba_cuda/numba/cuda/cudadrv/linkable_code.py @@ -7,8 +7,11 @@ class LinkableCode: :param data: A buffer containing the data to link. :param name: The name of the file to be referenced in any compilation or linking errors that may be produced. - :param setup_callback Callback function on kernel before launching. - :param teardown_callback Callback function on kernel before tearing down. + :param setup_callback A function called prior to the launch of a kernel + contained within a module that has this code object + linked into it. + :param teardown_callback A function called just prior to the unloading of + a module that has this code object linked into it. """ def __init__(self, data, name=None, From 97367ce5534e36e7abc742ae948698f458b72b41 Mon Sep 17 00:00:00 2001 From: isVoid Date: Fri, 21 Mar 2025 11:11:24 -0700 Subject: [PATCH 26/36] update the tests to acommodate nvidia bindings --- numba_cuda/numba/cuda/cudadrv/driver.py | 4 +-- .../tests/cudadrv/test_module_callbacks.py | 34 ++++++++++++------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index f6e7c462c..53aba9332 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -2438,9 +2438,9 @@ def _set_finalizers(self): if self.teardown_functions is None: return - def _teardown(teardowns, modref): + def _teardown(teardowns, handle): for f in teardowns: - f(modref) + f(handle) weakref.finalize( self, diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 53afe203e..ac8f9af5e 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -5,32 +5,38 @@ from numba import cuda, config from numba.cuda.cudadrv.linkable_code import CUSource from numba.cuda.testing import CUDATestCase, ContextResettingTestCase -from numba.cuda.cudadrv.drvapi import cu_module from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD +if config.CUDA_USE_NVIDIA_BINDING: + from cuda.cuda import CUmodule as cu_module_type +else: + from numba.cuda.cudadrv.drvapi import cu_module as cu_module_type + def wipe_all_modules_in_context(): ctx = cuda.current_context() ctx.reset() -@unittest.skipIf( - config.CUDA_USE_NVIDIA_BINDING, - "NV binding support superceded by cuda.bindings." -) +def get_hashable_handle_value(handle): + if not config.CUDA_USE_NVIDIA_BINDING: + handle = handle.value + return handle + + class TestModuleCallbacksBasic(ContextResettingTestCase): def test_basic(self): counter = 0 def setup(handle): - self.assertTrue(isinstance(handle, cu_module)) + self.assertTrue(isinstance(handle, cu_module_type)) nonlocal counter counter += 1 def teardown(handle): - self.assertTrue(isinstance(handle, cu_module)) + self.assertTrue(isinstance(handle, cu_module_type)) nonlocal counter counter -= 1 @@ -58,12 +64,12 @@ def test_different_argtypes(self): def setup(handle): nonlocal counter, setup_seen counter += 1 - setup_seen.add(handle.value) + setup_seen.add(get_hashable_handle_value(handle)) def teardown(handle): nonlocal counter counter -= 1 - teardown_seen.add(handle.value) + teardown_seen.add(get_hashable_handle_value(handle)) lib = CUSource("", setup_callback=setup, teardown_callback=teardown) @@ -94,12 +100,12 @@ def test_two_kernels(self): def setup(handle): nonlocal counter, setup_seen counter += 1 - setup_seen.add(handle.value) + setup_seen.add(get_hashable_handle_value(handle)) def teardown(handle): nonlocal counter, teardown_seen counter -= 1 - teardown_seen.add(handle.value) + teardown_seen.add(get_hashable_handle_value(handle)) lib = CUSource("", setup_callback=setup, teardown_callback=teardown) @@ -169,7 +175,7 @@ def setUp(self): def set_forty_two(handle): # Initialize 42 to global variable `num` res, dptr, size = cuModuleGetGlobal( - handle.value, "num".encode() + get_hashable_handle_value(handle), "num".encode() ) arr = np.array([42], np.int32) @@ -199,3 +205,7 @@ def kernel(arr): arr = np.zeros(1, np.int32) kernel[1, 1](arr) self.assertEqual(arr[0], 42) + + +if __name__ == '__main__': + unittest.main() From 1535233cbb7158e1872a008ee14bbbf7789889c8 Mon Sep 17 00:00:00 2001 From: isVoid Date: Fri, 21 Mar 2025 11:16:17 -0700 Subject: [PATCH 27/36] setup should raise an error if module is already initialized --- numba_cuda/numba/cuda/cudadrv/driver.py | 5 ++++- numba_cuda/numba/cuda/dispatcher.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 53aba9332..e9a1b5a8c 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -2424,7 +2424,10 @@ def get_global_symbol(self, name): def setup(self): """Call the setup functions for the module""" - if self.setup_functions is None or self.initialized: + if self.initialized: + raise RuntimeError("The module has already been initialized.") + + if self.setup_functions is None: return for f in self.setup_functions: diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 2afbed853..63f9edac8 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -311,7 +311,8 @@ def bind(self): cufunc = self._codelibrary.get_cufunc() mod = cufunc.module - mod.setup() + if not mod.initialized: + mod.setup() if ( hasattr(self, "target_context") From 79dd76e836b75cf1a5d51a57d84e07e4b54edb00 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Fri, 21 Mar 2025 11:57:33 -0700 Subject: [PATCH 28/36] Update docs/source/user/cuda_ffi.rst Co-authored-by: Graham Markall <535640+gmarkall@users.noreply.github.com> --- docs/source/user/cuda_ffi.rst | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/docs/source/user/cuda_ffi.rst b/docs/source/user/cuda_ffi.rst index 03d78e0c5..88871098e 100644 --- a/docs/source/user/cuda_ffi.rst +++ b/docs/source/user/cuda_ffi.rst @@ -160,10 +160,17 @@ CUDA C/C++ source code will be compiled with the `NVIDIA Runtime Compiler kernel as either PTX or LTOIR, depending on whether LTO is enabled. Other files will be passed directly to the CUDA Linker. -For cumodules linked to a linkable code object, additional setup -and teardown callbacks are available. Setup functions are invoked once for -every new module loaded. Teardown functions are invoked just prior to -module unloading. These callbacks are defined as follows: +A ``LinkableCode`` object may have setup and teardown callback functions that +perform module-specific initialization and cleanup tasks. + +* Setup functions are invoked once for every new module loaded. +* Teardown functions are invoked just prior to module unloading. + +Both setup and teardown callbacks are called with a handle to the relevant +module. In practice, Numba creates a new module each time a kernel is compiled +for a specific set of argument types. + +The callbacks are defined as follows: .. code:: From b0ff099ba12be556f2ca74bac4bece2e57f8b6fa Mon Sep 17 00:00:00 2001 From: isVoid Date: Fri, 21 Mar 2025 12:04:15 -0700 Subject: [PATCH 29/36] add input type guards for linkable code --- numba_cuda/numba/cuda/codegen.py | 4 ++-- numba_cuda/numba/cuda/cudadrv/linkable_code.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index d26dce5d6..f57899e0c 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -294,9 +294,9 @@ def add_linking_library(self, library): def add_linking_file(self, path_or_obj): if isinstance(path_or_obj, LinkableCode): - if path_or_obj.setup_callback is not None: + if path_or_obj.setup_callback: self._setup_functions.append(path_or_obj.setup_callback) - if path_or_obj.teardown_callback is not None: + if path_or_obj.teardown_callback: self._teardown_functions.append(path_or_obj.teardown_callback) self._linking_files.add(path_or_obj) diff --git a/numba_cuda/numba/cuda/cudadrv/linkable_code.py b/numba_cuda/numba/cuda/cudadrv/linkable_code.py index 67ed50f21..aa4f3fcd8 100644 --- a/numba_cuda/numba/cuda/cudadrv/linkable_code.py +++ b/numba_cuda/numba/cuda/cudadrv/linkable_code.py @@ -17,6 +17,12 @@ class LinkableCode: def __init__(self, data, name=None, setup_callback=None, teardown_callback=None): + + if setup_callback and not callable(setup_callback): + raise TypeError("setup_callback must be callable") + if teardown_callback and not callable(teardown_callback): + raise TypeError("teardown_callback must be callable") + self.data = data self._name = name self.setup_callback = setup_callback From 497f0eb094ac7232cfef0bea5568f96a2ea308b6 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Mon, 24 Mar 2025 15:07:44 +0000 Subject: [PATCH 30/36] Fix docstrings --- numba_cuda/numba/cuda/cudadrv/driver.py | 3 +-- numba_cuda/numba/cuda/cudadrv/linkable_code.py | 11 ++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index e9a1b5a8c..5e4e56d3c 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -2436,8 +2436,7 @@ def setup(self): self.initialized = True def _set_finalizers(self): - """Create finalizers that tears down the module. - """ + """Create finalizers that tear down the module. """ if self.teardown_functions is None: return diff --git a/numba_cuda/numba/cuda/cudadrv/linkable_code.py b/numba_cuda/numba/cuda/cudadrv/linkable_code.py index aa4f3fcd8..d9f18a834 100644 --- a/numba_cuda/numba/cuda/cudadrv/linkable_code.py +++ b/numba_cuda/numba/cuda/cudadrv/linkable_code.py @@ -7,11 +7,12 @@ class LinkableCode: :param data: A buffer containing the data to link. :param name: The name of the file to be referenced in any compilation or linking errors that may be produced. - :param setup_callback A function called prior to the launch of a kernel - contained within a module that has this code object - linked into it. - :param teardown_callback A function called just prior to the unloading of - a module that has this code object linked into it. + :param setup_callback: A function called prior to the launch of a kernel + contained within a module that has this code object + linked into it. + :param teardown_callback: A function called just prior to the unloading of + a module that has this code object linked into + it. """ def __init__(self, data, name=None, From df7427e88b7874ceb11ac4dc44ad498ef7952c17 Mon Sep 17 00:00:00 2001 From: isVoid Date: Tue, 25 Mar 2025 16:36:54 +0000 Subject: [PATCH 31/36] add lock to protect initialization secton --- numba_cuda/numba/cuda/dispatcher.py | 10 +++++++--- numba_cuda/numba/cuda/locks.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 numba_cuda/numba/cuda/locks.py diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 63f9edac8..f072f6073 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -24,6 +24,7 @@ normalize_kernel_dimensions) from numba.cuda import types as cuda_types from numba.cuda.runtime.nrt import rtsys +from numba.cuda.locks import module_init_lock from numba import cuda from numba import _dispatcher @@ -304,15 +305,18 @@ def _reduce_states(self): debug=self.debug, lineinfo=self.lineinfo, call_helper=self.call_helper, extensions=self.extensions) + @module_init_lock + def initialize_once(self, mod): + if not mod.initialized: + mod.setup() + def bind(self): """ Force binding to current CUDA context """ cufunc = self._codelibrary.get_cufunc() - mod = cufunc.module - if not mod.initialized: - mod.setup() + self.initialize_once(cufunc.module) if ( hasattr(self, "target_context") diff --git a/numba_cuda/numba/cuda/locks.py b/numba_cuda/numba/cuda/locks.py new file mode 100644 index 000000000..133e19d9b --- /dev/null +++ b/numba_cuda/numba/cuda/locks.py @@ -0,0 +1,15 @@ +from threading import Lock +from functools import wraps + +# Thread safety guard for module initialization. +_module_init_lock = Lock() + + +def module_init_lock(func): + """Decorator to make sure initialization is invoked once for all threads. + """ + @wraps(func) + def wrapper(*args, **kwargs): + with _module_init_lock: + return func(*args, **kwargs) + return wrapper From 6aa1120699ba637207fa978f8efa61e1882fb155 Mon Sep 17 00:00:00 2001 From: isVoid Date: Tue, 25 Mar 2025 16:42:21 +0000 Subject: [PATCH 32/36] add documentation --- docs/source/user/cuda_ffi.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/user/cuda_ffi.rst b/docs/source/user/cuda_ffi.rst index 88871098e..59d514473 100644 --- a/docs/source/user/cuda_ffi.rst +++ b/docs/source/user/cuda_ffi.rst @@ -170,6 +170,10 @@ Both setup and teardown callbacks are called with a handle to the relevant module. In practice, Numba creates a new module each time a kernel is compiled for a specific set of argument types. +For each module, the setup callback is invoked once only. When a module is +executed by multiple threads, only one thread will execute the setup +callback. + The callbacks are defined as follows: .. code:: From 40f4e9bb9b13c78366136264f3ea3b4ea4c4ec2c Mon Sep 17 00:00:00 2001 From: isVoid Date: Wed, 26 Mar 2025 17:19:04 -0700 Subject: [PATCH 33/36] add multithreaded callback behavior test --- .../tests/cudadrv/test_module_callbacks.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index ac8f9af5e..208db1ab7 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -1,4 +1,5 @@ import unittest +import threading import numpy as np @@ -15,6 +16,12 @@ def wipe_all_modules_in_context(): + """Cleans all modules reference held by current context. + This simulates the behavior on interpreter shutdown. + + TODO: This is a temp solution until + https://github.com/NVIDIA/numba-cuda/issues/171 is implemented. + """ ctx = cuda.current_context() ctx.reset() @@ -207,5 +214,80 @@ def kernel(arr): self.assertEqual(arr[0], 42) +class TestMultithreadedCallbacks(CUDATestCase): + + def test_concurrent_initialization(self): + seen_mods = set() + max_seen_mods = 0 + + def setup(mod): + nonlocal seen_mods, max_seen_mods + seen_mods.add(get_hashable_handle_value(mod)) + max_seen_mods = max(max_seen_mods, len(seen_mods)) + + def teardown(mod): + nonlocal seen_mods + # Raises an error if the module is not found in the seen_mods + seen_mods.remove(get_hashable_handle_value(mod)) + + lib = CUSource("", setup_callback=setup, teardown_callback=teardown) + + @cuda.jit(link=[lib]) + def kernel(): + pass + + def concurrent_compilation_launch(kernel): + kernel[1, 1]() + + threads = [ + threading.Thread( + target=concurrent_compilation_launch, args=(kernel,) + ) for _ in range(4) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + wipe_all_modules_in_context() + self.assertEqual(len(seen_mods), 0) + self.assertEqual(max_seen_mods, 4) + + def test_concurrent_initialization_different_args(self): + seen_mods = set() + max_seen_mods = 0 + + def setup(mod): + nonlocal seen_mods, max_seen_mods + seen_mods.add(get_hashable_handle_value(mod)) + max_seen_mods = max(max_seen_mods, len(seen_mods)) + + def teardown(mod): + nonlocal seen_mods + seen_mods.remove(get_hashable_handle_value(mod)) + + lib = CUSource("", setup_callback=setup, teardown_callback=teardown) + + @cuda.jit(link=[lib]) + def kernel(a): + pass + + def concurrent_compilation_launch(): + kernel[1, 1](42) # (int64)->() : module 1 + kernel[1, 1](9) # (int64)->() : module 1 from cache + kernel[1, 1](3.14) # (float64)->() : module 2 + + threads = [threading.Thread(target=concurrent_compilation_launch) + for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + wipe_all_modules_in_context() + assert len(seen_mods) == 0 + self.assertEqual(max_seen_mods, 8) # 2 kernels per thread + + if __name__ == '__main__': unittest.main() From c24e0b24800c7a0033341c7de3bfa62e452c5142 Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Fri, 21 Mar 2025 10:44:59 -0400 Subject: [PATCH 34/36] Replace flake8 with ruff and pre-commit-hooks --- .flake8 | 52 ---------------------------------- .pre-commit-config.yaml | 24 ++++++++++++++-- pyproject.toml | 63 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 55 deletions(-) delete mode 100644 .flake8 diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 591e4a2af..000000000 --- a/.flake8 +++ /dev/null @@ -1,52 +0,0 @@ -[flake8] -ignore = - # Extra space in brackets - E20, - # Multiple spaces around "," - E231,E241, - # Comments - E26, - # Assigning lambda expression - E731, - # Ambiguous variable names - E741, - # line break before binary operator - W503, - # line break after binary operator - W504, -max-line-length = 80 - -exclude = - __pycache__ - .git - *.pyc - *~ - *.o - *.so - *.cpp - *.c - *.h - -per-file-ignores = - # Slightly long line in the standard version file - numba_cuda/_version.py: E501 - # "Unused" imports / potentially undefined names in init files - numba_cuda/numba/cuda/__init__.py:F401,F403,F405 - numba_cuda/numba/cuda/simulator/__init__.py:F401,F403 - numba_cuda/numba/cuda/simulator/cudadrv/__init__.py:F401 - # Ignore star imports, unused imports, and "may be defined by star imports" - # errors in device_init because its purpose is to bring together a lot of - # the public API to be star-imported in numba.cuda.__init__ - numba_cuda/numba/cuda/device_init.py:F401,F403,F405 - # libdevice.py is an autogenerated file containing stubs for all the device - # functions. Some of the lines in docstrings are a little over-long, as they - # contain the URLs of the reference pages in the online libdevice - # documentation. - numba_cuda/numba/cuda/libdevice.py:E501 - # Ignore too-long lines in the doc examples, prioritising readability - # in the docs over line length in the example source (especially given that - # the test code is already indented by 8 spaces) - numba_cuda/numba/cuda/tests/doc_examples/test_random.py:E501 - numba_cuda/numba/cuda/tests/doc_examples/test_cg.py:E501 - numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py:E501 - numba_cuda/numba/tests/doc_examples/test_interval_example.py:E501 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0a114cd32..478cd1ef5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,23 @@ repos: -- repo: https://github.com/PyCQA/flake8 - rev: 7.1.0 +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 # Use the latest version or a specific tag hooks: - - id: flake8 + - id: check-added-large-files + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + exclude: ^conda/recipes/numba-cuda/meta.yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: trailing-whitespace + - id: mixed-line-ending + args: ['--fix=lf'] +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml index 2a484d9da..6dbf04e16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,3 +37,66 @@ include = ["numba_cuda*"] [tool.setuptools.package-data] "*" = ["*.cu", "*.h", "*.hpp", "*.ptx", "*.cuh", "VERSION", "Makefile"] + +[tool.ruff] +line-length = 80 + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 80 + +[tool.ruff.lint.pycodestyle] +max-doc-length = 80 +max-line-length = 80 + +[tool.ruff.lint] +ignore = [ + # Extra space in brackets + "E20", + # Multiple spaces around "," + "E231", + "E241", + # Comments + "E26", + # Assigning lambda expression + "E731", + # Ambiguous variable names + "E741", +] +fixable = ["ALL"] + +exclude = [ + "__pycache__", + ".git", + "*.pyc", + "*~", + "*.o", + "*.so", + "*.cpp", + "*.c", + "*.h", +] + +[tool.ruff.lint.per-file-ignores] +# Slightly long line in the standard version file +"numba_cuda/_version.py" = ["E501"] +# "Unused" imports / potentially undefined names in init files +"numba_cuda/numba/cuda/__init__.py" = ["F401", "F403", "F405"] +"numba_cuda/numba/cuda/simulator/__init__.py" = ["F401", "F403"] +"numba_cuda/numba/cuda/simulator/cudadrv/__init__.py" = ["F401"] +# Ignore star imports", " unused imports", " and "may be defined by star imports" +# errors in device_init because its purpose is to bring together a lot of +# the public API to be star-imported in numba.cuda.__init__ +"numba_cuda/numba/cuda/device_init.py" = ["F401", "F403", "F405"] +# libdevice.py is an autogenerated file containing stubs for all the device +# functions. Some of the lines in docstrings are a little over-long", " as they +# contain the URLs of the reference pages in the online libdevice +# documentation. +"numba_cuda/numba/cuda/libdevice.py" = ["E501"] +# Ignore too-long lines in the doc examples", " prioritising readability +# in the docs over line length in the example source (especially given that +# the test code is already indented by 8 spaces) +"numba_cuda/numba/cuda/tests/doc_examples/test_random.py" = ["E501"] +"numba_cuda/numba/cuda/tests/doc_examples/test_cg.py" = ["E501"] +"numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py" = ["E501"] +"numba_cuda/numba/tests/doc_examples/test_interval_example.py" = ["E501"] From 075b7bde93e5a69969f3a88450cd52af2f5fd94a Mon Sep 17 00:00:00 2001 From: isVoid Date: Tue, 8 Apr 2025 17:11:29 -0700 Subject: [PATCH 35/36] Apply precommit --- numba_cuda/numba/cuda/codegen.py | 99 ++- numba_cuda/numba/cuda/compiler.py | 345 ++++++--- numba_cuda/numba/cuda/cudadrv/driver.py | 730 ++++++++++-------- .../numba/cuda/cudadrv/linkable_code.py | 7 +- numba_cuda/numba/cuda/cudadrv/nvvm.py | 449 +++++++---- numba_cuda/numba/cuda/dispatcher.py | 504 +++++++----- numba_cuda/numba/cuda/locks.py | 5 +- .../tests/cudadrv/test_module_callbacks.py | 35 +- .../cuda/tests/cudadrv/test_nvjitlink.py | 39 +- .../numba/cuda/tests/cudapy/test_debuginfo.py | 59 +- .../cuda/tests/cudapy/test_device_func.py | 97 +-- .../numba/cuda/tests/cudapy/test_overload.py | 55 +- numba_cuda/numba/cuda/tests/nrt/test_nrt.py | 29 +- .../numba/cuda/tests/nrt/test_nrt_refct.py | 22 +- 14 files changed, 1489 insertions(+), 986 deletions(-) diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index f57899e0c..768d87c77 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -10,7 +10,7 @@ import subprocess import tempfile -CUDA_TRIPLE = 'nvptx64-nvidia-cuda' +CUDA_TRIPLE = "nvptx64-nvidia-cuda" def run_nvdisasm(cubin, flags): @@ -20,19 +20,24 @@ def run_nvdisasm(cubin, flags): fname = None try: fd, fname = tempfile.mkstemp() - with open(fname, 'wb') as f: + with open(fname, "wb") as f: f.write(cubin) try: - cp = subprocess.run(['nvdisasm', *flags, fname], check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + cp = subprocess.run( + ["nvdisasm", *flags, fname], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) except FileNotFoundError as e: - msg = ("nvdisasm has not been found. You may need " - "to install the CUDA toolkit and ensure that " - "it is available on your PATH.\n") + msg = ( + "nvdisasm has not been found. You may need " + "to install the CUDA toolkit and ensure that " + "it is available on your PATH.\n" + ) raise RuntimeError(msg) from e - return cp.stdout.decode('utf-8') + return cp.stdout.decode("utf-8") finally: if fd is not None: os.close(fd) @@ -42,13 +47,13 @@ def run_nvdisasm(cubin, flags): def disassemble_cubin(cubin): # Request lineinfo in disassembly - flags = ['-gi'] + flags = ["-gi"] return run_nvdisasm(cubin, flags) def disassemble_cubin_for_cfg(cubin): # Request control flow graph in disassembly - flags = ['-cfg'] + flags = ["-cfg"] return run_nvdisasm(cubin, flags) @@ -66,7 +71,7 @@ def __init__( entry_name=None, max_registers=None, lto=False, - nvvm_options=None + nvvm_options=None, ): """ codegen: @@ -149,7 +154,7 @@ def get_asm_str(self, cc=None): arch = nvvm.get_arch_option(*cc) options = self._nvvm_options.copy() - options['arch'] = arch + options["arch"] = arch irs = self.llvm_strs @@ -158,12 +163,12 @@ def get_asm_str(self, cc=None): # Sometimes the result from NVVM contains trailing whitespace and # nulls, which we strip so that the assembly dump looks a little # tidier. - ptx = ptx.decode().strip('\x00').strip() + ptx = ptx.decode().strip("\x00").strip() if config.DUMP_ASSEMBLY: - print(("ASSEMBLY %s" % self._name).center(80, '-')) + print(("ASSEMBLY %s" % self._name).center(80, "-")) print(ptx) - print('=' * 80) + print("=" * 80) self._ptx_cache[cc] = ptx @@ -178,8 +183,8 @@ def get_ltoir(self, cc=None): arch = nvvm.get_arch_option(*cc) options = self._nvvm_options.copy() - options['arch'] = arch - options['gen-lto'] = None + options["arch"] = arch + options["gen-lto"] = None irs = self.llvm_strs ltoir = nvvm.compile_ir(irs, **options) @@ -199,7 +204,7 @@ def _link_all(self, linker, cc, ignore_nonlto=False): linker.add_file_guess_ext(path, ignore_nonlto) if self.needs_cudadevrt: linker.add_file_guess_ext( - get_cudalib('cudadevrt', static=True), ignore_nonlto + get_cudalib("cudadevrt", static=True), ignore_nonlto ) def get_cubin(self, cc=None): @@ -214,22 +219,20 @@ def get_cubin(self, cc=None): max_registers=self._max_registers, cc=cc, additional_flags=["-ptx"], - lto=self._lto + lto=self._lto, ) # `-ptx` flag is meant to view the optimized PTX for LTO objects. # Non-LTO objects are not passed to linker. self._link_all(linker, cc, ignore_nonlto=True) - ptx = linker.get_linked_ptx().decode('utf-8') + ptx = linker.get_linked_ptx().decode("utf-8") - print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, '-')) + print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, "-")) print(ptx) - print('=' * 80) + print("=" * 80) linker = driver.Linker.new( - max_registers=self._max_registers, - cc=cc, - lto=self._lto + max_registers=self._max_registers, cc=cc, lto=self._lto ) self._link_all(linker, cc, ignore_nonlto=False) cubin = linker.complete() @@ -241,8 +244,10 @@ def get_cubin(self, cc=None): def get_cufunc(self): if self._entry_name is None: - msg = "Missing entry_name - are you trying to get the cufunc " \ - "for a device function?" + msg = ( + "Missing entry_name - are you trying to get the cufunc " + "for a device function?" + ) raise RuntimeError(msg) ctx = devices.get_context() @@ -254,7 +259,8 @@ def get_cufunc(self): cubin = self.get_cubin(cc=device.compute_capability) module = ctx.create_module_image( - cubin, self._setup_functions, self._teardown_functions) + cubin, self._setup_functions, self._teardown_functions + ) # Load cufunc = module.get_function(self._entry_name) @@ -268,7 +274,7 @@ def get_linkerinfo(self, cc): try: return self._linkerinfo_cache[cc] except KeyError: - raise KeyError(f'No linkerinfo for CC {cc}') + raise KeyError(f"No linkerinfo for CC {cc}") def get_sass(self, cc=None): return disassemble_cubin(self.get_cubin(cc=cc)) @@ -279,7 +285,7 @@ def get_sass_cfg(self, cc=None): def add_ir_module(self, mod): self._raise_if_finalized() if self._module is not None: - raise RuntimeError('CUDACodeLibrary only supports one module') + raise RuntimeError("CUDACodeLibrary only supports one module") self._module = mod def add_linking_library(self, library): @@ -305,12 +311,13 @@ def get_function(self, name): for fn in self._module.functions: if fn.name == name: return fn - raise KeyError(f'Function {name} not found') + raise KeyError(f"Function {name} not found") @property def modules(self): - return [self._module] + [mod for lib in self._linking_libraries - for mod in lib.modules] + return [self._module] + [ + mod for lib in self._linking_libraries for mod in lib.modules + ] @property def linking_libraries(self): @@ -345,7 +352,7 @@ def finalize(self): for mod in library.modules: for fn in mod.functions: if not fn.is_declaration: - fn.linkage = 'linkonce_odr' + fn.linkage = "linkonce_odr" self._finalized = True @@ -356,10 +363,10 @@ def _reduce_states(self): after deserialization. """ if self._linking_files: - msg = 'Cannot pickle CUDACodeLibrary with linking files' + msg = "Cannot pickle CUDACodeLibrary with linking files" raise RuntimeError(msg) if not self._finalized: - raise RuntimeError('Cannot pickle unfinalized CUDACodeLibrary') + raise RuntimeError("Cannot pickle unfinalized CUDACodeLibrary") return dict( codegen=None, name=self.name, @@ -370,13 +377,23 @@ def _reduce_states(self): linkerinfo_cache=self._linkerinfo_cache, max_registers=self._max_registers, nvvm_options=self._nvvm_options, - needs_cudadevrt=self.needs_cudadevrt + needs_cudadevrt=self.needs_cudadevrt, ) @classmethod - def _rebuild(cls, codegen, name, entry_name, llvm_strs, ptx_cache, - cubin_cache, linkerinfo_cache, max_registers, nvvm_options, - needs_cudadevrt): + def _rebuild( + cls, + codegen, + name, + entry_name, + llvm_strs, + ptx_cache, + cubin_cache, + linkerinfo_cache, + max_registers, + nvvm_options, + needs_cudadevrt, + ): """ Rebuild an instance. """ diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py index 49968890e..2009e777f 100644 --- a/numba_cuda/numba/cuda/compiler.py +++ b/numba_cuda/numba/cuda/compiler.py @@ -1,19 +1,39 @@ from llvmlite import ir from numba.core.typing.templates import ConcreteTemplate from numba.core import ir as numba_ir -from numba.core import (cgutils, types, typing, funcdesc, config, compiler, - sigutils, utils) -from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase, - DefaultPassBuilder, Flags, Option, - CompileResult) +from numba.core import ( + cgutils, + types, + typing, + funcdesc, + config, + compiler, + sigutils, + utils, +) +from numba.core.compiler import ( + sanitize_compile_result_entries, + CompilerBase, + DefaultPassBuilder, + Flags, + Option, + CompileResult, +) from numba.core.compiler_lock import global_compiler_lock -from numba.core.compiler_machinery import (FunctionPass, LoweringPass, - PassManager, register_pass) +from numba.core.compiler_machinery import ( + FunctionPass, + LoweringPass, + PassManager, + register_pass, +) from numba.core.interpreter import Interpreter from numba.core.errors import NumbaInvalidConfigWarning from numba.core.untyped_passes import TranslateByteCode -from numba.core.typed_passes import (IRLegalization, NativeLowering, - AnnotateTypes) +from numba.core.typed_passes import ( + IRLegalization, + NativeLowering, + AnnotateTypes, +) from warnings import warn from numba.cuda import nvvmutils from numba.cuda.api import get_current_device @@ -52,15 +72,9 @@ class CUDAFlags(Flags): doc="Compute Capability", ) max_registers = Option( - type=_optional_int_type, - default=None, - doc="Max registers" - ) - lto = Option( - type=bool, - default=False, - doc="Enable Link-time Optimization" + type=_optional_int_type, default=None, doc="Max registers" ) + lto = Option(type=bool, default=False, doc="Enable Link-time Optimization") # The CUDACompileResult (CCR) has a specially-defined entry point equal to its @@ -79,6 +93,7 @@ class CUDAFlags(Flags): # point will no longer need to be a synthetic value, but will instead be a # pointer to the compiled function as in the CPU target. + class CUDACompileResult(CompileResult): @property def entry_point(self): @@ -92,7 +107,6 @@ def cuda_compile_result(**entries): @register_pass(mutates_CFG=True, analysis_only=False) class CUDABackend(LoweringPass): - _name = "cuda_backend" def __init__(self): @@ -102,7 +116,7 @@ def run_pass(self, state): """ Back-end: Packages lowering output in a compile result """ - lowered = state['cr'] + lowered = state["cr"] signature = typing.signature(state.return_type, *state.args) state.cr = cuda_compile_result( @@ -137,9 +151,12 @@ def run_pass(self, state): nvvm_options = state.flags.nvvm_options max_registers = state.flags.max_registers lto = state.flags.lto - state.library = codegen.create_library(name, nvvm_options=nvvm_options, - max_registers=max_registers, - lto=lto) + state.library = codegen.create_library( + name, + nvvm_options=nvvm_options, + max_registers=max_registers, + lto=lto, + ) # Enable object caching upfront so that the library can be serialized. state.library.enable_object_caching() @@ -165,13 +182,15 @@ def _op_JUMP_IF(self, inst, pred, iftrue): gv_fn = numba_ir.Global("bool", bool, loc=self.loc) self.store(value=gv_fn, name=name) - callres = numba_ir.Expr.call(self.get(name), (self.get(pred),), (), - loc=self.loc) + callres = numba_ir.Expr.call( + self.get(name), (self.get(pred),), (), loc=self.loc + ) pname = "$%spred" % (inst.offset) predicate = self.store(value=callres, name=pname) - bra = numba_ir.Branch(cond=predicate, truebr=truebr, falsebr=falsebr, - loc=self.loc) + bra = numba_ir.Branch( + cond=predicate, truebr=truebr, falsebr=falsebr, loc=self.loc + ) self.current_block.append(bra) @@ -183,18 +202,18 @@ def __init__(self): FunctionPass.__init__(self) def run_pass(self, state): - func_id = state['func_id'] - bc = state['bc'] + func_id = state["func_id"] + bc = state["bc"] interp = CUDABytecodeInterpreter(func_id) func_ir = interp.interpret(bc) - state['func_ir'] = func_ir + state["func_ir"] = func_ir return True class CUDACompiler(CompilerBase): def define_pipelines(self): dpb = DefaultPassBuilder - pm = PassManager('cuda') + pm = PassManager("cuda") untyped_passes = dpb.define_untyped_pipeline(self.state) @@ -225,10 +244,9 @@ def replace_translate_pass(implementation, description): return [pm] def define_cuda_lowering_pipeline(self, state): - pm = PassManager('cuda_lowering') + pm = PassManager("cuda_lowering") # legalise - pm.add_pass(IRLegalization, - "ensure IR is legal prior to lowering") + pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering") pm.add_pass(AnnotateTypes, "annotate types") # lower @@ -241,13 +259,24 @@ def define_cuda_lowering_pipeline(self, state): @global_compiler_lock -def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False, - inline=False, fastmath=False, nvvm_options=None, - cc=None, max_registers=None, lto=False): +def compile_cuda( + pyfunc, + return_type, + args, + debug=False, + lineinfo=False, + inline=False, + fastmath=False, + nvvm_options=None, + cc=None, + max_registers=None, + lto=False, +): if cc is None: - raise ValueError('Compute Capability must be supplied') + raise ValueError("Compute Capability must be supplied") from .descriptor import cuda_target + typingctx = cuda_target.typing_context targetctx = cuda_target.target_context @@ -269,10 +298,10 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False, flags.dbg_directives_only = True if debug: - flags.error_model = 'python' + flags.error_model = "python" flags.dbg_extend_lifetimes = True else: - flags.error_model = 'numpy' + flags.error_model = "numpy" if inline: flags.forceinline = True @@ -286,15 +315,18 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False, # Run compilation pipeline from numba.core.target_extension import target_override - with target_override('cuda'): - cres = compiler.compile_extra(typingctx=typingctx, - targetctx=targetctx, - func=pyfunc, - args=args, - return_type=return_type, - flags=flags, - locals={}, - pipeline_class=CUDACompiler) + + with target_override("cuda"): + cres = compiler.compile_extra( + typingctx=typingctx, + targetctx=targetctx, + func=pyfunc, + args=args, + return_type=return_type, + flags=flags, + locals={}, + pipeline_class=CUDACompiler, + ) library = cres.library library.finalize() @@ -302,8 +334,9 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False, return cres -def cabi_wrap_function(context, lib, fndesc, wrapper_function_name, - nvvm_options): +def cabi_wrap_function( + context, lib, fndesc, wrapper_function_name, nvvm_options +): """ Wrap a Numba ABI function in a C ABI wrapper at the NVVM IR level. @@ -311,9 +344,11 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name, """ # The wrapper will be contained in a new library that links to the wrapped # function's library - library = lib.codegen.create_library(f'{lib.name}_function_', - entry_name=wrapper_function_name, - nvvm_options=nvvm_options) + library = lib.codegen.create_library( + f"{lib.name}_function_", + entry_name=wrapper_function_name, + nvvm_options=nvvm_options, + ) library.add_linking_library(lib) # Determine the caller (C ABI) and wrapper (Numba ABI) function types @@ -331,14 +366,15 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name, # its return value wrapfn = ir.Function(wrapper_module, wrapfnty, wrapper_function_name) - builder = ir.IRBuilder(wrapfn.append_basic_block('')) + builder = ir.IRBuilder(wrapfn.append_basic_block("")) arginfo = context.get_arg_packer(argtypes) callargs = arginfo.from_arguments(builder, wrapfn.args) # We get (status, return_value), but we ignore the status since we # can't propagate it through the C ABI anyway _, return_value = context.call_conv.call_function( - builder, func, restype, argtypes, callargs) + builder, func, restype, argtypes, callargs + ) builder.ret(return_value) if config.DUMP_LLVM: @@ -395,8 +431,10 @@ def kernel_fixup(kernel, debug): # Find all stores first for inst in block.instructions: - if (isinstance(inst, ir.StoreInstr) - and inst.operands[1] == return_value): + if ( + isinstance(inst, ir.StoreInstr) + and inst.operands[1] == return_value + ): remove_list.append(inst) # Remove all stores @@ -407,8 +445,9 @@ def kernel_fixup(kernel, debug): # value if isinstance(kernel.type, ir.PointerType): - new_type = ir.PointerType(ir.FunctionType(ir.VoidType(), - kernel.type.pointee.args[1:])) + new_type = ir.PointerType( + ir.FunctionType(ir.VoidType(), kernel.type.pointee.args[1:]) + ) else: new_type = ir.FunctionType(ir.VoidType(), kernel.type.args[1:]) @@ -418,13 +457,13 @@ def kernel_fixup(kernel, debug): # If debug metadata is present, remove the return value from it - if kernel_metadata := getattr(kernel, 'metadata', None): - if dbg_metadata := kernel_metadata.get('dbg', None): + if kernel_metadata := getattr(kernel, "metadata", None): + if dbg_metadata := kernel_metadata.get("dbg", None): for name, value in dbg_metadata.operands: if name == "type": type_metadata = value for tm_name, tm_value in type_metadata.operands: - if tm_name == 'types': + if tm_name == "types": types = tm_value types.operands = types.operands[1:] if config.DUMP_LLVM: @@ -435,26 +474,24 @@ def kernel_fixup(kernel, debug): nvvm.set_cuda_kernel(kernel) if config.DUMP_LLVM: - print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, '-')) + print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, "-")) print(kernel.module) - print('=' * 80) + print("=" * 80) def add_exception_store_helper(kernel): - # Create global variables for exception state def define_error_gv(postfix): name = kernel.name + postfix - gv = cgutils.add_global_variable(kernel.module, ir.IntType(32), - name) + gv = cgutils.add_global_variable(kernel.module, ir.IntType(32), name) gv.initializer = ir.Constant(gv.type.pointee, None) return gv gv_exc = define_error_gv("__errcode__") gv_tid = [] gv_ctaid = [] - for i in 'xyz': + for i in "xyz": gv_tid.append(define_error_gv("__tid%s__" % i)) gv_ctaid.append(define_error_gv("__ctaid%s__" % i)) @@ -484,18 +521,25 @@ def define_error_gv(postfix): # Use atomic cmpxchg to prevent rewriting the error status # Only the first error is recorded - xchg = builder.cmpxchg(gv_exc, old, status.code, - 'monotonic', 'monotonic') + xchg = builder.cmpxchg( + gv_exc, old, status.code, "monotonic", "monotonic" + ) changed = builder.extract_value(xchg, 1) # If the xchange is successful, save the thread ID. sreg = nvvmutils.SRegBuilder(builder) with builder.if_then(changed): - for dim, ptr, in zip("xyz", gv_tid): + for ( + dim, + ptr, + ) in zip("xyz", gv_tid): val = sreg.tid(dim) builder.store(val, ptr) - for dim, ptr, in zip("xyz", gv_ctaid): + for ( + dim, + ptr, + ) in zip("xyz", gv_ctaid): val = sreg.ctaid(dim) builder.store(val, ptr) @@ -505,9 +549,19 @@ def define_error_gv(postfix): @global_compiler_lock -def compile(pyfunc, sig, debug=None, lineinfo=False, device=True, - fastmath=False, cc=None, opt=None, abi="c", abi_info=None, - output='ptx'): +def compile( + pyfunc, + sig, + debug=None, + lineinfo=False, + device=True, + fastmath=False, + cc=None, + opt=None, + abi="c", + abi_info=None, + output="ptx", +): """Compile a Python function to PTX or LTO-IR for a given set of argument types. @@ -551,43 +605,49 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True, :rtype: tuple """ if abi not in ("numba", "c"): - raise NotImplementedError(f'Unsupported ABI: {abi}') + raise NotImplementedError(f"Unsupported ABI: {abi}") - if abi == 'c' and not device: - raise NotImplementedError('The C ABI is not supported for kernels') + if abi == "c" and not device: + raise NotImplementedError("The C ABI is not supported for kernels") if output not in ("ptx", "ltoir"): - raise NotImplementedError(f'Unsupported output type: {output}') + raise NotImplementedError(f"Unsupported output type: {output}") debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug opt = (config.OPT != 0) if opt is None else opt if debug and opt: - msg = ("debug=True with opt=True " - "is not supported by CUDA. This may result in a crash" - " - set debug=False or opt=False.") + msg = ( + "debug=True with opt=True " + "is not supported by CUDA. This may result in a crash" + " - set debug=False or opt=False." + ) warn(NumbaInvalidConfigWarning(msg)) - lto = (output == 'ltoir') + lto = output == "ltoir" abi_info = abi_info or dict() - nvvm_options = { - 'fastmath': fastmath, - 'opt': 3 if opt else 0 - } + nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0} if debug: - nvvm_options['g'] = None + nvvm_options["g"] = None if lto: - nvvm_options['gen-lto'] = None + nvvm_options["gen-lto"] = None args, return_type = sigutils.normalize_signature(sig) cc = cc or config.CUDA_DEFAULT_PTX_CC - cres = compile_cuda(pyfunc, return_type, args, debug=debug, - lineinfo=lineinfo, fastmath=fastmath, - nvvm_options=nvvm_options, cc=cc) + cres = compile_cuda( + pyfunc, + return_type, + args, + debug=debug, + lineinfo=lineinfo, + fastmath=fastmath, + nvvm_options=nvvm_options, + cc=cc, + ) resty = cres.signature.return_type if resty and not device and resty != types.void: @@ -598,9 +658,10 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True, if device: lib = cres.library if abi == "c": - wrapper_name = abi_info.get('abi_name', pyfunc.__name__) - lib = cabi_wrap_function(tgt, lib, cres.fndesc, wrapper_name, - nvvm_options) + wrapper_name = abi_info.get("abi_name", pyfunc.__name__) + lib = cabi_wrap_function( + tgt, lib, cres.fndesc, wrapper_name, nvvm_options + ) else: lib = cres.library kernel = lib.get_function(cres.fndesc.llvm_func_name) @@ -614,38 +675,94 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True, return code, resty -def compile_for_current_device(pyfunc, sig, debug=None, lineinfo=False, - device=True, fastmath=False, opt=None, - abi="c", abi_info=None, output='ptx'): +def compile_for_current_device( + pyfunc, + sig, + debug=None, + lineinfo=False, + device=True, + fastmath=False, + opt=None, + abi="c", + abi_info=None, + output="ptx", +): """Compile a Python function to PTX or LTO-IR for a given signature for the current device's compute capabilility. This calls :func:`compile` with an appropriate ``cc`` value for the current device.""" cc = get_current_device().compute_capability - return compile(pyfunc, sig, debug=debug, lineinfo=lineinfo, device=device, - fastmath=fastmath, cc=cc, opt=opt, abi=abi, - abi_info=abi_info, output=output) + return compile( + pyfunc, + sig, + debug=debug, + lineinfo=lineinfo, + device=device, + fastmath=fastmath, + cc=cc, + opt=opt, + abi=abi, + abi_info=abi_info, + output=output, + ) -def compile_ptx(pyfunc, sig, debug=None, lineinfo=False, device=False, - fastmath=False, cc=None, opt=None, abi="numba", abi_info=None): +def compile_ptx( + pyfunc, + sig, + debug=None, + lineinfo=False, + device=False, + fastmath=False, + cc=None, + opt=None, + abi="numba", + abi_info=None, +): """Compile a Python function to PTX for a given signature. See :func:`compile`. The defaults for this function are to compile a kernel with the Numba ABI, rather than :func:`compile`'s default of compiling a device function with the C ABI.""" - return compile(pyfunc, sig, debug=debug, lineinfo=lineinfo, device=device, - fastmath=fastmath, cc=cc, opt=opt, abi=abi, - abi_info=abi_info, output='ptx') + return compile( + pyfunc, + sig, + debug=debug, + lineinfo=lineinfo, + device=device, + fastmath=fastmath, + cc=cc, + opt=opt, + abi=abi, + abi_info=abi_info, + output="ptx", + ) -def compile_ptx_for_current_device(pyfunc, sig, debug=None, lineinfo=False, - device=False, fastmath=False, opt=None, - abi="numba", abi_info=None): +def compile_ptx_for_current_device( + pyfunc, + sig, + debug=None, + lineinfo=False, + device=False, + fastmath=False, + opt=None, + abi="numba", + abi_info=None, +): """Compile a Python function to PTX for a given signature for the current device's compute capabilility. See :func:`compile_ptx`.""" cc = get_current_device().compute_capability - return compile_ptx(pyfunc, sig, debug=debug, lineinfo=lineinfo, - device=device, fastmath=fastmath, cc=cc, opt=opt, - abi=abi, abi_info=abi_info) + return compile_ptx( + pyfunc, + sig, + debug=debug, + lineinfo=lineinfo, + device=device, + fastmath=fastmath, + cc=cc, + opt=opt, + abi=abi, + abi_info=abi_info, + ) def declare_device_function(name, restype, argtypes, link): @@ -654,6 +771,7 @@ def declare_device_function(name, restype, argtypes, link): def declare_device_function_template(name, restype, argtypes, link): from .descriptor import cuda_target + typingctx = cuda_target.typing_context targetctx = cuda_target.target_context sig = typing.signature(restype, *argtypes) @@ -664,7 +782,8 @@ class device_function_template(ConcreteTemplate): cases = [sig] fndesc = funcdesc.ExternalFunctionDescriptor( - name=name, restype=restype, argtypes=argtypes) + name=name, restype=restype, argtypes=argtypes + ) typingctx.insert_user_function(extfn, device_function_template) targetctx.insert_user_function(extfn, fndesc) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 5e4e56d3c..9c5237b30 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -10,6 +10,7 @@ system to freeze in some cases. """ + import sys import os import ctypes @@ -25,8 +26,17 @@ import re from itertools import product from abc import ABCMeta, abstractmethod -from ctypes import (c_int, byref, c_size_t, c_char, c_char_p, addressof, - c_void_p, c_float, c_uint) +from ctypes import ( + c_int, + byref, + c_size_t, + c_char, + c_char_p, + addressof, + c_void_p, + c_float, + c_uint, +) import contextlib import importlib import numpy as np @@ -51,13 +61,14 @@ if USE_NV_BINDING: from cuda import cuda as binding + # There is no definition of the default stream in the Nvidia bindings (nor # is there at the C/C++ level), so we define it here so we don't need to # use a magic number 0 in places where we want the default stream. CU_STREAM_DEFAULT = 0 MIN_REQUIRED_CC = (3, 5) -SUPPORTS_IPC = sys.platform.startswith('linux') +SUPPORTS_IPC = sys.platform.startswith("linux") _py_decref = ctypes.pythonapi.Py_DecRef @@ -71,10 +82,9 @@ "to be available" ) -ENABLE_PYNVJITLINK = ( - _readenv("NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, False) - or getattr(config, "CUDA_ENABLE_PYNVJITLINK", False) -) +ENABLE_PYNVJITLINK = _readenv( + "NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, False +) or getattr(config, "CUDA_ENABLE_PYNVJITLINK", False) if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"): config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK @@ -94,7 +104,7 @@ def make_logger(): if config.CUDA_LOG_LEVEL: # create a simple handler that prints to stderr handler = logging.StreamHandler(sys.stderr) - fmt = '== CUDA [%(relativeCreated)d] %(levelname)5s -- %(message)s' + fmt = "== CUDA [%(relativeCreated)d] %(levelname)5s -- %(message)s" handler.setFormatter(logging.Formatter(fmt=fmt)) logger.addHandler(handler) else: @@ -122,50 +132,52 @@ def __str__(self): def locate_driver_and_loader(): - envpath = config.CUDA_DRIVER - if envpath == '0': + if envpath == "0": # Force fail _raise_driver_not_found() # Determine DLL type - if sys.platform == 'win32': + if sys.platform == "win32": dlloader = ctypes.WinDLL - dldir = ['\\windows\\system32'] - dlnames = ['nvcuda.dll'] - elif sys.platform == 'darwin': + dldir = ["\\windows\\system32"] + dlnames = ["nvcuda.dll"] + elif sys.platform == "darwin": dlloader = ctypes.CDLL - dldir = ['/usr/local/cuda/lib'] - dlnames = ['libcuda.dylib'] + dldir = ["/usr/local/cuda/lib"] + dlnames = ["libcuda.dylib"] else: # Assume to be *nix like dlloader = ctypes.CDLL - dldir = ['/usr/lib', '/usr/lib64'] - dlnames = ['libcuda.so', 'libcuda.so.1'] + dldir = ["/usr/lib", "/usr/lib64"] + dlnames = ["libcuda.so", "libcuda.so.1"] if envpath: try: envpath = os.path.abspath(envpath) except ValueError: - raise ValueError("NUMBA_CUDA_DRIVER %s is not a valid path" % - envpath) + raise ValueError( + "NUMBA_CUDA_DRIVER %s is not a valid path" % envpath + ) if not os.path.isfile(envpath): - raise ValueError("NUMBA_CUDA_DRIVER %s is not a valid file " - "path. Note it must be a filepath of the .so/" - ".dll/.dylib or the driver" % envpath) + raise ValueError( + "NUMBA_CUDA_DRIVER %s is not a valid file " + "path. Note it must be a filepath of the .so/" + ".dll/.dylib or the driver" % envpath + ) candidates = [envpath] else: # First search for the name in the default library path. # If that is not found, try the specific path. - candidates = dlnames + [os.path.join(x, y) - for x, y in product(dldir, dlnames)] + candidates = dlnames + [ + os.path.join(x, y) for x, y in product(dldir, dlnames) + ] return dlloader, candidates def load_driver(dlloader, candidates): - # Load the driver; Collect driver error information path_not_exist = [] driver_load_error = [] @@ -184,7 +196,7 @@ def load_driver(dlloader, candidates): if all(path_not_exist): _raise_driver_not_found() else: - errmsg = '\n'.join(str(e) for e in driver_load_error) + errmsg = "\n".join(str(e) for e in driver_load_error) _raise_driver_error(errmsg) @@ -216,7 +228,7 @@ def _raise_driver_error(e): def _build_reverse_error_map(): - prefix = 'CUDA_ERROR' + prefix = "CUDA_ERROR" map = utils.UniqueDict() for name in dir(enums): if name.startswith(prefix): @@ -236,6 +248,7 @@ class Driver(object): """ Driver API functions are lazily bound. """ + _singleton = None def __new__(cls): @@ -254,9 +267,11 @@ def __init__(self): self.pid = None try: if config.DISABLE_CUDA: - msg = ("CUDA is disabled due to setting NUMBA_DISABLE_CUDA=1 " - "in the environment, or because CUDA is unsupported on " - "32-bit systems.") + msg = ( + "CUDA is disabled due to setting NUMBA_DISABLE_CUDA=1 " + "in the environment, or because CUDA is unsupported on " + "32-bit systems." + ) raise CudaSupportError(msg) self.lib = find_driver() except CudaSupportError as e: @@ -273,7 +288,7 @@ def ensure_initialized(self): self.is_initialized = True try: - _logger.info('init') + _logger.info("init") self.cuInit(0) except CudaAPIError as e: description = f"{e.msg} ({e.code})" @@ -292,8 +307,9 @@ def __getattr__(self, fname): self.ensure_initialized() if self.initialization_error is not None: - raise CudaSupportError("Error at driver init: \n%s:" % - self.initialization_error) + raise CudaSupportError( + "Error at driver init: \n%s:" % self.initialization_error + ) if USE_NV_BINDING: return self._cuda_python_wrap_fn(fname) @@ -317,12 +333,12 @@ def _ctypes_wrap_fn(self, fname, libfn=None): def verbose_cuda_api_call(*args): argstr = ", ".join([str(arg) for arg in args]) - _logger.debug('call driver api: %s(%s)', libfn.__name__, argstr) + _logger.debug("call driver api: %s(%s)", libfn.__name__, argstr) retcode = libfn(*args) self._check_ctypes_error(fname, retcode) def safe_cuda_api_call(*args): - _logger.debug('call driver api: %s', libfn.__name__) + _logger.debug("call driver api: %s", libfn.__name__) retcode = libfn(*args) self._check_ctypes_error(fname, retcode) @@ -340,11 +356,11 @@ def _cuda_python_wrap_fn(self, fname): def verbose_cuda_api_call(*args): argstr = ", ".join([str(arg) for arg in args]) - _logger.debug('call driver api: %s(%s)', libfn.__name__, argstr) + _logger.debug("call driver api: %s(%s)", libfn.__name__, argstr) return self._check_cuda_python_error(fname, libfn(*args)) def safe_cuda_api_call(*args): - _logger.debug('call driver api: %s', libfn.__name__) + _logger.debug("call driver api: %s", libfn.__name__) return self._check_cuda_python_error(fname, libfn(*args)) if config.CUDA_LOG_API_ARGS: @@ -361,27 +377,27 @@ def _find_api(self, fname): # binding. For the NVidia binding, it handles linking to the correct # variant. if config.CUDA_PER_THREAD_DEFAULT_STREAM and not USE_NV_BINDING: - variants = ('_v2_ptds', '_v2_ptsz', '_ptds', '_ptsz', '_v2', '') + variants = ("_v2_ptds", "_v2_ptsz", "_ptds", "_ptsz", "_v2", "") else: - variants = ('_v2', '') + variants = ("_v2", "") for variant in variants: try: - return getattr(self.lib, f'{fname}{variant}') + return getattr(self.lib, f"{fname}{variant}") except AttributeError: pass # Not found. # Delay missing function error to use def absent_function(*args, **kws): - raise CudaDriverError(f'Driver missing function: {fname}') + raise CudaDriverError(f"Driver missing function: {fname}") setattr(self, fname, absent_function) return absent_function def _detect_fork(self): if self.pid is not None and _getpid() != self.pid: - msg = 'pid %s forked from pid %s after CUDA driver init' + msg = "pid %s forked from pid %s after CUDA driver init" _logger.critical(msg, _getpid(), self.pid) raise CudaDriverError("CUDA initialized before forking") @@ -425,13 +441,11 @@ def get_device_count(self): return count.value def list_devices(self): - """Returns a list of active devices - """ + """Returns a list of active devices""" return list(self.devices.values()) def reset(self): - """Reset all devices - """ + """Reset all devices""" for dev in self.devices.values(): dev.reset() @@ -449,8 +463,7 @@ def pop_active_context(self): return popped def get_active_context(self): - """Returns an instance of ``_ActiveContext``. - """ + """Returns an instance of ``_ActiveContext``.""" return _ActiveContext() def get_version(self): @@ -477,12 +490,13 @@ class _ActiveContext(object): Once entering the context, it is assumed that the active CUDA context is not changed until the context is exited. """ + _tls_cache = threading.local() def __enter__(self): is_top = False # check TLS cache - if hasattr(self._tls_cache, 'ctx_devnum'): + if hasattr(self._tls_cache, "ctx_devnum"): hctx, devnum = self._tls_cache.ctx_devnum # Not cached. Query the driver API. else: @@ -515,11 +529,10 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if self._is_top: - delattr(self._tls_cache, 'ctx_devnum') + delattr(self._tls_cache, "ctx_devnum") def __bool__(self): - """Returns True is there's a valid and active CUDA context. - """ + """Returns True is there's a valid and active CUDA context.""" return self.context_handle is not None __nonzero__ = __bool__ @@ -533,7 +546,7 @@ def _build_reverse_device_attrs(): map = utils.UniqueDict() for name in dir(enums): if name.startswith(prefix): - map[name[len(prefix):]] = getattr(enums, name) + map[name[len(prefix) :]] = getattr(enums, name) return map @@ -545,6 +558,7 @@ class Device(object): The device object owns the CUDA contexts. This is owned by the driver object. User should not construct devices directly. """ + @classmethod def from_identity(self, identity): """Create Device object from device identity created by @@ -579,15 +593,17 @@ def __init__(self, devnum): self.attributes = {} # Read compute capability - self.compute_capability = (self.COMPUTE_CAPABILITY_MAJOR, - self.COMPUTE_CAPABILITY_MINOR) + self.compute_capability = ( + self.COMPUTE_CAPABILITY_MAJOR, + self.COMPUTE_CAPABILITY_MINOR, + ) # Read name bufsz = 128 if USE_NV_BINDING: buf = driver.cuDeviceGetName(bufsz, self.id) - name = buf.decode('utf-8').rstrip('\0') + name = buf.decode("utf-8").rstrip("\0") else: buf = (c_char * bufsz)() driver.cuDeviceGetName(buf, bufsz, self.id) @@ -604,31 +620,31 @@ def __init__(self, devnum): driver.cuDeviceGetUuid(byref(uuid), self.id) uuid_vals = tuple(bytes(uuid)) - b = '%02x' + b = "%02x" b2 = b * 2 b4 = b * 4 b6 = b * 6 - fmt = f'GPU-{b4}-{b2}-{b2}-{b2}-{b6}' + fmt = f"GPU-{b4}-{b2}-{b2}-{b2}-{b6}" self.uuid = fmt % uuid_vals self.primary_context = None def get_device_identity(self): return { - 'pci_domain_id': self.PCI_DOMAIN_ID, - 'pci_bus_id': self.PCI_BUS_ID, - 'pci_device_id': self.PCI_DEVICE_ID, + "pci_domain_id": self.PCI_DOMAIN_ID, + "pci_bus_id": self.PCI_BUS_ID, + "pci_device_id": self.PCI_DEVICE_ID, } def __repr__(self): return "" % (self.id, self.name) def __getattr__(self, attr): - """Read attributes lazily - """ + """Read attributes lazily""" if USE_NV_BINDING: - code = getattr(binding.CUdevice_attribute, - f'CU_DEVICE_ATTRIBUTE_{attr}') + code = getattr( + binding.CUdevice_attribute, f"CU_DEVICE_ATTRIBUTE_{attr}" + ) value = driver.cuDeviceGetAttribute(code, self.id) else: try: @@ -698,17 +714,18 @@ def supports_float16(self): def met_requirement_for_device(device): if device.compute_capability < MIN_REQUIRED_CC: - raise CudaSupportError("%s has compute capability < %s" % - (device, MIN_REQUIRED_CC)) + raise CudaSupportError( + "%s has compute capability < %s" % (device, MIN_REQUIRED_CC) + ) class BaseCUDAMemoryManager(object, metaclass=ABCMeta): """Abstract base class for External Memory Management (EMM) Plugins.""" def __init__(self, *args, **kwargs): - if 'context' not in kwargs: + if "context" not in kwargs: raise RuntimeError("Memory manager requires a context") - self.context = kwargs.pop('context') + self.context = kwargs.pop("context") @abstractmethod def memalloc(self, size): @@ -864,8 +881,7 @@ def _attempt_allocation(self, allocator): else: raise - def memhostalloc(self, size, mapped=False, portable=False, - wc=False): + def memhostalloc(self, size, mapped=False, portable=False, wc=False): """Implements the allocation of pinned host memory. It is recommended that this method is not overridden by EMM Plugin @@ -880,6 +896,7 @@ def memhostalloc(self, size, mapped=False, portable=False, flags |= enums.CU_MEMHOSTALLOC_WRITECOMBINED if USE_NV_BINDING: + def allocator(): return driver.cuMemHostAlloc(size, flags) @@ -946,16 +963,19 @@ def allocator(): ctx = weakref.proxy(self.context) if mapped: - mem = MappedMemory(ctx, pointer, size, owner=owner, - finalizer=finalizer) + mem = MappedMemory( + ctx, pointer, size, owner=owner, finalizer=finalizer + ) self.allocations[alloc_key] = mem return mem.own() else: - return PinnedMemory(ctx, pointer, size, owner=owner, - finalizer=finalizer) + return PinnedMemory( + ctx, pointer, size, owner=owner, finalizer=finalizer + ) def memallocmanaged(self, size, attach_global): if USE_NV_BINDING: + def allocator(): ma_flags = binding.CUmemAttach_flags @@ -1014,8 +1034,7 @@ def defer_cleanup(self): class GetIpcHandleMixin: - """A class that provides a default implementation of ``get_ipc_handle()``. - """ + """A class that provides a default implementation of ``get_ipc_handle()``.""" def get_ipc_handle(self, memory): """Open an IPC memory handle by using ``cuMemGetAddressRange`` to @@ -1034,8 +1053,9 @@ def get_ipc_handle(self, memory): offset = memory.handle.value - base source_info = self.context.device.get_device_identity() - return IpcHandle(memory, ipchandle, memory.size, source_info, - offset=offset) + return IpcHandle( + memory, ipchandle, memory.size, source_info, offset=offset + ) class NumbaCUDAMemoryManager(GetIpcHandleMixin, HostOnlyCUDAMemoryManager): @@ -1050,6 +1070,7 @@ def initialize(self): def memalloc(self, size): if USE_NV_BINDING: + def allocator(): return driver.cuMemAlloc(size) @@ -1098,7 +1119,7 @@ def _ensure_memory_manager(): if _memory_manager: return - if config.CUDA_MEMORY_MANAGER == 'default': + if config.CUDA_MEMORY_MANAGER == "default": _memory_manager = NumbaCUDAMemoryManager return @@ -1106,8 +1127,9 @@ def _ensure_memory_manager(): mgr_module = importlib.import_module(config.CUDA_MEMORY_MANAGER) set_memory_manager(mgr_module._numba_memory_manager) except Exception: - raise RuntimeError("Failed to use memory manager from %s" % - config.CUDA_MEMORY_MANAGER) + raise RuntimeError( + "Failed to use memory manager from %s" % config.CUDA_MEMORY_MANAGER + ) def set_memory_manager(mm_plugin): @@ -1124,8 +1146,10 @@ def set_memory_manager(mm_plugin): dummy = mm_plugin(context=None) iv = dummy.interface_version if iv != _SUPPORTED_EMM_INTERFACE_VERSION: - err = "EMM Plugin interface has version %d - version %d required" \ - % (iv, _SUPPORTED_EMM_INTERFACE_VERSION) + err = "EMM Plugin interface has version %d - version %d required" % ( + iv, + _SUPPORTED_EMM_INTERFACE_VERSION, + ) raise RuntimeError(err) _memory_manager = mm_plugin @@ -1140,7 +1164,7 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls, 0) def __str__(self): - return '?' + return "?" _SizeNotSet = _SizeNotSet() @@ -1153,6 +1177,7 @@ class _PendingDeallocs(object): modified later once the driver is initialized and the total memory capacity known. """ + def __init__(self, capacity=_SizeNotSet): self._cons = deque() self._disable_count = 0 @@ -1172,11 +1197,13 @@ def add_item(self, dtor, handle, size=_SizeNotSet): byte size of the resource added. It is an optional argument. Some resources (e.g. CUModule) has an unknown memory footprint on the device. """ - _logger.info('add pending dealloc: %s %s bytes', dtor.__name__, size) + _logger.info("add pending dealloc: %s %s bytes", dtor.__name__, size) self._cons.append((dtor, handle, size)) self._size += int(size) - if (len(self._cons) > config.CUDA_DEALLOCS_COUNT or - self._size > self._max_pending_bytes): + if ( + len(self._cons) > config.CUDA_DEALLOCS_COUNT + or self._size > self._max_pending_bytes + ): self.clear() def clear(self): @@ -1187,7 +1214,7 @@ def clear(self): if not self.is_disabled: while self._cons: [dtor, handle, size] = self._cons.popleft() - _logger.info('dealloc: %s %s bytes', dtor.__name__, size) + _logger.info("dealloc: %s %s bytes", dtor.__name__, size) dtor(handle) self._size = 0 @@ -1251,19 +1278,19 @@ def reset(self): Clean up all owned resources in this context. """ # Free owned resources - _logger.info('reset context of device %s', self.device.id) + _logger.info("reset context of device %s", self.device.id) self.memory_manager.reset() self.modules.clear() # Clear trash self.deallocations.clear() def get_memory_info(self): - """Returns (free, total) memory in bytes in the context. - """ + """Returns (free, total) memory in bytes in the context.""" return self.memory_manager.get_memory_info() - def get_active_blocks_per_multiprocessor(self, func, blocksize, memsize, - flags=None): + def get_active_blocks_per_multiprocessor( + self, func, blocksize, memsize, flags=None + ): """Return occupancy of a function. :param func: kernel for which occupancy is calculated :param blocksize: block size the kernel is intended to be launched with @@ -1275,8 +1302,9 @@ def get_active_blocks_per_multiprocessor(self, func, blocksize, memsize, else: return self._ctypes_active_blocks_per_multiprocessor(*args) - def _cuda_python_active_blocks_per_multiprocessor(self, func, blocksize, - memsize, flags): + def _cuda_python_active_blocks_per_multiprocessor( + self, func, blocksize, memsize, flags + ): ps = [func.handle, blocksize, memsize] if not flags: @@ -1285,8 +1313,9 @@ def _cuda_python_active_blocks_per_multiprocessor(self, func, blocksize, ps.append(flags) return driver.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(*ps) - def _ctypes_active_blocks_per_multiprocessor(self, func, blocksize, - memsize, flags): + def _ctypes_active_blocks_per_multiprocessor( + self, func, blocksize, memsize, flags + ): retval = c_int() args = (byref(retval), func.handle, blocksize, memsize) @@ -1297,8 +1326,9 @@ def _ctypes_active_blocks_per_multiprocessor(self, func, blocksize, return retval.value - def get_max_potential_block_size(self, func, b2d_func, memsize, - blocksizelimit, flags=None): + def get_max_potential_block_size( + self, func, b2d_func, memsize, blocksizelimit, flags=None + ): """Suggest a launch configuration with reasonable occupancy. :param func: kernel for which occupancy is calculated :param b2d_func: function that calculates how much per-block dynamic @@ -1315,13 +1345,20 @@ def get_max_potential_block_size(self, func, b2d_func, memsize, else: return self._ctypes_max_potential_block_size(*args) - def _ctypes_max_potential_block_size(self, func, b2d_func, memsize, - blocksizelimit, flags): + def _ctypes_max_potential_block_size( + self, func, b2d_func, memsize, blocksizelimit, flags + ): gridsize = c_int() blocksize = c_int() b2d_cb = cu_occupancy_b2d_size(b2d_func) - args = [byref(gridsize), byref(blocksize), func.handle, b2d_cb, - memsize, blocksizelimit] + args = [ + byref(gridsize), + byref(blocksize), + func.handle, + b2d_cb, + memsize, + blocksizelimit, + ] if not flags: driver.cuOccupancyMaxPotentialBlockSize(*args) @@ -1331,10 +1368,11 @@ def _ctypes_max_potential_block_size(self, func, b2d_func, memsize, return (gridsize.value, blocksize.value) - def _cuda_python_max_potential_block_size(self, func, b2d_func, memsize, - blocksizelimit, flags): + def _cuda_python_max_potential_block_size( + self, func, b2d_func, memsize, blocksizelimit, flags + ): b2d_cb = ctypes.CFUNCTYPE(c_size_t, c_int)(b2d_func) - ptr = int.from_bytes(b2d_cb, byteorder='little') + ptr = int.from_bytes(b2d_cb, byteorder="little") driver_b2d_cb = binding.CUoccupancyB2DSize(ptr) args = [func.handle, driver_b2d_cb, memsize, blocksizelimit] @@ -1387,7 +1425,7 @@ def get_ipc_handle(self, memory): Returns an *IpcHandle* from a GPU allocation. """ if not SUPPORTS_IPC: - raise OSError('OS does not support CUDA IPC') + raise OSError("OS does not support CUDA IPC") return self.memory_manager.get_ipc_handle(memory) def open_ipc_handle(self, handle, size): @@ -1400,13 +1438,13 @@ def open_ipc_handle(self, handle, size): driver.cuIpcOpenMemHandle(byref(dptr), handle, flags) # wrap it - return MemoryPointer(context=weakref.proxy(self), pointer=dptr, - size=size) + return MemoryPointer( + context=weakref.proxy(self), pointer=dptr, size=size + ) def enable_peer_access(self, peer_context, flags=0): - """Enable peer access between the current context and the peer context - """ - assert flags == 0, '*flags* is reserved and MUST be zero' + """Enable peer access between the current context and the peer context""" + assert flags == 0, "*flags* is reserved and MUST be zero" driver.cuCtxEnablePeerAccess(peer_context, flags) def can_access_peer(self, peer_device): @@ -1415,28 +1453,34 @@ def can_access_peer(self, peer_device): """ if USE_NV_BINDING: peer_device = binding.CUdevice(peer_device) - can_access_peer = driver.cuDeviceCanAccessPeer(self.device.id, - peer_device) + can_access_peer = driver.cuDeviceCanAccessPeer( + self.device.id, peer_device + ) else: can_access_peer = c_int() - driver.cuDeviceCanAccessPeer(byref(can_access_peer), - self.device.id, peer_device,) + driver.cuDeviceCanAccessPeer( + byref(can_access_peer), + self.device.id, + peer_device, + ) return bool(can_access_peer) def create_module_ptx(self, ptx): if isinstance(ptx, str): - ptx = ptx.encode('utf8') + ptx = ptx.encode("utf8") if USE_NV_BINDING: image = ptx else: image = c_char_p(ptx) return self.create_module_image(image) - def create_module_image(self, image, - setup_callbacks=None, teardown_callbacks=None): - module = load_module_image(self, image, - setup_callbacks, teardown_callbacks) + def create_module_image( + self, image, setup_callbacks=None, teardown_callbacks=None + ): + module = load_module_image( + self, image, setup_callbacks, teardown_callbacks + ) if USE_NV_BINDING: key = module.handle else: @@ -1483,8 +1527,11 @@ def create_stream(self): else: handle = drvapi.cu_stream() driver.cuStreamCreate(byref(handle), 0) - return Stream(weakref.proxy(self), handle, - _stream_finalizer(self.deallocations, handle)) + return Stream( + weakref.proxy(self), + handle, + _stream_finalizer(self.deallocations, handle), + ) def create_external_stream(self, ptr): if not isinstance(ptr, int): @@ -1493,8 +1540,7 @@ def create_external_stream(self, ptr): handle = binding.CUstream(ptr) else: handle = drvapi.cu_stream(ptr) - return Stream(weakref.proxy(self), handle, None, - external=True) + return Stream(weakref.proxy(self), handle, None, external=True) def create_event(self, timing=True): flags = 0 @@ -1505,8 +1551,11 @@ def create_event(self, timing=True): else: handle = drvapi.cu_event() driver.cuEventCreate(byref(handle), flags) - return Event(weakref.proxy(self), handle, - finalizer=_event_finalizer(self.deallocations, handle)) + return Event( + weakref.proxy(self), + handle, + finalizer=_event_finalizer(self.deallocations, handle), + ) def synchronize(self): driver.cuCtxSynchronize() @@ -1530,21 +1579,25 @@ def __ne__(self, other): return not self.__eq__(other) -def load_module_image(context, image, - setup_callbacks=None, teardown_callbacks=None): +def load_module_image( + context, image, setup_callbacks=None, teardown_callbacks=None +): """ image must be a pointer """ if USE_NV_BINDING: return load_module_image_cuda_python( - context, image, setup_callbacks, teardown_callbacks) + context, image, setup_callbacks, teardown_callbacks + ) else: return load_module_image_ctypes( - context, image, setup_callbacks, teardown_callbacks) + context, image, setup_callbacks, teardown_callbacks + ) -def load_module_image_ctypes(context, image, - setup_callbacks, teardown_callbacks): +def load_module_image_ctypes( + context, image, setup_callbacks, teardown_callbacks +): logsz = config.CUDA_LOG_SIZE jitinfo = (c_char * logsz)() @@ -1563,21 +1616,28 @@ def load_module_image_ctypes(context, image, handle = drvapi.cu_module() try: - driver.cuModuleLoadDataEx(byref(handle), image, len(options), - option_keys, option_vals) + driver.cuModuleLoadDataEx( + byref(handle), image, len(options), option_keys, option_vals + ) except CudaAPIError as e: msg = "cuModuleLoadDataEx error:\n%s" % jiterrors.value.decode("utf8") raise CudaAPIError(e.code, msg) info_log = jitinfo.value - return CtypesModule(weakref.proxy(context), handle, info_log, - _module_finalizer(context, handle), - setup_callbacks, teardown_callbacks) + return CtypesModule( + weakref.proxy(context), + handle, + info_log, + _module_finalizer(context, handle), + setup_callbacks, + teardown_callbacks, + ) -def load_module_image_cuda_python(context, image, - setup_callbacks, teardown_callbacks): +def load_module_image_cuda_python( + context, image, setup_callbacks, teardown_callbacks +): """ image must be a pointer """ @@ -1599,18 +1659,24 @@ def load_module_image_cuda_python(context, image, option_vals = [v for v in options.values()] try: - handle = driver.cuModuleLoadDataEx(image, len(options), option_keys, - option_vals) + handle = driver.cuModuleLoadDataEx( + image, len(options), option_keys, option_vals + ) except CudaAPIError as e: - err_string = jiterrors.decode('utf-8') + err_string = jiterrors.decode("utf-8") msg = "cuModuleLoadDataEx error:\n%s" % err_string raise CudaAPIError(e.code, msg) - info_log = jitinfo.decode('utf-8') + info_log = jitinfo.decode("utf-8") - return CudaPythonModule(weakref.proxy(context), handle, info_log, - _module_finalizer(context, handle), - setup_callbacks, teardown_callbacks) + return CudaPythonModule( + weakref.proxy(context), + handle, + info_log, + _module_finalizer(context, handle), + setup_callbacks, + teardown_callbacks, + ) def _alloc_finalizer(memory_manager, ptr, alloc_key, size): @@ -1713,6 +1779,7 @@ class _CudaIpcImpl(object): """Implementation of GPU IPC using CUDA driver API. This requires the devices to be peer accessible. """ + def __init__(self, parent): self.base = parent.base self.handle = parent.handle @@ -1726,10 +1793,10 @@ def open(self, context): Import the IPC memory and returns a raw CUDA memory pointer object """ if self.base is not None: - raise ValueError('opening IpcHandle from original process') + raise ValueError("opening IpcHandle from original process") if self._opened_mem is not None: - raise ValueError('IpcHandle is already opened') + raise ValueError("IpcHandle is already opened") mem = context.open_ipc_handle(self.handle, self.offset + self.size) # this object owns the opened allocation @@ -1740,7 +1807,7 @@ def open(self, context): def close(self): if self._opened_mem is None: - raise ValueError('IpcHandle not opened') + raise ValueError("IpcHandle not opened") driver.cuIpcCloseMemHandle(self._opened_mem.handle) self._opened_mem = None @@ -1749,6 +1816,7 @@ class _StagedIpcImpl(object): """Implementation of GPU IPC using custom staging logic to workaround CUDA IPC limitation on peer accessibility between devices. """ + def __init__(self, parent, source_info): self.parent = parent self.base = parent.base @@ -1804,6 +1872,7 @@ class IpcHandle(object): referred to by this IPC handle. :type offset: int """ + def __init__(self, base, handle, size, source_info=None, offset=0): self.base = base self.handle = handle @@ -1827,12 +1896,11 @@ def can_access_peer(self, context): return context.can_access_peer(source_device.id) def open_staged(self, context): - """Open the IPC by allowing staging on the host memory first. - """ + """Open the IPC by allowing staging on the host memory first.""" self._sentry_source_info() if self._impl is not None: - raise ValueError('IpcHandle is already opened') + raise ValueError("IpcHandle is already opened") self._impl = _StagedIpcImpl(self, self.source_info) return self._impl.open(context) @@ -1842,7 +1910,7 @@ def open_direct(self, context): Import the IPC memory and returns a raw CUDA memory pointer object """ if self._impl is not None: - raise ValueError('IpcHandle is already opened') + raise ValueError("IpcHandle is already opened") self._impl = _CudaIpcImpl(self) return self._impl.open(context) @@ -1873,12 +1941,13 @@ def open_array(self, context, shape, dtype, strides=None): strides = dtype.itemsize dptr = self.open(context) # read the device pointer as an array - return devicearray.DeviceNDArray(shape=shape, strides=strides, - dtype=dtype, gpu_data=dptr) + return devicearray.DeviceNDArray( + shape=shape, strides=strides, dtype=dtype, gpu_data=dptr + ) def close(self): if self._impl is None: - raise ValueError('IpcHandle not opened') + raise ValueError("IpcHandle not opened") self._impl.close() self._impl = None @@ -1904,8 +1973,13 @@ def _rebuild(cls, handle_ary, size, source_info, offset): else: handle = drvapi.cu_ipc_mem_handle() handle.reserved = handle_ary - return cls(base=None, handle=handle, size=size, - source_info=source_info, offset=offset) + return cls( + base=None, + handle=handle, + size=size, + source_info=source_info, + offset=offset, + ) class MemoryPointer(object): @@ -1939,6 +2013,7 @@ class MemoryPointer(object): :param finalizer: A function that is called when the buffer is to be freed. :type finalizer: function """ + __cuda_memory__ = True def __init__(self, context, pointer, size, owner=None, finalizer=None): @@ -1974,8 +2049,9 @@ def free(self): def memset(self, byte, count=None, stream=0): count = self.size if count is None else count if stream: - driver.cuMemsetD8Async(self.device_pointer, byte, count, - stream.handle) + driver.cuMemsetD8Async( + self.device_pointer, byte, count, stream.handle + ) else: driver.cuMemsetD8(self.device_pointer, byte, count) @@ -1989,12 +2065,12 @@ def view(self, start, stop=None): if not self.device_pointer_value: if size != 0: raise RuntimeError("non-empty slice into empty slice") - view = self # new view is just a reference to self + view = self # new view is just a reference to self # Handle normal case else: base = self.device_pointer_value + start if size < 0: - raise RuntimeError('size cannot be negative') + raise RuntimeError("size cannot be negative") if USE_NV_BINDING: pointer = binding.CUdeviceptr() ctypes_ptr = drvapi.cu_device_ptr.from_address(pointer.getPtr()) @@ -2030,6 +2106,7 @@ class AutoFreePointer(MemoryPointer): Constructor arguments are the same as for :class:`MemoryPointer`. """ + def __init__(self, *args, **kwargs): super(AutoFreePointer, self).__init__(*args, **kwargs) # Releease the self reference to the buffer, so that the finalizer @@ -2072,8 +2149,9 @@ def __init__(self, context, pointer, size, owner=None, finalizer=None): self._bufptr_ = self.host_pointer.value self.device_pointer = devptr - super(MappedMemory, self).__init__(context, devptr, size, - finalizer=finalizer) + super(MappedMemory, self).__init__( + context, devptr, size, finalizer=finalizer + ) self.handle = self.host_pointer # For buffer interface @@ -2188,8 +2266,7 @@ def deref(): weakref.finalize(self, deref) def __getattr__(self, fname): - """Proxy MemoryPointer methods - """ + """Proxy MemoryPointer methods""" return getattr(self._view, fname) @@ -2220,18 +2297,15 @@ def __repr__(self): if USE_NV_BINDING: default_streams = { CU_STREAM_DEFAULT: "", - binding.CU_STREAM_LEGACY: - "", - binding.CU_STREAM_PER_THREAD: - "", + binding.CU_STREAM_LEGACY: "", + binding.CU_STREAM_PER_THREAD: "", } ptr = int(self.handle) or 0 else: default_streams = { drvapi.CU_STREAM_DEFAULT: "", drvapi.CU_STREAM_LEGACY: "", - drvapi.CU_STREAM_PER_THREAD: - "", + drvapi.CU_STREAM_PER_THREAD: "", } ptr = self.handle.value or drvapi.CU_STREAM_DEFAULT @@ -2243,18 +2317,18 @@ def __repr__(self): return "" % (ptr, self.context) def synchronize(self): - ''' + """ Wait for all commands in this stream to execute. This will commit any pending memory transfers. - ''' + """ driver.cuStreamSynchronize(self.handle) @contextlib.contextmanager def auto_synchronize(self): - ''' + """ A context manager that waits for all commands in this stream to execute and commits any pending memory transfers upon exiting the context. - ''' + """ yield self self.synchronize() @@ -2281,7 +2355,7 @@ def add_callback(self, callback, arg=None): data = (self, callback, arg) _py_incref(data) if USE_NV_BINDING: - ptr = int.from_bytes(self._stream_callback, byteorder='little') + ptr = int.from_bytes(self._stream_callback, byteorder="little") stream_callback = binding.CUstreamCallback(ptr) # The callback needs to receive a pointer to the data PyObject data = id(data) @@ -2382,9 +2456,9 @@ def elapsed_time(self, evtend): def event_elapsed_time(evtstart, evtend): - ''' + """ Compute the elapsed time between two events in milliseconds. - ''' + """ if USE_NV_BINDING: return driver.cuEventElapsedTime(evtstart.handle, evtend.handle) else: @@ -2396,8 +2470,15 @@ def event_elapsed_time(evtstart, evtend): class Module(metaclass=ABCMeta): """Abstract base class for modules""" - def __init__(self, context, handle, info_log, finalizer=None, - setup_callbacks=None, teardown_callbacks=None): + def __init__( + self, + context, + handle, + info_log, + finalizer=None, + setup_callbacks=None, + teardown_callbacks=None, + ): self.context = context self.handle = handle self.info_log = info_log @@ -2436,7 +2517,7 @@ def setup(self): self.initialized = True def _set_finalizers(self): - """Create finalizers that tear down the module. """ + """Create finalizers that tear down the module.""" if self.teardown_functions is None: return @@ -2453,34 +2534,35 @@ def _teardown(teardowns, handle): class CtypesModule(Module): - def get_function(self, name): handle = drvapi.cu_function() - driver.cuModuleGetFunction(byref(handle), self.handle, - name.encode('utf8')) + driver.cuModuleGetFunction( + byref(handle), self.handle, name.encode("utf8") + ) return CtypesFunction(weakref.proxy(self), handle, name) def get_global_symbol(self, name): ptr = drvapi.cu_device_ptr() size = drvapi.c_size_t() - driver.cuModuleGetGlobal(byref(ptr), byref(size), self.handle, - name.encode('utf8')) + driver.cuModuleGetGlobal( + byref(ptr), byref(size), self.handle, name.encode("utf8") + ) return MemoryPointer(self.context, ptr, size), size.value class CudaPythonModule(Module): - def get_function(self, name): - handle = driver.cuModuleGetFunction(self.handle, name.encode('utf8')) + handle = driver.cuModuleGetFunction(self.handle, name.encode("utf8")) return CudaPythonFunction(weakref.proxy(self), handle, name) def get_global_symbol(self, name): - ptr, size = driver.cuModuleGetGlobal(self.handle, name.encode('utf8')) + ptr, size = driver.cuModuleGetGlobal(self.handle, name.encode("utf8")) return MemoryPointer(self.context, ptr, size), size -FuncAttr = namedtuple("FuncAttr", ["regs", "shared", "local", "const", - "maxthreads"]) +FuncAttr = namedtuple( + "FuncAttr", ["regs", "shared", "local", "const", "maxthreads"] +) class Function(metaclass=ABCMeta): @@ -2503,8 +2585,9 @@ def device(self): return self.module.context.device @abstractmethod - def cache_config(self, prefer_equal=False, prefer_cache=False, - prefer_shared=False): + def cache_config( + self, prefer_equal=False, prefer_cache=False, prefer_shared=False + ): """Set the cache configuration for this function.""" @abstractmethod @@ -2518,9 +2601,9 @@ def read_func_attr_all(self): class CtypesFunction(Function): - - def cache_config(self, prefer_equal=False, prefer_cache=False, - prefer_shared=False): + def cache_config( + self, prefer_equal=False, prefer_cache=False, prefer_shared=False + ): prefer_equal = prefer_equal or (prefer_cache and prefer_shared) if prefer_equal: flag = enums.CU_FUNC_CACHE_PREFER_EQUAL @@ -2543,15 +2626,17 @@ def read_func_attr_all(self): lmem = self.read_func_attr(enums.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES) smem = self.read_func_attr(enums.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES) maxtpb = self.read_func_attr( - enums.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK) - return FuncAttr(regs=nregs, const=cmem, local=lmem, shared=smem, - maxthreads=maxtpb) + enums.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK + ) + return FuncAttr( + regs=nregs, const=cmem, local=lmem, shared=smem, maxthreads=maxtpb + ) class CudaPythonFunction(Function): - - def cache_config(self, prefer_equal=False, prefer_cache=False, - prefer_shared=False): + def cache_config( + self, prefer_equal=False, prefer_cache=False, prefer_shared=False + ): prefer_equal = prefer_equal or (prefer_cache and prefer_shared) attr = binding.CUfunction_attribute if prefer_equal: @@ -2574,19 +2659,26 @@ def read_func_attr_all(self): lmem = self.read_func_attr(attr.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES) smem = self.read_func_attr(attr.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES) maxtpb = self.read_func_attr( - attr.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK) - return FuncAttr(regs=nregs, const=cmem, local=lmem, shared=smem, - maxthreads=maxtpb) - + attr.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK + ) + return FuncAttr( + regs=nregs, const=cmem, local=lmem, shared=smem, maxthreads=maxtpb + ) -def launch_kernel(cufunc_handle, - gx, gy, gz, - bx, by, bz, - sharedmem, - hstream, - args, - cooperative=False): +def launch_kernel( + cufunc_handle, + gx, + gy, + gz, + bx, + by, + bz, + sharedmem, + hstream, + args, + cooperative=False, +): param_ptrs = [addressof(arg) for arg in args] params = (c_void_p * len(param_ptrs))(*param_ptrs) @@ -2598,46 +2690,54 @@ def launch_kernel(cufunc_handle, extra = None if cooperative: - driver.cuLaunchCooperativeKernel(cufunc_handle, - gx, gy, gz, - bx, by, bz, - sharedmem, - hstream, - params_for_launch) + driver.cuLaunchCooperativeKernel( + cufunc_handle, + gx, + gy, + gz, + bx, + by, + bz, + sharedmem, + hstream, + params_for_launch, + ) else: - driver.cuLaunchKernel(cufunc_handle, - gx, gy, gz, - bx, by, bz, - sharedmem, - hstream, - params_for_launch, - extra) + driver.cuLaunchKernel( + cufunc_handle, + gx, + gy, + gz, + bx, + by, + bz, + sharedmem, + hstream, + params_for_launch, + extra, + ) class Linker(metaclass=ABCMeta): """Abstract base class for linkers""" @classmethod - def new(cls, - max_registers=0, - lineinfo=False, - cc=None, - lto=None, - additional_flags=None - ): - + def new( + cls, + max_registers=0, + lineinfo=False, + cc=None, + lto=None, + additional_flags=None, + ): driver_ver = driver.get_version() - if ( - config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY - and driver_ver >= (12, 0) + if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and driver_ver >= ( + 12, + 0, ): - raise ValueError( - "Use CUDA_ENABLE_PYNVJITLINK for CUDA >= 12.0 MVC" - ) + raise ValueError("Use CUDA_ENABLE_PYNVJITLINK for CUDA >= 12.0 MVC") if config.CUDA_ENABLE_PYNVJITLINK and driver_ver < (12, 0): - raise ValueError( - "Enabling pynvjitlink requires CUDA 12." - ) + raise ValueError("Enabling pynvjitlink requires CUDA 12.") if config.CUDA_ENABLE_PYNVJITLINK: linker = PyNvJitLinker @@ -2686,9 +2786,9 @@ def add_cu(self, cu, name): ptx, log = nvrtc.compile(cu, name, cc) if config.DUMP_ASSEMBLY: - print(("ASSEMBLY %s" % name).center(80, '-')) + print(("ASSEMBLY %s" % name).center(80, "-")) print(ptx) - print('=' * 80) + print("=" * 80) # Link the program's PTX using the normal linker mechanism ptx_name = os.path.splitext(name)[0] + ".ptx" @@ -2699,7 +2799,7 @@ def add_file(self, path, kind): """Add code from a file to the link""" def add_cu_file(self, path): - with open(path, 'rb') as f: + with open(path, "rb") as f: cu = f.read() self.add_cu(cu, os.path.basename(path)) @@ -2717,24 +2817,24 @@ def add_file_guess_ext(self, path_or_code, ignore_nonlto=False): if isinstance(path_or_code, str): ext = pathlib.Path(path_or_code).suffix - if ext == '': + if ext == "": raise RuntimeError( "Don't know how to link file with no extension" ) - elif ext == '.cu': + elif ext == ".cu": self.add_cu_file(path_or_code) else: - kind = FILE_EXTENSION_MAP.get(ext.lstrip('.'), None) + kind = FILE_EXTENSION_MAP.get(ext.lstrip("."), None) if kind is None: raise RuntimeError( - "Don't know how to link file with extension " - f"{ext}" + f"Don't know how to link file with extension {ext}" ) if ignore_nonlto: warn_and_return = False if kind in ( - FILE_EXTENSION_MAP["fatbin"], FILE_EXTENSION_MAP["o"] + FILE_EXTENSION_MAP["fatbin"], + FILE_EXTENSION_MAP["o"], ): entry_types = inspect_obj_content(path_or_code) if "nvvm" not in entry_types: @@ -2799,6 +2899,7 @@ class MVCLinker(Linker): Linker supporting Minor Version Compatibility, backed by the cubinlinker package. """ + def __init__(self, max_registers=None, lineinfo=False, cc=None): try: from cubinlinker import CubinLinker @@ -2806,18 +2907,20 @@ def __init__(self, max_registers=None, lineinfo=False, cc=None): raise ImportError(_MVC_ERROR_MESSAGE) from err if cc is None: - raise RuntimeError("MVCLinker requires Compute Capability to be " - "specified, but cc is None") + raise RuntimeError( + "MVCLinker requires Compute Capability to be " + "specified, but cc is None" + ) super().__init__(max_registers, lineinfo, cc) arch = f"sm_{cc[0] * 10 + cc[1]}" - ptx_compile_opts = ['--gpu-name', arch, '-c'] + ptx_compile_opts = ["--gpu-name", arch, "-c"] if max_registers: arg = f"--maxrregcount={max_registers}" ptx_compile_opts.append(arg) if lineinfo: - ptx_compile_opts.append('--generate-line-info') + ptx_compile_opts.append("--generate-line-info") self.ptx_compile_options = tuple(ptx_compile_opts) self._linker = CubinLinker(f"--arch={arch}") @@ -2830,7 +2933,7 @@ def info_log(self): def error_log(self): return self._linker.error_log - def add_ptx(self, ptx, name=''): + def add_ptx(self, ptx, name=""): try: from ptxcompiler import compile_ptx from cubinlinker import CubinLinkerError @@ -2849,19 +2952,19 @@ def add_file(self, path, kind): raise ImportError(_MVC_ERROR_MESSAGE) from err try: - with open(path, 'rb') as f: + with open(path, "rb") as f: data = f.read() except FileNotFoundError: - raise LinkerError(f'{path} not found') + raise LinkerError(f"{path} not found") name = pathlib.Path(path).name - if kind == FILE_EXTENSION_MAP['cubin']: + if kind == FILE_EXTENSION_MAP["cubin"]: fn = self._linker.add_cubin - elif kind == FILE_EXTENSION_MAP['fatbin']: + elif kind == FILE_EXTENSION_MAP["fatbin"]: fn = self._linker.add_fatbin - elif kind == FILE_EXTENSION_MAP['a']: + elif kind == FILE_EXTENSION_MAP["a"]: raise LinkerError(f"Don't know how to link {kind}") - elif kind == FILE_EXTENSION_MAP['ptx']: + elif kind == FILE_EXTENSION_MAP["ptx"]: return self.add_ptx(data, name) else: raise LinkerError(f"Don't know how to link {kind}") @@ -2887,6 +2990,7 @@ class CtypesLinker(Linker): """ Links for current device if no CC given """ + def __init__(self, max_registers=0, lineinfo=False, cc=None): super().__init__(max_registers, lineinfo, cc) @@ -2920,8 +3024,9 @@ def __init__(self, max_registers=0, lineinfo=False, cc=None): option_vals = (c_void_p * len(raw_values))(*raw_values) self.handle = handle = drvapi.cu_link_state() - driver.cuLinkCreate(len(raw_keys), option_keys, option_vals, - byref(self.handle)) + driver.cuLinkCreate( + len(raw_keys), option_keys, option_vals, byref(self.handle) + ) weakref.finalize(self, driver.cuLinkDestroy, handle) @@ -2932,19 +3037,27 @@ def __init__(self, max_registers=0, lineinfo=False, cc=None): @property def info_log(self): - return self.linker_info_buf.value.decode('utf8') + return self.linker_info_buf.value.decode("utf8") @property def error_log(self): - return self.linker_errors_buf.value.decode('utf8') + return self.linker_errors_buf.value.decode("utf8") - def add_ptx(self, ptx, name=''): + def add_ptx(self, ptx, name=""): ptxbuf = c_char_p(ptx) - namebuf = c_char_p(name.encode('utf8')) + namebuf = c_char_p(name.encode("utf8")) self._keep_alive += [ptxbuf, namebuf] try: - driver.cuLinkAddData(self.handle, enums.CU_JIT_INPUT_PTX, - ptxbuf, len(ptx), namebuf, 0, None, None) + driver.cuLinkAddData( + self.handle, + enums.CU_JIT_INPUT_PTX, + ptxbuf, + len(ptx), + namebuf, + 0, + None, + None, + ) except CudaAPIError as e: raise LinkerError("%s\n%s" % (e, self.error_log)) @@ -2956,7 +3069,7 @@ def add_file(self, path, kind): driver.cuLinkAddFile(self.handle, kind, pathbuf, 0, None, None) except CudaAPIError as e: if e.code == enums.CUDA_ERROR_FILE_NOT_FOUND: - msg = f'{path} not found' + msg = f"{path} not found" else: msg = "%s\n%s" % (e, self.error_log) raise LinkerError(msg) @@ -2971,7 +3084,7 @@ def complete(self): raise LinkerError("%s\n%s" % (e, self.error_log)) size = size.value - assert size > 0, 'linker returned a zero sized cubin' + assert size > 0, "linker returned a zero sized cubin" del self._keep_alive[:] # We return a copy of the cubin because it's owned by the linker @@ -2983,6 +3096,7 @@ class CudaPythonLinker(Linker): """ Links for current device if no CC given """ + def __init__(self, max_registers=0, lineinfo=False, cc=None): super().__init__(max_registers, lineinfo, cc) @@ -3009,8 +3123,9 @@ def __init__(self, max_registers=0, lineinfo=False, cc=None): options[jit_option.CU_JIT_TARGET_FROM_CUCONTEXT] = 1 else: cc_val = cc[0] * 10 + cc[1] - cc_enum = getattr(binding.CUjit_target, - f'CU_TARGET_COMPUTE_{cc_val}') + cc_enum = getattr( + binding.CUjit_target, f"CU_TARGET_COMPUTE_{cc_val}" + ) options[jit_option.CU_JIT_TARGET] = cc_enum raw_keys = list(options.keys()) @@ -3027,19 +3142,20 @@ def __init__(self, max_registers=0, lineinfo=False, cc=None): @property def info_log(self): - return self.linker_info_buf.decode('utf8') + return self.linker_info_buf.decode("utf8") @property def error_log(self): - return self.linker_errors_buf.decode('utf8') + return self.linker_errors_buf.decode("utf8") - def add_ptx(self, ptx, name=''): - namebuf = name.encode('utf8') + def add_ptx(self, ptx, name=""): + namebuf = name.encode("utf8") self._keep_alive += [ptx, namebuf] try: input_ptx = binding.CUjitInputType.CU_JIT_INPUT_PTX - driver.cuLinkAddData(self.handle, input_ptx, ptx, len(ptx), - namebuf, 0, [], []) + driver.cuLinkAddData( + self.handle, input_ptx, ptx, len(ptx), namebuf, 0, [], [] + ) except CudaAPIError as e: raise LinkerError("%s\n%s" % (e, self.error_log)) @@ -3051,7 +3167,7 @@ def add_file(self, path, kind): driver.cuLinkAddFile(self.handle, kind, pathbuf, 0, [], []) except CudaAPIError as e: if e.code == binding.CUresult.CUDA_ERROR_FILE_NOT_FOUND: - msg = f'{path} not found' + msg = f"{path} not found" else: msg = "%s\n%s" % (e, self.error_log) raise LinkerError(msg) @@ -3062,7 +3178,7 @@ def complete(self): except CudaAPIError as e: raise LinkerError("%s\n%s" % (e, self.error_log)) - assert size > 0, 'linker returned a zero sized cubin' + assert size > 0, "linker returned a zero sized cubin" del self._keep_alive[:] # We return a copy of the cubin because it's owned by the linker cubin_ptr = ctypes.cast(cubin_buf, ctypes.POINTER(ctypes.c_char)) @@ -3196,6 +3312,7 @@ def complete(self): except NvJitLinkError as e: raise LinkerError from e + # ----------------------------------------------------------------------------- @@ -3245,7 +3362,7 @@ def device_memory_size(devmem): The result is cached in the device memory object. It may query the driver for the memory size of the device memory allocation. """ - sz = getattr(devmem, '_cuda_memsize_', None) + sz = getattr(devmem, "_cuda_memsize_", None) if sz is None: s, e = device_extents(devmem) if USE_NV_BINDING: @@ -3258,10 +3375,9 @@ def device_memory_size(devmem): def _is_datetime_dtype(obj): - """Returns True if the obj.dtype is datetime64 or timedelta64 - """ - dtype = getattr(obj, 'dtype', None) - return dtype is not None and dtype.char in 'Mm' + """Returns True if the obj.dtype is datetime64 or timedelta64""" + dtype = getattr(obj, "dtype", None) + return dtype is not None and dtype.char in "Mm" def _workaround_for_datetime(obj): @@ -3340,12 +3456,11 @@ def is_device_memory(obj): "device_pointer" which value is an int object carrying the pointer value of the device memory address. This is not tested in this method. """ - return getattr(obj, '__cuda_memory__', False) + return getattr(obj, "__cuda_memory__", False) def require_device_memory(obj): - """A sentry for methods that accept CUDA memory object. - """ + """A sentry for methods that accept CUDA memory object.""" if not is_device_memory(obj): raise Exception("Not a CUDA memory object.") @@ -3436,16 +3551,16 @@ def device_memset(dst, val, size, stream=0): def profile_start(): - ''' + """ Enable profile collection in the current context. - ''' + """ driver.cuProfilerStart() def profile_stop(): - ''' + """ Disable profile collection in the current context. - ''' + """ driver.cuProfilerStop() @@ -3472,18 +3587,21 @@ def inspect_obj_content(objpath: str): Given path to a fatbin or object, use `cuobjdump` to examine its content Return the set of entries in the object. """ - code_types :set[str] = set() + code_types: set[str] = set() try: - out = subprocess.run(["cuobjdump", objpath], check=True, - capture_output=True) + out = subprocess.run( + ["cuobjdump", objpath], check=True, capture_output=True + ) except FileNotFoundError as e: - msg = ("cuobjdump has not been found. You may need " - "to install the CUDA toolkit and ensure that " - "it is available on your PATH.\n") + msg = ( + "cuobjdump has not been found. You may need " + "to install the CUDA toolkit and ensure that " + "it is available on your PATH.\n" + ) raise RuntimeError(msg) from e - objtable = out.stdout.decode('utf-8') + objtable = out.stdout.decode("utf-8") entry_pattern = r"Fatbin (.*) code" for line in objtable.split("\n"): if match := re.match(entry_pattern, line): diff --git a/numba_cuda/numba/cuda/cudadrv/linkable_code.py b/numba_cuda/numba/cuda/cudadrv/linkable_code.py index d9f18a834..448a739ff 100644 --- a/numba_cuda/numba/cuda/cudadrv/linkable_code.py +++ b/numba_cuda/numba/cuda/cudadrv/linkable_code.py @@ -15,10 +15,9 @@ class LinkableCode: it. """ - def __init__(self, data, name=None, - setup_callback=None, - teardown_callback=None): - + def __init__( + self, data, name=None, setup_callback=None, teardown_callback=None + ): if setup_callback and not callable(setup_callback): raise TypeError("setup_callback must be callable") if teardown_callback and not callable(teardown_callback): diff --git a/numba_cuda/numba/cuda/cudadrv/nvvm.py b/numba_cuda/numba/cuda/cudadrv/nvvm.py index 0844661e2..b46fb0a39 100644 --- a/numba_cuda/numba/cuda/cudadrv/nvvm.py +++ b/numba_cuda/numba/cuda/cudadrv/nvvm.py @@ -1,12 +1,12 @@ """ This is a direct translation of nvvm.h """ + import logging import re import sys import warnings -from ctypes import (c_void_p, c_int, POINTER, c_char_p, c_size_t, byref, - c_char) +from ctypes import c_void_p, c_int, POINTER, c_char_p, c_size_t, byref, c_char import threading @@ -31,7 +31,7 @@ # Result code nvvm_result = c_int -RESULT_CODE_NAMES = ''' +RESULT_CODE_NAMES = """ NVVM_SUCCESS NVVM_ERROR_OUT_OF_MEMORY NVVM_ERROR_PROGRAM_CREATION_FAILURE @@ -42,19 +42,23 @@ NVVM_ERROR_INVALID_OPTION NVVM_ERROR_NO_MODULE_IN_PROGRAM NVVM_ERROR_COMPILATION -'''.split() +""".split() for i, k in enumerate(RESULT_CODE_NAMES): setattr(sys.modules[__name__], k, i) # Data layouts. NVVM IR 1.8 (CUDA 11.6) introduced 128-bit integer support. -_datalayout_original = ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-' - 'i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-' - 'v64:64:64-v128:128:128-n16:32:64') -_datalayout_i128 = ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-' - 'i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-' - 'v64:64:64-v128:128:128-n16:32:64') +_datalayout_original = ( + "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-" + "i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-" + "v64:64:64-v128:128:128-n16:32:64" +) +_datalayout_i128 = ( + "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-" + "i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-" + "v64:64:64-v128:128:128-n16:32:64" +) def is_available(): @@ -73,59 +77,74 @@ def is_available(): class NVVM(object): - '''Process-wide singleton. - ''' - _PROTOTYPES = { + """Process-wide singleton.""" + _PROTOTYPES = { # nvvmResult nvvmVersion(int *major, int *minor) - 'nvvmVersion': (nvvm_result, POINTER(c_int), POINTER(c_int)), - + "nvvmVersion": (nvvm_result, POINTER(c_int), POINTER(c_int)), # nvvmResult nvvmCreateProgram(nvvmProgram *cu) - 'nvvmCreateProgram': (nvvm_result, POINTER(nvvm_program)), - + "nvvmCreateProgram": (nvvm_result, POINTER(nvvm_program)), # nvvmResult nvvmDestroyProgram(nvvmProgram *cu) - 'nvvmDestroyProgram': (nvvm_result, POINTER(nvvm_program)), - + "nvvmDestroyProgram": (nvvm_result, POINTER(nvvm_program)), # nvvmResult nvvmAddModuleToProgram(nvvmProgram cu, const char *buffer, # size_t size, const char *name) - 'nvvmAddModuleToProgram': ( - nvvm_result, nvvm_program, c_char_p, c_size_t, c_char_p), - + "nvvmAddModuleToProgram": ( + nvvm_result, + nvvm_program, + c_char_p, + c_size_t, + c_char_p, + ), # nvvmResult nvvmLazyAddModuleToProgram(nvvmProgram cu, # const char* buffer, # size_t size, # const char *name) - 'nvvmLazyAddModuleToProgram': ( - nvvm_result, nvvm_program, c_char_p, c_size_t, c_char_p), - + "nvvmLazyAddModuleToProgram": ( + nvvm_result, + nvvm_program, + c_char_p, + c_size_t, + c_char_p, + ), # nvvmResult nvvmCompileProgram(nvvmProgram cu, int numOptions, # const char **options) - 'nvvmCompileProgram': ( - nvvm_result, nvvm_program, c_int, POINTER(c_char_p)), - + "nvvmCompileProgram": ( + nvvm_result, + nvvm_program, + c_int, + POINTER(c_char_p), + ), # nvvmResult nvvmGetCompiledResultSize(nvvmProgram cu, # size_t *bufferSizeRet) - 'nvvmGetCompiledResultSize': ( - nvvm_result, nvvm_program, POINTER(c_size_t)), - + "nvvmGetCompiledResultSize": ( + nvvm_result, + nvvm_program, + POINTER(c_size_t), + ), # nvvmResult nvvmGetCompiledResult(nvvmProgram cu, char *buffer) - 'nvvmGetCompiledResult': (nvvm_result, nvvm_program, c_char_p), - + "nvvmGetCompiledResult": (nvvm_result, nvvm_program, c_char_p), # nvvmResult nvvmGetProgramLogSize(nvvmProgram cu, # size_t *bufferSizeRet) - 'nvvmGetProgramLogSize': (nvvm_result, nvvm_program, POINTER(c_size_t)), - + "nvvmGetProgramLogSize": (nvvm_result, nvvm_program, POINTER(c_size_t)), # nvvmResult nvvmGetProgramLog(nvvmProgram cu, char *buffer) - 'nvvmGetProgramLog': (nvvm_result, nvvm_program, c_char_p), - + "nvvmGetProgramLog": (nvvm_result, nvvm_program, c_char_p), # nvvmResult nvvmIRVersion (int* majorIR, int* minorIR, int* majorDbg, # int* minorDbg ) - 'nvvmIRVersion': (nvvm_result, POINTER(c_int), POINTER(c_int), - POINTER(c_int), POINTER(c_int)), + "nvvmIRVersion": ( + nvvm_result, + POINTER(c_int), + POINTER(c_int), + POINTER(c_int), + POINTER(c_int), + ), # nvvmResult nvvmVerifyProgram (nvvmProgram prog, int numOptions, # const char** options) - 'nvvmVerifyProgram': (nvvm_result, nvvm_program, c_int, - POINTER(c_char_p)) + "nvvmVerifyProgram": ( + nvvm_result, + nvvm_program, + c_int, + POINTER(c_char_p), + ), } # Singleton reference @@ -136,11 +155,13 @@ def __new__(cls): if cls.__INSTANCE is None: cls.__INSTANCE = inst = object.__new__(cls) try: - inst.driver = open_cudalib('nvvm') + inst.driver = open_cudalib("nvvm") except OSError as e: cls.__INSTANCE = None - errmsg = ("libNVVM cannot be found. Do `conda install " - "cudatoolkit`:\n%s") + errmsg = ( + "libNVVM cannot be found. Do `conda install " + "cudatoolkit`:\n%s" + ) raise NvvmSupportError(errmsg % e) # Find & populate functions @@ -175,7 +196,7 @@ def get_version(self): major = c_int() minor = c_int() err = self.nvvmVersion(byref(major), byref(minor)) - self.check_error(err, 'Failed to get version.') + self.check_error(err, "Failed to get version.") return major.value, minor.value def get_ir_version(self): @@ -183,9 +204,10 @@ def get_ir_version(self): minorIR = c_int() majorDbg = c_int() minorDbg = c_int() - err = self.nvvmIRVersion(byref(majorIR), byref(minorIR), - byref(majorDbg), byref(minorDbg)) - self.check_error(err, 'Failed to get IR version.') + err = self.nvvmIRVersion( + byref(majorIR), byref(minorIR), byref(majorDbg), byref(minorDbg) + ) + self.check_error(err, "Failed to get IR version.") return majorIR.value, minorIR.value, majorDbg.value, minorDbg.value def check_error(self, error, msg, exit=False): @@ -223,18 +245,18 @@ def __init__(self, options): self.driver = NVVM() self._handle = nvvm_program() err = self.driver.nvvmCreateProgram(byref(self._handle)) - self.driver.check_error(err, 'Failed to create CU') + self.driver.check_error(err, "Failed to create CU") def stringify_option(k, v): - k = k.replace('_', '-') + k = k.replace("_", "-") if v is None: - return f'-{k}'.encode('utf-8') + return f"-{k}".encode("utf-8") if isinstance(v, bool): v = int(v) - return f'-{k}={v}'.encode('utf-8') + return f"-{k}={v}".encode("utf-8") options = [stringify_option(k, v) for k, v in options.items()] option_ptrs = (c_char_p * len(options))(*[c_char_p(x) for x in options]) @@ -248,17 +270,18 @@ def stringify_option(k, v): def __del__(self): driver = NVVM() err = driver.nvvmDestroyProgram(byref(self._handle)) - driver.check_error(err, 'Failed to destroy CU', exit=True) + driver.check_error(err, "Failed to destroy CU", exit=True) def add_module(self, buffer): """ - Add a module level NVVM IR to a compilation unit. - - The buffer should contain an NVVM module IR either in the bitcode - representation (LLVM3.0) or in the text representation. + Add a module level NVVM IR to a compilation unit. + - The buffer should contain an NVVM module IR either in the bitcode + representation (LLVM3.0) or in the text representation. """ - err = self.driver.nvvmAddModuleToProgram(self._handle, buffer, - len(buffer), None) - self.driver.check_error(err, 'Failed to add module') + err = self.driver.nvvmAddModuleToProgram( + self._handle, buffer, len(buffer), None + ) + self.driver.check_error(err, "Failed to add module") def lazy_add_module(self, buffer): """ @@ -266,37 +289,41 @@ def lazy_add_module(self, buffer): The buffer should contain NVVM module IR either in the bitcode representation or in the text representation. """ - err = self.driver.nvvmLazyAddModuleToProgram(self._handle, buffer, - len(buffer), None) - self.driver.check_error(err, 'Failed to add module') + err = self.driver.nvvmLazyAddModuleToProgram( + self._handle, buffer, len(buffer), None + ) + self.driver.check_error(err, "Failed to add module") def verify(self): """ Run the NVVM verifier on all code added to the compilation unit. """ - err = self.driver.nvvmVerifyProgram(self._handle, self.n_options, - self.option_ptrs) - self._try_error(err, 'Failed to verify\n') + err = self.driver.nvvmVerifyProgram( + self._handle, self.n_options, self.option_ptrs + ) + self._try_error(err, "Failed to verify\n") def compile(self): """ Compile all modules added to the compilation unit and return the resulting PTX or LTO-IR (depending on the options). """ - err = self.driver.nvvmCompileProgram(self._handle, self.n_options, - self.option_ptrs) - self._try_error(err, 'Failed to compile\n') + err = self.driver.nvvmCompileProgram( + self._handle, self.n_options, self.option_ptrs + ) + self._try_error(err, "Failed to compile\n") # Get result result_size = c_size_t() - err = self.driver.nvvmGetCompiledResultSize(self._handle, - byref(result_size)) + err = self.driver.nvvmGetCompiledResultSize( + self._handle, byref(result_size) + ) - self._try_error(err, 'Failed to get size of compiled result.') + self._try_error(err, "Failed to get size of compiled result.") output_buffer = (c_char * result_size.value)() err = self.driver.nvvmGetCompiledResult(self._handle, output_buffer) - self._try_error(err, 'Failed to get compiled result.') + self._try_error(err, "Failed to get compiled result.") # Get log self.log = self.get_log() @@ -311,26 +338,37 @@ def _try_error(self, err, msg): def get_log(self): reslen = c_size_t() err = self.driver.nvvmGetProgramLogSize(self._handle, byref(reslen)) - self.driver.check_error(err, 'Failed to get compilation log size.') + self.driver.check_error(err, "Failed to get compilation log size.") if reslen.value > 1: logbuf = (c_char * reslen.value)() err = self.driver.nvvmGetProgramLog(self._handle, logbuf) - self.driver.check_error(err, 'Failed to get compilation log.') + self.driver.check_error(err, "Failed to get compilation log.") - return logbuf.value.decode('utf8') # populate log attribute + return logbuf.value.decode("utf8") # populate log attribute - return '' + return "" COMPUTE_CAPABILITIES = ( - (3, 5), (3, 7), - (5, 0), (5, 2), (5, 3), - (6, 0), (6, 1), (6, 2), - (7, 0), (7, 2), (7, 5), - (8, 0), (8, 6), (8, 7), (8, 9), + (3, 5), + (3, 7), + (5, 0), + (5, 2), + (5, 3), + (6, 0), + (6, 1), + (6, 2), + (7, 0), + (7, 2), + (7, 5), + (8, 0), + (8, 6), + (8, 7), + (8, 9), (9, 0), - (10, 0), (10, 1), + (10, 0), + (10, 1), (12, 0), ) @@ -358,20 +396,27 @@ def ccs_supported_by_ctk(ctk_version): try: # For supported versions, we look up the range of supported CCs min_cc, max_cc = CTK_SUPPORTED[ctk_version] - return tuple([cc for cc in COMPUTE_CAPABILITIES - if min_cc <= cc <= max_cc]) + return tuple( + [cc for cc in COMPUTE_CAPABILITIES if min_cc <= cc <= max_cc] + ) except KeyError: # For unsupported CUDA toolkit versions, all we can do is assume all # non-deprecated versions we are aware of are supported. - return tuple([cc for cc in COMPUTE_CAPABILITIES - if cc >= config.CUDA_DEFAULT_PTX_CC]) + return tuple( + [ + cc + for cc in COMPUTE_CAPABILITIES + if cc >= config.CUDA_DEFAULT_PTX_CC + ] + ) def get_supported_ccs(): try: from numba.cuda.cudadrv.runtime import runtime + cudart_version = runtime.get_version() - except: # noqa: E722 + except: # noqa: E722 # We can't support anything if there's an error getting the runtime # version (e.g. if it's not present or there's another issue) _supported_cc = () @@ -382,9 +427,11 @@ def get_supported_ccs(): if cudart_version < min_cudart: _supported_cc = () ctk_ver = f"{cudart_version[0]}.{cudart_version[1]}" - unsupported_ver = (f"CUDA Toolkit {ctk_ver} is unsupported by Numba - " - f"{min_cudart[0]}.{min_cudart[1]} is the minimum " - "required version.") + unsupported_ver = ( + f"CUDA Toolkit {ctk_ver} is unsupported by Numba - " + f"{min_cudart[0]}.{min_cudart[1]} is the minimum " + "required version." + ) warnings.warn(unsupported_ver) return _supported_cc @@ -403,8 +450,10 @@ def find_closest_arch(mycc): supported_ccs = NVVM().supported_ccs if not supported_ccs: - msg = "No supported GPU compute capabilities found. " \ - "Please check your cudatoolkit version matches your CUDA version." + msg = ( + "No supported GPU compute capabilities found. " + "Please check your cudatoolkit version matches your CUDA version." + ) raise NvvmSupportError(msg) for i, cc in enumerate(supported_ccs): @@ -415,8 +464,10 @@ def find_closest_arch(mycc): # Exceeded if i == 0: # CC lower than supported - msg = "GPU compute capability %d.%d is not supported" \ - "(requires >=%d.%d)" % (mycc + cc) + msg = ( + "GPU compute capability %d.%d is not supported" + "(requires >=%d.%d)" % (mycc + cc) + ) raise NvvmSupportError(msg) else: # return the previous CC @@ -427,16 +478,15 @@ def find_closest_arch(mycc): def get_arch_option(major, minor): - """Matches with the closest architecture option - """ + """Matches with the closest architecture option""" if config.FORCE_CUDA_CC: arch = config.FORCE_CUDA_CC else: arch = find_closest_arch((major, minor)) - return 'compute_%d%d' % arch + return "compute_%d%d" % arch -MISSING_LIBDEVICE_FILE_MSG = '''Missing libdevice file. +MISSING_LIBDEVICE_FILE_MSG = """Missing libdevice file. Please ensure you have a CUDA Toolkit 11.2 or higher. For CUDA 12, ``cuda-nvcc`` and ``cuda-nvrtc`` are required: @@ -445,7 +495,7 @@ def get_arch_option(major, minor): For CUDA 11, ``cudatoolkit`` is required: $ conda install -c conda-forge cudatoolkit "cuda-version>=11.2,<12.0" -''' +""" class LibDevice(object): @@ -466,7 +516,7 @@ def get(self): cas_nvvm = """ %cas_success = cmpxchg volatile {Ti}* %iptr, {Ti} %old, {Ti} %new monotonic monotonic %cas = extractvalue {{ {Ti}, i1 }} %cas_success, 0 -""" # noqa: E501 +""" # noqa: E501 # Translation of code from CUDA Programming Guide v6.5, section B.12 @@ -490,7 +540,7 @@ def get(self): %result = bitcast {Ti} %old to {T} ret {T} %result }} -""" # noqa: E501 +""" # noqa: E501 ir_numba_atomic_inc_template = """ define internal {T} @___numba_atomic_{Tu}_inc({T}* %iptr, {T} %val) alwaysinline {{ @@ -510,7 +560,7 @@ def get(self): done: ret {T} %old }} -""" # noqa: E501 +""" # noqa: E501 ir_numba_atomic_dec_template = """ define internal {T} @___numba_atomic_{Tu}_dec({T}* %iptr, {T} %val) alwaysinline {{ @@ -530,7 +580,7 @@ def get(self): done: ret {T} %old }} -""" # noqa: E501 +""" # noqa: E501 ir_numba_atomic_minmax_template = """ define internal {T} @___numba_atomic_{T}_{NAN}{FUNC}({T}* %ptr, {T} %val) alwaysinline {{ @@ -561,7 +611,7 @@ def get(self): done: ret {T} %ptrval }} -""" # noqa: E501 +""" # noqa: E501 def ir_cas(Ti): @@ -574,8 +624,15 @@ def ir_numba_atomic_binary(T, Ti, OP, FUNC): def ir_numba_atomic_minmax(T, Ti, NAN, OP, PTR_OR_VAL, FUNC): - params = dict(T=T, Ti=Ti, NAN=NAN, OP=OP, PTR_OR_VAL=PTR_OR_VAL, - FUNC=FUNC, CAS=ir_cas(Ti)) + params = dict( + T=T, + Ti=Ti, + NAN=NAN, + OP=OP, + PTR_OR_VAL=PTR_OR_VAL, + FUNC=FUNC, + CAS=ir_cas(Ti), + ) return ir_numba_atomic_minmax_template.format(**params) @@ -590,41 +647,115 @@ def ir_numba_atomic_dec(T, Tu): def llvm_replace(llvmir): replacements = [ - ('declare double @"___numba_atomic_double_add"(double* %".1", double %".2")', # noqa: E501 - ir_numba_atomic_binary(T='double', Ti='i64', OP='fadd', FUNC='add')), - ('declare float @"___numba_atomic_float_sub"(float* %".1", float %".2")', # noqa: E501 - ir_numba_atomic_binary(T='float', Ti='i32', OP='fsub', FUNC='sub')), - ('declare double @"___numba_atomic_double_sub"(double* %".1", double %".2")', # noqa: E501 - ir_numba_atomic_binary(T='double', Ti='i64', OP='fsub', FUNC='sub')), - ('declare i64 @"___numba_atomic_u64_inc"(i64* %".1", i64 %".2")', - ir_numba_atomic_inc(T='i64', Tu='u64')), - ('declare i64 @"___numba_atomic_u64_dec"(i64* %".1", i64 %".2")', - ir_numba_atomic_dec(T='i64', Tu='u64')), - ('declare float @"___numba_atomic_float_max"(float* %".1", float %".2")', # noqa: E501 - ir_numba_atomic_minmax(T='float', Ti='i32', NAN='', OP='nnan olt', - PTR_OR_VAL='ptr', FUNC='max')), - ('declare double @"___numba_atomic_double_max"(double* %".1", double %".2")', # noqa: E501 - ir_numba_atomic_minmax(T='double', Ti='i64', NAN='', OP='nnan olt', - PTR_OR_VAL='ptr', FUNC='max')), - ('declare float @"___numba_atomic_float_min"(float* %".1", float %".2")', # noqa: E501 - ir_numba_atomic_minmax(T='float', Ti='i32', NAN='', OP='nnan ogt', - PTR_OR_VAL='ptr', FUNC='min')), - ('declare double @"___numba_atomic_double_min"(double* %".1", double %".2")', # noqa: E501 - ir_numba_atomic_minmax(T='double', Ti='i64', NAN='', OP='nnan ogt', - PTR_OR_VAL='ptr', FUNC='min')), - ('declare float @"___numba_atomic_float_nanmax"(float* %".1", float %".2")', # noqa: E501 - ir_numba_atomic_minmax(T='float', Ti='i32', NAN='nan', OP='ult', - PTR_OR_VAL='', FUNC='max')), - ('declare double @"___numba_atomic_double_nanmax"(double* %".1", double %".2")', # noqa: E501 - ir_numba_atomic_minmax(T='double', Ti='i64', NAN='nan', OP='ult', - PTR_OR_VAL='', FUNC='max')), - ('declare float @"___numba_atomic_float_nanmin"(float* %".1", float %".2")', # noqa: E501 - ir_numba_atomic_minmax(T='float', Ti='i32', NAN='nan', OP='ugt', - PTR_OR_VAL='', FUNC='min')), - ('declare double @"___numba_atomic_double_nanmin"(double* %".1", double %".2")', # noqa: E501 - ir_numba_atomic_minmax(T='double', Ti='i64', NAN='nan', OP='ugt', - PTR_OR_VAL='', FUNC='min')), - ('immarg', '') + ( + 'declare double @"___numba_atomic_double_add"(double* %".1", double %".2")', # noqa: E501 + ir_numba_atomic_binary(T="double", Ti="i64", OP="fadd", FUNC="add"), + ), + ( + 'declare float @"___numba_atomic_float_sub"(float* %".1", float %".2")', # noqa: E501 + ir_numba_atomic_binary(T="float", Ti="i32", OP="fsub", FUNC="sub"), + ), + ( + 'declare double @"___numba_atomic_double_sub"(double* %".1", double %".2")', # noqa: E501 + ir_numba_atomic_binary(T="double", Ti="i64", OP="fsub", FUNC="sub"), + ), + ( + 'declare i64 @"___numba_atomic_u64_inc"(i64* %".1", i64 %".2")', + ir_numba_atomic_inc(T="i64", Tu="u64"), + ), + ( + 'declare i64 @"___numba_atomic_u64_dec"(i64* %".1", i64 %".2")', + ir_numba_atomic_dec(T="i64", Tu="u64"), + ), + ( + 'declare float @"___numba_atomic_float_max"(float* %".1", float %".2")', # noqa: E501 + ir_numba_atomic_minmax( + T="float", + Ti="i32", + NAN="", + OP="nnan olt", + PTR_OR_VAL="ptr", + FUNC="max", + ), + ), + ( + 'declare double @"___numba_atomic_double_max"(double* %".1", double %".2")', # noqa: E501 + ir_numba_atomic_minmax( + T="double", + Ti="i64", + NAN="", + OP="nnan olt", + PTR_OR_VAL="ptr", + FUNC="max", + ), + ), + ( + 'declare float @"___numba_atomic_float_min"(float* %".1", float %".2")', # noqa: E501 + ir_numba_atomic_minmax( + T="float", + Ti="i32", + NAN="", + OP="nnan ogt", + PTR_OR_VAL="ptr", + FUNC="min", + ), + ), + ( + 'declare double @"___numba_atomic_double_min"(double* %".1", double %".2")', # noqa: E501 + ir_numba_atomic_minmax( + T="double", + Ti="i64", + NAN="", + OP="nnan ogt", + PTR_OR_VAL="ptr", + FUNC="min", + ), + ), + ( + 'declare float @"___numba_atomic_float_nanmax"(float* %".1", float %".2")', # noqa: E501 + ir_numba_atomic_minmax( + T="float", + Ti="i32", + NAN="nan", + OP="ult", + PTR_OR_VAL="", + FUNC="max", + ), + ), + ( + 'declare double @"___numba_atomic_double_nanmax"(double* %".1", double %".2")', # noqa: E501 + ir_numba_atomic_minmax( + T="double", + Ti="i64", + NAN="nan", + OP="ult", + PTR_OR_VAL="", + FUNC="max", + ), + ), + ( + 'declare float @"___numba_atomic_float_nanmin"(float* %".1", float %".2")', # noqa: E501 + ir_numba_atomic_minmax( + T="float", + Ti="i32", + NAN="nan", + OP="ugt", + PTR_OR_VAL="", + FUNC="min", + ), + ), + ( + 'declare double @"___numba_atomic_double_nanmin"(double* %".1", double %".2")', # noqa: E501 + ir_numba_atomic_minmax( + T="double", + Ti="i64", + NAN="nan", + OP="ugt", + PTR_OR_VAL="", + FUNC="min", + ), + ), + ("immarg", ""), ] for decl, fn in replacements: @@ -639,19 +770,21 @@ def compile_ir(llvmir, **options): if isinstance(llvmir, str): llvmir = [llvmir] - if options.pop('fastmath', False): - options.update({ - 'ftz': True, - 'fma': True, - 'prec_div': False, - 'prec_sqrt': False, - }) + if options.pop("fastmath", False): + options.update( + { + "ftz": True, + "fma": True, + "prec_div": False, + "prec_sqrt": False, + } + ) cu = CompilationUnit(options) for mod in llvmir: mod = llvm_replace(mod) - cu.add_module(mod.encode('utf8')) + cu.add_module(mod.encode("utf8")) cu.verify() # We add libdevice following verification so that it is not subject to the @@ -671,16 +804,16 @@ def llvm150_to_70_ir(ir): """ buf = [] for line in ir.splitlines(): - if line.startswith('attributes #'): + if line.startswith("attributes #"): # Remove function attributes unsupported by LLVM 7.0 m = re_attributes_def.match(line) attrs = m.group(1).split() - attrs = ' '.join(a for a in attrs if a != 'willreturn') + attrs = " ".join(a for a in attrs if a != "willreturn") line = line.replace(m.group(1), attrs) buf.append(line) - return '\n'.join(buf) + return "\n".join(buf) def set_cuda_kernel(function): @@ -704,7 +837,7 @@ def set_cuda_kernel(function): mdvalue = ir.Constant(ir.IntType(32), 1) md = module.add_metadata((function, mdstr, mdvalue)) - nmd = cgutils.get_or_insert_named_metadata(module, 'nvvm.annotations') + nmd = cgutils.get_or_insert_named_metadata(module, "nvvm.annotations") nmd.add(md) # Create the used list @@ -713,13 +846,13 @@ def set_cuda_kernel(function): fnptr = function.bitcast(ptrty) - llvm_used = ir.GlobalVariable(module, usedty, 'llvm.used') - llvm_used.linkage = 'appending' - llvm_used.section = 'llvm.metadata' + llvm_used = ir.GlobalVariable(module, usedty, "llvm.used") + llvm_used.linkage = "appending" + llvm_used.section = "llvm.metadata" llvm_used.initializer = ir.Constant(usedty, [fnptr]) # Remove 'noinline' if it is present. - function.attributes.discard('noinline') + function.attributes.discard("noinline") def add_ir_version(mod): @@ -728,4 +861,4 @@ def add_ir_version(mod): i32 = ir.IntType(32) ir_versions = [i32(v) for v in NVVM().get_ir_version()] md_ver = mod.add_metadata(ir_versions) - mod.add_named_metadata('nvvmir.version', md_ver) + mod.add_named_metadata("nvvmir.version", md_ver) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index f072f6073..e68c3b745 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -15,13 +15,19 @@ from numba.core.types.functions import Function from numba.cuda.api import get_current_device from numba.cuda.args import wrap_arg -from numba.cuda.compiler import (compile_cuda, CUDACompiler, kernel_fixup, - ExternFunction) +from numba.cuda.compiler import ( + compile_cuda, + CUDACompiler, + kernel_fixup, + ExternFunction, +) from numba.cuda.cudadrv import driver from numba.cuda.cudadrv.devices import get_context from numba.cuda.descriptor import cuda_target -from numba.cuda.errors import (missing_launch_config_msg, - normalize_kernel_dimensions) +from numba.cuda.errors import ( + missing_launch_config_msg, + normalize_kernel_dimensions, +) from numba.cuda import types as cuda_types from numba.cuda.runtime.nrt import rtsys from numba.cuda.locks import module_init_lock @@ -31,17 +37,26 @@ from warnings import warn -cuda_fp16_math_funcs = ['hsin', 'hcos', - 'hlog', 'hlog10', - 'hlog2', - 'hexp', 'hexp10', - 'hexp2', - 'hsqrt', 'hrsqrt', - 'hfloor', 'hceil', - 'hrcp', 'hrint', - 'htrunc', 'hdiv'] - -reshape_funcs = ['nocopy_empty_reshape', 'numba_attempt_nocopy_reshape'] +cuda_fp16_math_funcs = [ + "hsin", + "hcos", + "hlog", + "hlog10", + "hlog2", + "hexp", + "hexp10", + "hexp2", + "hsqrt", + "hrsqrt", + "hfloor", + "hceil", + "hrcp", + "hrint", + "htrunc", + "hdiv", +] + +reshape_funcs = ["nocopy_empty_reshape", "numba_attempt_nocopy_reshape"] def get_cres_link_objects(cres): @@ -52,17 +67,16 @@ def get_cres_link_objects(cres): # List of calls into declared device functions device_func_calls = [ - (name, v) for name, v in cres.fndesc.typemap.items() if ( - isinstance(v, cuda_types.CUDADispatcher) - ) + (name, v) + for name, v in cres.fndesc.typemap.items() + if (isinstance(v, cuda_types.CUDADispatcher)) ] # List of tuples with SSA name of calls and corresponding signature call_signatures = [ (call.func.name, sig) - for call, sig in cres.fndesc.calltypes.items() if ( - isinstance(call, ir.Expr) and call.op == 'call' - ) + for call, sig in cres.fndesc.calltypes.items() + if (isinstance(call, ir.Expr) and call.op == "call") ] # Map SSA names to all invoked signatures @@ -94,10 +108,10 @@ def get_cres_link_objects(cres): class _Kernel(serialize.ReduceMixin): - ''' + """ CUDA Kernel specialized for a given set of argument types. When called, this object launches the kernel on the device. - ''' + """ NRT_functions = [ "NRT_Allocate", @@ -111,16 +125,27 @@ class _Kernel(serialize.ReduceMixin): "NRT_MemInfo_alloc_aligned", "NRT_Allocate_External", "NRT_decref", - "NRT_incref" + "NRT_incref", ] @global_compiler_lock - def __init__(self, py_func, argtypes, link=None, debug=False, - lineinfo=False, inline=False, fastmath=False, extensions=None, - max_registers=None, lto=False, opt=True, device=False): - + def __init__( + self, + py_func, + argtypes, + link=None, + debug=False, + lineinfo=False, + inline=False, + fastmath=False, + extensions=None, + max_registers=None, + lto=False, + opt=True, + device=False, + ): if device: - raise RuntimeError('Cannot compile a device function as a kernel') + raise RuntimeError("Cannot compile a device function as a kernel") super().__init__() @@ -145,24 +170,25 @@ def __init__(self, py_func, argtypes, link=None, debug=False, self.lineinfo = lineinfo self.extensions = extensions or [] - nvvm_options = { - 'fastmath': fastmath, - 'opt': 3 if opt else 0 - } + nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0} if debug: - nvvm_options['g'] = None + nvvm_options["g"] = None cc = get_current_device().compute_capability - cres = compile_cuda(self.py_func, types.void, self.argtypes, - debug=self.debug, - lineinfo=lineinfo, - inline=inline, - fastmath=fastmath, - nvvm_options=nvvm_options, - cc=cc, - max_registers=max_registers, - lto=lto) + cres = compile_cuda( + self.py_func, + types.void, + self.argtypes, + debug=self.debug, + lineinfo=lineinfo, + inline=inline, + fastmath=fastmath, + nvvm_options=nvvm_options, + cc=cc, + max_registers=max_registers, + lto=lto, + ) tgt_ctx = cres.target_context lib = cres.library kernel = lib.get_function(cres.fndesc.llvm_func_name) @@ -175,24 +201,25 @@ def __init__(self, py_func, argtypes, link=None, debug=False, asm = lib.get_asm_str() # A kernel needs cooperative launch if grid_sync is being used. - self.cooperative = 'cudaCGGetIntrinsicHandle' in asm + self.cooperative = "cudaCGGetIntrinsicHandle" in asm # We need to link against cudadevrt if grid sync is being used. if self.cooperative: lib.needs_cudadevrt = True - def link_to_library_functions(library_functions, library_path, - prefix=None): + def link_to_library_functions( + library_functions, library_path, prefix=None + ): """ Dynamically links to library functions by searching for their names in the specified library and linking to the corresponding source file. """ if prefix is not None: - library_functions = [f"{prefix}{fn}" for fn in - library_functions] + library_functions = [ + f"{prefix}{fn}" for fn in library_functions + ] - found_functions = [fn for fn in library_functions - if f'{fn}' in asm] + found_functions = [fn for fn in library_functions if f"{fn}" in asm] if found_functions: basedir = os.path.dirname(os.path.abspath(__file__)) @@ -202,11 +229,11 @@ def link_to_library_functions(library_functions, library_path, return found_functions # Link to the helper library functions if needed - link_to_library_functions(reshape_funcs, 'reshape_funcs.cu') + link_to_library_functions(reshape_funcs, "reshape_funcs.cu") # Link to the CUDA FP16 math library functions if needed - link_to_library_functions(cuda_fp16_math_funcs, - 'cpp_function_wrappers.cu', - '__numba_wrapper_') + link_to_library_functions( + cuda_fp16_math_funcs, "cpp_function_wrappers.cu", "__numba_wrapper_" + ) self.maybe_link_nrt(link, tgt_ctx, asm) @@ -240,15 +267,16 @@ def maybe_link_nrt(self, link, tgt_ctx, asm): all_nrt = "|".join(self.NRT_functions) pattern = ( - r'\.extern\s+\.func\s+(?:\s*\(.+\)\s*)?(' - + all_nrt + r')\s*\([^)]*\)\s*;' + r"\.extern\s+\.func\s+(?:\s*\(.+\)\s*)?(" + + all_nrt + + r")\s*\([^)]*\)\s*;" ) nrt_in_asm = re.findall(pattern, asm) basedir = os.path.dirname(os.path.abspath(__file__)) if nrt_in_asm: - nrt_path = os.path.join(basedir, 'runtime', 'nrt.cu') + nrt_path = os.path.join(basedir, "runtime", "nrt.cu") link.append(nrt_path) @property @@ -271,8 +299,17 @@ def argument_types(self): return tuple(self.signature.args) @classmethod - def _rebuild(cls, cooperative, name, signature, codelibrary, - debug, lineinfo, call_helper, extensions): + def _rebuild( + cls, + cooperative, + name, + signature, + codelibrary, + debug, + lineinfo, + call_helper, + extensions, + ): """ Rebuild an instance. """ @@ -300,10 +337,16 @@ def _reduce_states(self): Thread, block and shared memory configuration are serialized. Stream information is discarded. """ - return dict(cooperative=self.cooperative, name=self.entry_name, - signature=self.signature, codelibrary=self._codelibrary, - debug=self.debug, lineinfo=self.lineinfo, - call_helper=self.call_helper, extensions=self.extensions) + return dict( + cooperative=self.cooperative, + name=self.entry_name, + signature=self.signature, + codelibrary=self._codelibrary, + debug=self.debug, + lineinfo=self.lineinfo, + call_helper=self.call_helper, + extensions=self.extensions, + ) @module_init_lock def initialize_once(self, mod): @@ -331,73 +374,73 @@ def bind(self): @property def regs_per_thread(self): - ''' + """ The number of registers used by each thread for this kernel. - ''' + """ return self._codelibrary.get_cufunc().attrs.regs @property def const_mem_size(self): - ''' + """ The amount of constant memory used by this kernel. - ''' + """ return self._codelibrary.get_cufunc().attrs.const @property def shared_mem_per_block(self): - ''' + """ The amount of shared memory used per block for this kernel. - ''' + """ return self._codelibrary.get_cufunc().attrs.shared @property def max_threads_per_block(self): - ''' + """ The maximum allowable threads per block. - ''' + """ return self._codelibrary.get_cufunc().attrs.maxthreads @property def local_mem_per_thread(self): - ''' + """ The amount of local memory used per thread for this kernel. - ''' + """ return self._codelibrary.get_cufunc().attrs.local def inspect_llvm(self): - ''' + """ Returns the LLVM IR for this kernel. - ''' + """ return self._codelibrary.get_llvm_str() def inspect_asm(self, cc): - ''' + """ Returns the PTX code for this kernel. - ''' + """ return self._codelibrary.get_asm_str(cc=cc) def inspect_sass_cfg(self): - ''' + """ Returns the CFG of the SASS for this kernel. Requires nvdisasm to be available on the PATH. - ''' + """ return self._codelibrary.get_sass_cfg() def inspect_sass(self): - ''' + """ Returns the SASS code for this kernel. Requires nvdisasm to be available on the PATH. - ''' + """ return self._codelibrary.get_sass() def inspect_types(self, file=None): - ''' + """ Produce a dump of the Python source of this function annotated with the corresponding Numba IR and type information. The dump is written to *file*, or *sys.stdout* if *file* is *None*. - ''' + """ if self._type_annotation is None: raise ValueError("Type annotation is not available") @@ -405,12 +448,12 @@ def inspect_types(self, file=None): file = sys.stdout print("%s %s" % (self.entry_name, self.argument_types), file=file) - print('-' * 80, file=file) + print("-" * 80, file=file) print(self._type_annotation, file=file) - print('=' * 80, file=file) + print("=" * 80, file=file) def max_cooperative_grid_blocks(self, blockdim, dynsmemsize=0): - ''' + """ Calculates the maximum number of blocks that can be launched for this kernel in a cooperative grid in the current context, for the given block and dynamic shared memory sizes. @@ -419,15 +462,15 @@ def max_cooperative_grid_blocks(self, blockdim, dynsmemsize=0): a tuple for 2D or 3D blocks. :param dynsmemsize: Dynamic shared memory size in bytes. :return: The maximum number of blocks in the grid. - ''' + """ ctx = get_context() cufunc = self._codelibrary.get_cufunc() if isinstance(blockdim, tuple): blockdim = functools.reduce(lambda x, y: x * y, blockdim) - active_per_sm = ctx.get_active_blocks_per_multiprocessor(cufunc, - blockdim, - dynsmemsize) + active_per_sm = ctx.get_active_blocks_per_multiprocessor( + cufunc, blockdim, dynsmemsize + ) sm_count = ctx.device.MULTIPROCESSOR_COUNT return active_per_sm * sm_count @@ -443,7 +486,7 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0): excmem.memset(0, stream=stream) # Prepare arguments - retr = [] # hold functors for writeback + retr = [] # hold functors for writeback kernelargs = [] for t, v in zip(self.argument_types, args): @@ -457,46 +500,51 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0): stream_handle = stream and stream.handle or zero_stream # Invoke kernel - driver.launch_kernel(cufunc.handle, - *griddim, - *blockdim, - sharedmem, - stream_handle, - kernelargs, - cooperative=self.cooperative) + driver.launch_kernel( + cufunc.handle, + *griddim, + *blockdim, + sharedmem, + stream_handle, + kernelargs, + cooperative=self.cooperative, + ) if self.debug: driver.device_to_host(ctypes.addressof(excval), excmem, excsz) if excval.value != 0: # An error occurred def load_symbol(name): - mem, sz = cufunc.module.get_global_symbol("%s__%s__" % - (cufunc.name, - name)) + mem, sz = cufunc.module.get_global_symbol( + "%s__%s__" % (cufunc.name, name) + ) val = ctypes.c_int() driver.device_to_host(ctypes.addressof(val), mem, sz) return val.value - tid = [load_symbol("tid" + i) for i in 'zyx'] - ctaid = [load_symbol("ctaid" + i) for i in 'zyx'] + tid = [load_symbol("tid" + i) for i in "zyx"] + ctaid = [load_symbol("ctaid" + i) for i in "zyx"] code = excval.value exccls, exc_args, loc = self.call_helper.get_exception(code) # Prefix the exception message with the source location if loc is None: - locinfo = '' + locinfo = "" else: sym, filepath, lineno = loc filepath = os.path.abspath(filepath) - locinfo = 'In function %r, file %s, line %s, ' % (sym, - filepath, - lineno,) + locinfo = "In function %r, file %s, line %s, " % ( + sym, + filepath, + lineno, + ) # Prefix the exception message with the thread position prefix = "%stid=%s ctaid=%s" % (locinfo, tid, ctaid) if exc_args: - exc_args = ("%s: %s" % (prefix, exc_args[0]),) + \ - exc_args[1:] + exc_args = ("%s: %s" % (prefix, exc_args[0]),) + exc_args[ + 1: + ] else: - exc_args = prefix, + exc_args = (prefix,) raise exccls(*exc_args) # retrieve auto converted arrays @@ -510,11 +558,7 @@ def _prepare_args(self, ty, val, stream, retr, kernelargs): # map the arguments using any extension you've registered for extension in reversed(self.extensions): - ty, val = extension.prepare_args( - ty, - val, - stream=stream, - retr=retr) + ty, val = extension.prepare_args(ty, val, stream=stream, retr=retr) if isinstance(ty, types.Array): devary = wrap_arg(val).to_device(retr, stream) @@ -600,8 +644,9 @@ def _prepare_args(self, ty, val, stream, retr, kernelargs): class ForAll(object): def __init__(self, dispatcher, ntasks, tpb, stream, sharedmem): if ntasks < 0: - raise ValueError("Can't create ForAll with negative task count: %s" - % ntasks) + raise ValueError( + "Can't create ForAll with negative task count: %s" % ntasks + ) self.dispatcher = dispatcher self.ntasks = ntasks self.thread_per_block = tpb @@ -619,8 +664,9 @@ def __call__(self, *args): blockdim = self._compute_thread_per_block(specialized) griddim = (self.ntasks + blockdim - 1) // blockdim - return specialized[griddim, blockdim, self.stream, - self.sharedmem](*args) + return specialized[griddim, blockdim, self.stream, self.sharedmem]( + *args + ) def _compute_thread_per_block(self, dispatcher): tpb = self.thread_per_block @@ -635,7 +681,7 @@ def _compute_thread_per_block(self, dispatcher): kernel = next(iter(dispatcher.overloads.values())) kwargs = dict( func=kernel._codelibrary.get_cufunc(), - b2d_func=0, # dynamic-shared memory is constant to blksz + b2d_func=0, # dynamic-shared memory is constant to blksz memsize=self.sharedmem, blocksizelimit=1024, ) @@ -666,13 +712,16 @@ def __init__(self, dispatcher, griddim, blockdim, stream, sharedmem): min_grid_size = 128 grid_size = griddim[0] * griddim[1] * griddim[2] if grid_size < min_grid_size: - msg = (f"Grid size {grid_size} will likely result in GPU " - "under-utilization due to low occupancy.") + msg = ( + f"Grid size {grid_size} will likely result in GPU " + "under-utilization due to low occupancy." + ) warn(NumbaPerformanceWarning(msg)) def __call__(self, *args): - return self.dispatcher.call(args, self.griddim, self.blockdim, - self.stream, self.sharedmem) + return self.dispatcher.call( + args, self.griddim, self.blockdim, self.stream, self.sharedmem + ) class CUDACacheImpl(CacheImpl): @@ -697,6 +746,7 @@ class CUDACache(Cache): """ Implements a cache that saves and loads CUDA kernels and compile results. """ + _impl_class = CUDACacheImpl def load_overload(self, sig, target_context): @@ -704,12 +754,13 @@ def load_overload(self, sig, target_context): # initialized. To initialize the correct (i.e. CUDA) target, we need to # enforce that the current target is the CUDA target. from numba.core.target_extension import target_override - with target_override('cuda'): + + with target_override("cuda"): return super().load_overload(sig, target_context) class CUDADispatcher(Dispatcher, serialize.ReduceMixin): - ''' + """ CUDA Dispatcher object. When configured and called, the dispatcher will specialize itself for the given arguments (if no suitable specialized version already exists) & compute capability, and launch on the device @@ -717,7 +768,7 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin): Dispatcher objects are not to be constructed by the user, but instead are created using the :func:`numba.cuda.jit` decorator. - ''' + """ # Whether to fold named arguments and default values. Default values are # presently unsupported on CUDA, so we can leave this as False in all @@ -727,8 +778,9 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin): targetdescr = cuda_target def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler): - super().__init__(py_func, targetoptions=targetoptions, - pipeline_class=pipeline_class) + super().__init__( + py_func, targetoptions=targetoptions, pipeline_class=pipeline_class + ) # The following properties are for specialization of CUDADispatchers. A # specialized CUDADispatcher is one that is compiled for exactly one @@ -756,7 +808,7 @@ def configure(self, griddim, blockdim, stream=0, sharedmem=0): def __getitem__(self, args): if len(args) not in [2, 3, 4]: - raise ValueError('must specify at least the griddim and blockdim') + raise ValueError("must specify at least the griddim and blockdim") return self.configure(*args) def forall(self, ntasks, tpb=0, stream=0, sharedmem=0): @@ -783,7 +835,7 @@ def forall(self, ntasks, tpb=0, stream=0, sharedmem=0): @property def extensions(self): - ''' + """ A list of objects that must have a `prepare_args` function. When a specialized kernel is called, each argument will be passed through to the `prepare_args` (from the last object in this list to the @@ -799,17 +851,17 @@ def extensions(self): will be passed in turn to the next right-most `extension`. After all the extensions have been called, the resulting `(ty, val)` will be passed into Numba's default argument marshalling logic. - ''' - return self.targetoptions.get('extensions') + """ + return self.targetoptions.get("extensions") def __call__(self, *args, **kwargs): # An attempt to launch an unconfigured kernel raise ValueError(missing_launch_config_msg) def call(self, args, griddim, blockdim, stream, sharedmem): - ''' + """ Compile if necessary and invoke this kernel with *args*. - ''' + """ if self.specialized: kernel = next(iter(self.overloads.values())) else: @@ -832,28 +884,30 @@ def typeof_pyval(self, val): if cuda.is_cuda_array(val): # When typing, we don't need to synchronize on the array's # stream - this is done when the kernel is launched. - return typeof(cuda.as_cuda_array(val, sync=False), - Purpose.argument) + return typeof( + cuda.as_cuda_array(val, sync=False), Purpose.argument + ) else: raise def specialize(self, *args): - ''' + """ Create a new instance of this dispatcher specialized for the given *args*. - ''' + """ cc = get_current_device().compute_capability argtypes = tuple(self.typeof_pyval(a) for a in args) if self.specialized: - raise RuntimeError('Dispatcher already specialized') + raise RuntimeError("Dispatcher already specialized") specialization = self.specializations.get((cc, argtypes)) if specialization: return specialization targetoptions = self.targetoptions - specialization = CUDADispatcher(self.py_func, - targetoptions=targetoptions) + specialization = CUDADispatcher( + self.py_func, targetoptions=targetoptions + ) specialization.compile(argtypes) specialization.disable_compile() specialization._specialized = True @@ -868,7 +922,7 @@ def specialized(self): return self._specialized def get_regs_per_thread(self, signature=None): - ''' + """ Returns the number of registers used by each thread in this kernel for the device in the current context. @@ -877,17 +931,19 @@ def get_regs_per_thread(self, signature=None): kernel. :return: The number of registers used by the compiled variant of the kernel for the given signature and current device. - ''' + """ if signature is not None: return self.overloads[signature.args].regs_per_thread if self.specialized: return next(iter(self.overloads.values())).regs_per_thread else: - return {sig: overload.regs_per_thread - for sig, overload in self.overloads.items()} + return { + sig: overload.regs_per_thread + for sig, overload in self.overloads.items() + } def get_const_mem_size(self, signature=None): - ''' + """ Returns the size in bytes of constant memory used by this kernel for the device in the current context. @@ -897,17 +953,19 @@ def get_const_mem_size(self, signature=None): :return: The size in bytes of constant memory allocated by the compiled variant of the kernel for the given signature and current device. - ''' + """ if signature is not None: return self.overloads[signature.args].const_mem_size if self.specialized: return next(iter(self.overloads.values())).const_mem_size else: - return {sig: overload.const_mem_size - for sig, overload in self.overloads.items()} + return { + sig: overload.const_mem_size + for sig, overload in self.overloads.items() + } def get_shared_mem_per_block(self, signature=None): - ''' + """ Returns the size in bytes of statically allocated shared memory for this kernel. @@ -916,17 +974,19 @@ def get_shared_mem_per_block(self, signature=None): specialized kernel. :return: The amount of shared memory allocated by the compiled variant of the kernel for the given signature and current device. - ''' + """ if signature is not None: return self.overloads[signature.args].shared_mem_per_block if self.specialized: return next(iter(self.overloads.values())).shared_mem_per_block else: - return {sig: overload.shared_mem_per_block - for sig, overload in self.overloads.items()} + return { + sig: overload.shared_mem_per_block + for sig, overload in self.overloads.items() + } def get_max_threads_per_block(self, signature=None): - ''' + """ Returns the maximum allowable number of threads per block for this kernel. Exceeding this threshold will result in the kernel failing to launch. @@ -937,17 +997,19 @@ def get_max_threads_per_block(self, signature=None): :return: The maximum allowable threads per block for the compiled variant of the kernel for the given signature and current device. - ''' + """ if signature is not None: return self.overloads[signature.args].max_threads_per_block if self.specialized: return next(iter(self.overloads.values())).max_threads_per_block else: - return {sig: overload.max_threads_per_block - for sig, overload in self.overloads.items()} + return { + sig: overload.max_threads_per_block + for sig, overload in self.overloads.items() + } def get_local_mem_per_thread(self, signature=None): - ''' + """ Returns the size in bytes of local memory per thread for this kernel. @@ -956,14 +1018,16 @@ def get_local_mem_per_thread(self, signature=None): specialized kernel. :return: The amount of local memory allocated by the compiled variant of the kernel for the given signature and current device. - ''' + """ if signature is not None: return self.overloads[signature.args].local_mem_per_thread if self.specialized: return next(iter(self.overloads.values())).local_mem_per_thread else: - return {sig: overload.local_mem_per_thread - for sig, overload in self.overloads.items()} + return { + sig: overload.local_mem_per_thread + for sig, overload in self.overloads.items() + } def get_call_template(self, args, kws): # Originally copied from _DispatcherBase.get_call_template. This @@ -991,7 +1055,8 @@ def get_call_template(self, args, kws): name = "CallTemplate({0})".format(func_name) call_template = typing.make_concrete_template( - name, key=func_name, signatures=self.nopython_signatures) + name, key=func_name, signatures=self.nopython_signatures + ) pysig = utils.pysignature(self.py_func) return call_template, pysig, args, kws @@ -1006,33 +1071,36 @@ def compile_device(self, args, return_type=None): """ if args not in self.overloads: with self._compiling_counter: - - debug = self.targetoptions.get('debug') - lineinfo = self.targetoptions.get('lineinfo') - inline = self.targetoptions.get('inline') - fastmath = self.targetoptions.get('fastmath') + debug = self.targetoptions.get("debug") + lineinfo = self.targetoptions.get("lineinfo") + inline = self.targetoptions.get("inline") + fastmath = self.targetoptions.get("fastmath") nvvm_options = { - 'opt': 3 if self.targetoptions.get('opt') else 0, - 'fastmath': fastmath + "opt": 3 if self.targetoptions.get("opt") else 0, + "fastmath": fastmath, } if debug: - nvvm_options['g'] = None + nvvm_options["g"] = None cc = get_current_device().compute_capability - cres = compile_cuda(self.py_func, return_type, args, - debug=debug, - lineinfo=lineinfo, - inline=inline, - fastmath=fastmath, - nvvm_options=nvvm_options, - cc=cc) + cres = compile_cuda( + self.py_func, + return_type, + args, + debug=debug, + lineinfo=lineinfo, + inline=inline, + fastmath=fastmath, + nvvm_options=nvvm_options, + cc=cc, + ) self.overloads[args] = cres - cres.target_context.insert_user_function(cres.entry_point, - cres.fndesc, - [cres.library]) + cres.target_context.insert_user_function( + cres.entry_point, cres.fndesc, [cres.library] + ) else: cres = self.overloads[args] @@ -1044,10 +1112,10 @@ def add_overload(self, kernel, argtypes): self.overloads[argtypes] = kernel def compile(self, sig): - ''' + """ Compile and bind to the current context a version of this kernel specialized for the given signature. - ''' + """ argtypes, return_type = sigutils.normalize_signature(sig) assert return_type is None or return_type == types.none @@ -1080,15 +1148,15 @@ def compile(self, sig): return kernel def inspect_llvm(self, signature=None): - ''' + """ Return the LLVM IR for this kernel. :param signature: A tuple of argument types. :return: The LLVM IR for the given signature, or a dict of LLVM IR for all previously-encountered signatures. - ''' - device = self.targetoptions.get('device') + """ + device = self.targetoptions.get("device") if signature is not None: if device: return self.overloads[signature].library.get_llvm_str() @@ -1096,23 +1164,27 @@ def inspect_llvm(self, signature=None): return self.overloads[signature].inspect_llvm() else: if device: - return {sig: overload.library.get_llvm_str() - for sig, overload in self.overloads.items()} + return { + sig: overload.library.get_llvm_str() + for sig, overload in self.overloads.items() + } else: - return {sig: overload.inspect_llvm() - for sig, overload in self.overloads.items()} + return { + sig: overload.inspect_llvm() + for sig, overload in self.overloads.items() + } def inspect_asm(self, signature=None): - ''' + """ Return this kernel's PTX assembly code for for the device in the current context. :param signature: A tuple of argument types. :return: The PTX code for the given signature, or a dict of PTX codes for all previously-encountered signatures. - ''' + """ cc = get_current_device().compute_capability - device = self.targetoptions.get('device') + device = self.targetoptions.get("device") if signature is not None: if device: return self.overloads[signature].library.get_asm_str(cc) @@ -1120,14 +1192,18 @@ def inspect_asm(self, signature=None): return self.overloads[signature].inspect_asm(cc) else: if device: - return {sig: overload.library.get_asm_str(cc) - for sig, overload in self.overloads.items()} + return { + sig: overload.library.get_asm_str(cc) + for sig, overload in self.overloads.items() + } else: - return {sig: overload.inspect_asm(cc) - for sig, overload in self.overloads.items()} + return { + sig: overload.inspect_asm(cc) + for sig, overload in self.overloads.items() + } def inspect_sass_cfg(self, signature=None): - ''' + """ Return this kernel's CFG for the device in the current context. :param signature: A tuple of argument types. @@ -1137,18 +1213,20 @@ def inspect_sass_cfg(self, signature=None): The CFG for the device in the current context is returned. Requires nvdisasm to be available on the PATH. - ''' - if self.targetoptions.get('device'): - raise RuntimeError('Cannot get the CFG of a device function') + """ + if self.targetoptions.get("device"): + raise RuntimeError("Cannot get the CFG of a device function") if signature is not None: return self.overloads[signature].inspect_sass_cfg() else: - return {sig: defn.inspect_sass_cfg() - for sig, defn in self.overloads.items()} + return { + sig: defn.inspect_sass_cfg() + for sig, defn in self.overloads.items() + } def inspect_sass(self, signature=None): - ''' + """ Return this kernel's SASS assembly code for for the device in the current context. @@ -1159,22 +1237,23 @@ def inspect_sass(self, signature=None): SASS for the device in the current context is returned. Requires nvdisasm to be available on the PATH. - ''' - if self.targetoptions.get('device'): - raise RuntimeError('Cannot inspect SASS of a device function') + """ + if self.targetoptions.get("device"): + raise RuntimeError("Cannot inspect SASS of a device function") if signature is not None: return self.overloads[signature].inspect_sass() else: - return {sig: defn.inspect_sass() - for sig, defn in self.overloads.items()} + return { + sig: defn.inspect_sass() for sig, defn in self.overloads.items() + } def inspect_types(self, file=None): - ''' + """ Produce a dump of the Python source of this function annotated with the corresponding Numba IR and type information. The dump is written to *file*, or *sys.stdout* if *file* is *None*. - ''' + """ if file is None: file = sys.stdout @@ -1194,5 +1273,4 @@ def _reduce_states(self): Reduce the instance for serialization. Compiled definitions are discarded. """ - return dict(py_func=self.py_func, - targetoptions=self.targetoptions) + return dict(py_func=self.py_func, targetoptions=self.targetoptions) diff --git a/numba_cuda/numba/cuda/locks.py b/numba_cuda/numba/cuda/locks.py index 133e19d9b..5b03a06d8 100644 --- a/numba_cuda/numba/cuda/locks.py +++ b/numba_cuda/numba/cuda/locks.py @@ -6,10 +6,11 @@ def module_init_lock(func): - """Decorator to make sure initialization is invoked once for all threads. - """ + """Decorator to make sure initialization is invoked once for all threads.""" + @wraps(func) def wrapper(*args, **kwargs): with _module_init_lock: return func(*args, **kwargs) + return wrapper diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 208db1ab7..4e37bec5b 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -33,7 +33,6 @@ def get_hashable_handle_value(handle): class TestModuleCallbacksBasic(ContextResettingTestCase): - def test_basic(self): counter = 0 @@ -56,7 +55,7 @@ def kernel(): self.assertEqual(counter, 0) kernel[1, 1]() self.assertEqual(counter, 1) - kernel[1, 1]() # cached + kernel[1, 1]() # cached self.assertEqual(counter, 1) wipe_all_modules_in_context() @@ -85,9 +84,9 @@ def kernel(arg): pass self.assertEqual(counter, 0) - kernel[1, 1](42) # (int64)->() : module 1 + kernel[1, 1](42) # (int64)->() : module 1 self.assertEqual(counter, 1) - kernel[1, 1](100) # (int64)->() : module 1, cached + kernel[1, 1](100) # (int64)->() : module 1, cached self.assertEqual(counter, 1) kernel[1, 1](3.14) # (float64)->() : module 2 self.assertEqual(counter, 2) @@ -138,7 +137,6 @@ def kernel2(): class TestModuleCallbacksAPICompleteness(CUDATestCase): - def test_api(self): def setup(handle): pass @@ -150,13 +148,14 @@ def teardown(handle): (setup, teardown), (setup, None), (None, teardown), - (None, None) + (None, None), ] for setup, teardown in api_combo: with self.subTest(setup=setup, teardown=teardown): lib = CUSource( - "", setup_callback=setup, teardown_callback=teardown) + "", setup_callback=setup, teardown_callback=teardown + ) @cuda.jit(link=[lib]) def kernel(): @@ -166,7 +165,6 @@ def kernel(): class TestModuleCallbacks(CUDATestCase): - def setUp(self): super().setUp() @@ -189,7 +187,8 @@ def set_forty_two(handle): cuMemcpyHtoD(dptr, arr.ctypes.data, size) self.lib = CUSource( - module, setup_callback=set_forty_two, teardown_callback=None) + module, setup_callback=set_forty_two, teardown_callback=None + ) def test_decldevice_arg(self): get_num = cuda.declare_device("get_num", "int32()", link=[self.lib]) @@ -215,7 +214,6 @@ def kernel(arr): class TestMultithreadedCallbacks(CUDATestCase): - def test_concurrent_initialization(self): seen_mods = set() max_seen_mods = 0 @@ -242,7 +240,8 @@ def concurrent_compilation_launch(kernel): threads = [ threading.Thread( target=concurrent_compilation_launch, args=(kernel,) - ) for _ in range(4) + ) + for _ in range(4) ] for t in threads: t.start() @@ -274,11 +273,13 @@ def kernel(a): def concurrent_compilation_launch(): kernel[1, 1](42) # (int64)->() : module 1 - kernel[1, 1](9) # (int64)->() : module 1 from cache - kernel[1, 1](3.14) # (float64)->() : module 2 + kernel[1, 1](9) # (int64)->() : module 1 from cache + kernel[1, 1](3.14) # (float64)->() : module 2 - threads = [threading.Thread(target=concurrent_compilation_launch) - for _ in range(4)] + threads = [ + threading.Thread(target=concurrent_compilation_launch) + for _ in range(4) + ] for t in threads: t.start() for t in threads: @@ -286,8 +287,8 @@ def concurrent_compilation_launch(): wipe_all_modules_in_context() assert len(seen_mods) == 0 - self.assertEqual(max_seen_mods, 8) # 2 kernels per thread + self.assertEqual(max_seen_mods, 8) # 2 kernels per thread -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py b/numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py index 106ab0d30..0fe6177cb 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py @@ -10,6 +10,7 @@ try: import pynvjitlink # noqa: F401 + PYNVJITLINK_INSTALLED = True except ImportError: PYNVJITLINK_INSTALLED = False @@ -52,7 +53,7 @@ @unittest.skipIf( not config.CUDA_ENABLE_PYNVJITLINK or not TEST_BIN_DIR, - "pynvjitlink not enabled" + "pynvjitlink not enabled", ) @skip_on_cudasim("Linking unsupported in the simulator") class TestLinker(CUDATestCase): @@ -85,7 +86,6 @@ def test_nvjitlink_invalid_cc_type_error(self): PyNvJitLinker(cc=0) def test_nvjitlink_ptx_compile_options(self): - max_registers = (None, 32) lineinfo = (False, True) lto = (False, True) @@ -190,7 +190,7 @@ def test_nvjitlink_jit_with_linkable_code_lto_dump_assembly(self): files = [ test_device_functions_cu, test_device_functions_ltoir, - test_device_functions_fatbin_multi + test_device_functions_fatbin_multi, ] config.DUMP_ASSEMBLY = True @@ -228,7 +228,7 @@ def test_nvjitlink_jit_with_linkable_code_lto_dump_assembly_warn(self): for file in files: with self.subTest(file=file): with warnings.catch_warnings(record=True) as w: - with contextlib.redirect_stdout(None): # suppress other PTX + with contextlib.redirect_stdout(None): # suppress other PTX sig = "uint32(uint32, uint32)" add_from_numba = cuda.declare_device( "add_from_numba", sig @@ -243,8 +243,11 @@ def kernel(result): assert result[0] == 3 assert len(w) == 1 - self.assertIn("it is not optimizable at link time, and " - "`ignore_nonlto == True`", str(w[0].message)) + self.assertIn( + "it is not optimizable at link time, and " + "`ignore_nonlto == True`", + str(w[0].message), + ) config.DUMP_ASSEMBLY = False @@ -262,7 +265,7 @@ def kernel(): @unittest.skipIf( not PYNVJITLINK_INSTALLED or not TEST_BIN_DIR, - reason="pynvjitlink not enabled" + reason="pynvjitlink not enabled", ) class TestLinkerUsage(CUDATestCase): """Test that whether pynvjitlink can be enabled by both environment variable @@ -295,12 +298,12 @@ def kernel(result): def test_linker_enabled_envvar(self): env = os.environ.copy() - env['NUMBA_CUDA_ENABLE_PYNVJITLINK'] = "1" + env["NUMBA_CUDA_ENABLE_PYNVJITLINK"] = "1" run_in_subprocess(self.src.format(config=""), env=env) def test_linker_disabled_envvar(self): env = os.environ.copy() - env.pop('NUMBA_CUDA_ENABLE_PYNVJITLINK', None) + env.pop("NUMBA_CUDA_ENABLE_PYNVJITLINK", None) with self.assertRaisesRegex( AssertionError, "LTO and additional flags require PyNvJitLinker" ): @@ -310,19 +313,25 @@ def test_linker_disabled_envvar(self): def test_linker_enabled_config(self): env = os.environ.copy() - env.pop('NUMBA_CUDA_ENABLE_PYNVJITLINK', None) - run_in_subprocess(self.src.format( - config="config.CUDA_ENABLE_PYNVJITLINK = True"), env=env) + env.pop("NUMBA_CUDA_ENABLE_PYNVJITLINK", None) + run_in_subprocess( + self.src.format(config="config.CUDA_ENABLE_PYNVJITLINK = True"), + env=env, + ) def test_linker_disabled_config(self): env = os.environ.copy() - env.pop('NUMBA_CUDA_ENABLE_PYNVJITLINK', None) + env.pop("NUMBA_CUDA_ENABLE_PYNVJITLINK", None) with override_config("CUDA_ENABLE_PYNVJITLINK", False): with self.assertRaisesRegex( AssertionError, "LTO and additional flags require PyNvJitLinker" ): - run_in_subprocess(self.src.format( - config="config.CUDA_ENABLE_PYNVJITLINK = False"), env=env) + run_in_subprocess( + self.src.format( + config="config.CUDA_ENABLE_PYNVJITLINK = False" + ), + env=env, + ) if __name__ == "__main__": diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py index 0afa99115..76d732474 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py @@ -1,4 +1,4 @@ -from numba.tests.support import (override_config, captured_stdout) +from numba.tests.support import override_config, captured_stdout from numba.cuda.testing import skip_on_cudasim from numba import cuda from numba.core import types @@ -8,7 +8,7 @@ import unittest -@skip_on_cudasim('Simulator does not produce debug dumps') +@skip_on_cudasim("Simulator does not produce debug dumps") class TestCudaDebugInfo(CUDATestCase): """ These tests only checks the compiled PTX for debuginfo section @@ -49,7 +49,7 @@ def foo(x): self._check(foo, sig=(types.int32[:],), expect=True) def test_environment_override(self): - with override_config('CUDA_DEBUGINFO_DEFAULT', 1): + with override_config("CUDA_DEBUGINFO_DEFAULT", 1): # Using default value @cuda.jit(opt=False) def foo(x): @@ -86,7 +86,7 @@ def f(cond): llvm_ir = f.inspect_llvm(sig) # A varible name starting with "bool" in the debug metadata - pat = r'!DILocalVariable\(.*name:\s+\"bool' + pat = r"!DILocalVariable\(.*name:\s+\"bool" match = re.compile(pat).search(llvm_ir) self.assertIsNone(match, msg=llvm_ir) @@ -106,7 +106,7 @@ def f(x, y): mdnode_id = match.group(1) # verify the DIBasicType has correct encoding attribute DW_ATE_boolean - pat = rf'!{mdnode_id}\s+=\s+!DIBasicType\(.*DW_ATE_boolean' + pat = rf"!{mdnode_id}\s+=\s+!DIBasicType\(.*DW_ATE_boolean" match = re.compile(pat).search(llvm_ir) self.assertIsNotNone(match, msg=llvm_ir) @@ -133,14 +133,17 @@ def f(x): llvm_ir = f.inspect_llvm(sig) - defines = [line for line in llvm_ir.splitlines() - if 'define void @"_ZN6cudapy' in line] + defines = [ + line + for line in llvm_ir.splitlines() + if 'define void @"_ZN6cudapy' in line + ] # Make sure we only found one definition self.assertEqual(len(defines), 1) wrapper_define = defines[0] - self.assertIn('!dbg', wrapper_define) + self.assertIn("!dbg", wrapper_define) def test_debug_function_calls_internal_impl(self): # Calling a function in a module generated from an implementation @@ -198,16 +201,16 @@ def test_chained_device_function(self): debug_opts = itertools.product(*[(True, False)] * 3) for kernel_debug, f1_debug, f2_debug in debug_opts: - with self.subTest(kernel_debug=kernel_debug, - f1_debug=f1_debug, - f2_debug=f2_debug): - self._test_chained_device_function(kernel_debug, - f1_debug, - f2_debug) - - def _test_chained_device_function_two_calls(self, kernel_debug, f1_debug, - f2_debug): - + with self.subTest( + kernel_debug=kernel_debug, f1_debug=f1_debug, f2_debug=f2_debug + ): + self._test_chained_device_function( + kernel_debug, f1_debug, f2_debug + ) + + def _test_chained_device_function_two_calls( + self, kernel_debug, f1_debug, f2_debug + ): @cuda.jit(device=True, debug=f2_debug, opt=False) def f2(x): return x + 1 @@ -232,12 +235,12 @@ def test_chained_device_function_two_calls(self): debug_opts = itertools.product(*[(True, False)] * 3) for kernel_debug, f1_debug, f2_debug in debug_opts: - with self.subTest(kernel_debug=kernel_debug, - f1_debug=f1_debug, - f2_debug=f2_debug): - self._test_chained_device_function_two_calls(kernel_debug, - f1_debug, - f2_debug) + with self.subTest( + kernel_debug=kernel_debug, f1_debug=f1_debug, f2_debug=f2_debug + ): + self._test_chained_device_function_two_calls( + kernel_debug, f1_debug, f2_debug + ) def test_chained_device_three_functions(self): # Like test_chained_device_function, but with enough functions (three) @@ -278,13 +281,13 @@ def f(x, y): llvm_ir = f.inspect_llvm(sig) # extract the metadata node id from `types` field of DISubroutineType - pat = r'!DISubroutineType\(types:\s+!(\d+)\)' + pat = r"!DISubroutineType\(types:\s+!(\d+)\)" match = re.compile(pat).search(llvm_ir) self.assertIsNotNone(match, msg=llvm_ir) mdnode_id = match.group(1) # extract the metadata node ids from the flexible node of types - pat = rf'!{mdnode_id}\s+=\s+!{{\s+!(\d+),\s+!(\d+)\s+}}' + pat = rf"!{mdnode_id}\s+=\s+!{{\s+!(\d+),\s+!(\d+)\s+}}" match = re.compile(pat).search(llvm_ir) self.assertIsNotNone(match, msg=llvm_ir) mdnode_id1 = match.group(1) @@ -303,10 +306,10 @@ def test_kernel_args_types(self): def test_kernel_args_types_dump(self): # see issue#135 - with override_config('DUMP_LLVM', 1): + with override_config("DUMP_LLVM", 1): with captured_stdout(): self._test_kernel_args_types() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py b/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py index 0c7088b74..4ff973baa 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py @@ -3,8 +3,13 @@ import numpy as np -from numba.cuda.testing import (skip_if_curand_kernel_missing, skip_on_cudasim, - test_data_dir, unittest, CUDATestCase) +from numba.cuda.testing import ( + skip_if_curand_kernel_missing, + skip_on_cudasim, + test_data_dir, + unittest, + CUDATestCase, +) from numba import cuda, jit, float32, int32, types from numba.core.errors import TypingError from numba.tests.support import skip_unless_cffi @@ -12,9 +17,7 @@ class TestDeviceFunc(CUDATestCase): - def test_use_add2f(self): - @cuda.jit("float32(float32, float32)", device=True) def add2f(a, b): return a + b @@ -33,7 +36,6 @@ def use_add2f(ary): self.assertTrue(np.all(ary == exp), (ary, exp)) def test_indirect_add2f(self): - @cuda.jit("float32(float32, float32)", device=True) def add2f(a, b): return a + b @@ -74,12 +76,12 @@ def add(a, b): self._check_cpu_dispatcher(add) - @skip_on_cudasim('not supported in cudasim') + @skip_on_cudasim("not supported in cudasim") def test_cpu_dispatcher_invalid(self): # Test invalid usage # Explicit signature disables compilation, which also disable # compiling on CUDA. - @jit('(i4, i4)') + @jit("(i4, i4)") def add(a, b): return a + b @@ -95,7 +97,7 @@ def test_cpu_dispatcher_other_module(self): def add(a, b): return a + b - mymod = ModuleType(name='mymod') + mymod = ModuleType(name="mymod") mymod.add = add del add @@ -109,7 +111,7 @@ def add_kernel(ary): add_kernel[1, ary.size](ary) np.testing.assert_equal(expect, ary) - @skip_on_cudasim('not supported in cudasim') + @skip_on_cudasim("not supported in cudasim") def test_inspect_llvm(self): @cuda.jit(device=True) def foo(x, y): @@ -120,13 +122,13 @@ def foo(x, y): fname = cres.fndesc.mangled_name # Verify that the function name has "foo" in it as in the python name - self.assertIn('foo', fname) + self.assertIn("foo", fname) llvm = foo.inspect_llvm(args) # Check that the compiled function name is in the LLVM. self.assertIn(fname, llvm) - @skip_on_cudasim('not supported in cudasim') + @skip_on_cudasim("not supported in cudasim") def test_inspect_asm(self): @cuda.jit(device=True) def foo(x, y): @@ -137,13 +139,13 @@ def foo(x, y): fname = cres.fndesc.mangled_name # Verify that the function name has "foo" in it as in the python name - self.assertIn('foo', fname) + self.assertIn("foo", fname) ptx = foo.inspect_asm(args) # Check that the compiled function name is in the PTX self.assertIn(fname, ptx) - @skip_on_cudasim('not supported in cudasim') + @skip_on_cudasim("not supported in cudasim") def test_inspect_sass_disallowed(self): @cuda.jit(device=True) def foo(x, y): @@ -152,10 +154,11 @@ def foo(x, y): with self.assertRaises(RuntimeError) as raises: foo.inspect_sass((int32, int32)) - self.assertIn('Cannot inspect SASS of a device function', - str(raises.exception)) + self.assertIn( + "Cannot inspect SASS of a device function", str(raises.exception) + ) - @skip_on_cudasim('cudasim will allow calling any function') + @skip_on_cudasim("cudasim will allow calling any function") def test_device_func_as_kernel_disallowed(self): @cuda.jit(device=True) def f(): @@ -164,10 +167,12 @@ def f(): with self.assertRaises(RuntimeError) as raises: f[1, 1]() - self.assertIn('Cannot compile a device function as a kernel', - str(raises.exception)) + self.assertIn( + "Cannot compile a device function as a kernel", + str(raises.exception), + ) - @skip_on_cudasim('cudasim ignores casting by jit decorator signature') + @skip_on_cudasim("cudasim ignores casting by jit decorator signature") def test_device_casting(self): # Ensure that casts to the correct type are forced when calling a # device function with a signature. This test ensures that: @@ -176,20 +181,23 @@ def test_device_casting(self): # shouldn't # - We insert a cast when calling rgba, as opposed to failing to type. - @cuda.jit('int32(int32, int32, int32, int32)', device=True) + @cuda.jit("int32(int32, int32, int32, int32)", device=True) def rgba(r, g, b, a): - return (((r & 0xFF) << 16) | - ((g & 0xFF) << 8) | - ((b & 0xFF) << 0) | - ((a & 0xFF) << 24)) + return ( + ((r & 0xFF) << 16) + | ((g & 0xFF) << 8) + | ((b & 0xFF) << 0) + | ((a & 0xFF) << 24) + ) @cuda.jit def rgba_caller(x, channels): x[0] = rgba(channels[0], channels[1], channels[2], channels[3]) x = cuda.device_array(1, dtype=np.int32) - channels = cuda.to_device(np.asarray([1.0, 2.0, 3.0, 4.0], - dtype=np.float32)) + channels = cuda.to_device( + np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + ) rgba_caller[1, 1](x, channels) @@ -259,32 +267,31 @@ def rgba_caller(x, channels): }""") -@skip_on_cudasim('External functions unsupported in the simulator') +@skip_on_cudasim("External functions unsupported in the simulator") class TestDeclareDevice(CUDATestCase): - def check_api(self, decl): - self.assertEqual(decl.name, 'f1') + self.assertEqual(decl.name, "f1") self.assertEqual(decl.sig.args, (float32[:],)) self.assertEqual(decl.sig.return_type, int32) def test_declare_device_signature(self): - f1 = cuda.declare_device('f1', int32(float32[:])) + f1 = cuda.declare_device("f1", int32(float32[:])) self.check_api(f1) def test_declare_device_string(self): - f1 = cuda.declare_device('f1', 'int32(float32[:])') + f1 = cuda.declare_device("f1", "int32(float32[:])") self.check_api(f1) def test_bad_declare_device_tuple(self): - with self.assertRaisesRegex(TypeError, 'Return type'): - cuda.declare_device('f1', (float32[:],)) + with self.assertRaisesRegex(TypeError, "Return type"): + cuda.declare_device("f1", (float32[:],)) def test_bad_declare_device_string(self): - with self.assertRaisesRegex(TypeError, 'Return type'): - cuda.declare_device('f1', '(float32[:],)') + with self.assertRaisesRegex(TypeError, "Return type"): + cuda.declare_device("f1", "(float32[:],)") def test_link_cu_source(self): - times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu) + times2 = cuda.declare_device("times2", "int32(int32)", link=times2_cu) @cuda.jit def kernel(r, x): @@ -301,7 +308,7 @@ def kernel(r, x): def _test_link_multiple_sources(self, link_type): link = link_type([times2_cu, times4_cu]) - times4 = cuda.declare_device('times4', 'int32(int32)', link=link) + times4 = cuda.declare_device("times4", "int32(int32)", link=link) @cuda.jit def kernel(r, x): @@ -360,7 +367,7 @@ def kernel(x, seed): np.testing.assert_equal(x[0], 323845807) def test_declared_in_called_function(self): - times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu) + times2 = cuda.declare_device("times2", "int32(int32)", link=times2_cu) @cuda.jit def device_func(x): @@ -380,7 +387,7 @@ def kernel(r, x): np.testing.assert_equal(r, x * 2) def test_declared_in_called_function_twice(self): - times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu) + times2 = cuda.declare_device("times2", "int32(int32)", link=times2_cu) @cuda.jit def device_func_1(x): @@ -404,7 +411,7 @@ def kernel(r, x): np.testing.assert_equal(r, x * 2) def test_declared_in_called_function_two_calls(self): - times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu) + times2 = cuda.declare_device("times2", "int32(int32)", link=times2_cu) @cuda.jit def device_func(x): @@ -424,7 +431,7 @@ def kernel(r, x): np.testing.assert_equal(r, x * 6) def test_call_declared_function_twice(self): - times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu) + times2 = cuda.declare_device("times2", "int32(int32)", link=times2_cu) @cuda.jit def kernel(r, x): @@ -440,7 +447,7 @@ def kernel(r, x): np.testing.assert_equal(r, x * 6) def test_declared_in_called_function_and_parent(self): - times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu) + times2 = cuda.declare_device("times2", "int32(int32)", link=times2_cu) @cuda.jit def device_func(x): @@ -460,8 +467,8 @@ def kernel(r, x): np.testing.assert_equal(r, x * 4) def test_call_two_different_declared_functions(self): - times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu) - times3 = cuda.declare_device('times3', 'int32(int32)', link=times3_cu) + times2 = cuda.declare_device("times2", "int32(int32)", link=times2_cu) + times3 = cuda.declare_device("times3", "int32(int32)", link=times3_cu) @cuda.jit def kernel(r, x): @@ -477,5 +484,5 @@ def kernel(r, x): np.testing.assert_equal(r, x * 5) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_overload.py b/numba_cuda/numba/cuda/tests/cudapy/test_overload.py index 746ea3f4a..51752f732 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_overload.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_overload.py @@ -8,6 +8,7 @@ # Dummy function definitions to overload + def generic_func_1(): pass @@ -83,109 +84,124 @@ def default_values_and_kwargs(): # Overload implementations -@overload(generic_func_1, target='generic') + +@overload(generic_func_1, target="generic") def ol_generic_func_1(x): def impl(x): x[0] *= GENERIC_FUNCTION_1 + return impl -@overload(cuda_func_1, target='cuda') +@overload(cuda_func_1, target="cuda") def ol_cuda_func_1(x): def impl(x): x[0] *= CUDA_FUNCTION_1 + return impl -@overload(generic_func_2, target='generic') +@overload(generic_func_2, target="generic") def ol_generic_func_2(x): def impl(x): x[0] *= GENERIC_FUNCTION_2 + return impl -@overload(cuda_func_2, target='cuda') +@overload(cuda_func_2, target="cuda") def ol_cuda_func(x): def impl(x): x[0] *= CUDA_FUNCTION_2 + return impl -@overload(generic_calls_generic, target='generic') +@overload(generic_calls_generic, target="generic") def ol_generic_calls_generic(x): def impl(x): x[0] *= GENERIC_CALLS_GENERIC generic_func_1(x) + return impl -@overload(generic_calls_cuda, target='generic') +@overload(generic_calls_cuda, target="generic") def ol_generic_calls_cuda(x): def impl(x): x[0] *= GENERIC_CALLS_CUDA cuda_func_1(x) + return impl -@overload(cuda_calls_generic, target='cuda') +@overload(cuda_calls_generic, target="cuda") def ol_cuda_calls_generic(x): def impl(x): x[0] *= CUDA_CALLS_GENERIC generic_func_1(x) + return impl -@overload(cuda_calls_cuda, target='cuda') +@overload(cuda_calls_cuda, target="cuda") def ol_cuda_calls_cuda(x): def impl(x): x[0] *= CUDA_CALLS_CUDA cuda_func_1(x) + return impl -@overload(target_overloaded, target='generic') +@overload(target_overloaded, target="generic") def ol_target_overloaded_generic(x): def impl(x): x[0] *= GENERIC_TARGET_OL + return impl -@overload(target_overloaded, target='cuda') +@overload(target_overloaded, target="cuda") def ol_target_overloaded_cuda(x): def impl(x): x[0] *= CUDA_TARGET_OL + return impl -@overload(generic_calls_target_overloaded, target='generic') +@overload(generic_calls_target_overloaded, target="generic") def ol_generic_calls_target_overloaded(x): def impl(x): x[0] *= GENERIC_CALLS_TARGET_OL target_overloaded(x) + return impl -@overload(cuda_calls_target_overloaded, target='cuda') +@overload(cuda_calls_target_overloaded, target="cuda") def ol_cuda_calls_target_overloaded(x): def impl(x): x[0] *= CUDA_CALLS_TARGET_OL target_overloaded(x) + return impl -@overload(target_overloaded_calls_target_overloaded, target='generic') +@overload(target_overloaded_calls_target_overloaded, target="generic") def ol_generic_calls_target_overloaded_generic(x): def impl(x): x[0] *= GENERIC_TARGET_OL_CALLS_TARGET_OL target_overloaded(x) + return impl -@overload(target_overloaded_calls_target_overloaded, target='cuda') +@overload(target_overloaded_calls_target_overloaded, target="cuda") def ol_generic_calls_target_overloaded_cuda(x): def impl(x): x[0] *= CUDA_TARGET_OL_CALLS_TARGET_OL target_overloaded(x) + return impl @@ -193,10 +209,11 @@ def impl(x): def ol_default_values_and_kwargs(out, x, y=5, z=6): def impl(out, x, y=5, z=6): out[0], out[1] = x + y, z + return impl -@skip_on_cudasim('Overloading not supported in cudasim') +@skip_on_cudasim("Overloading not supported in cudasim") class TestOverload(CUDATestCase): def check_overload(self, kernel, expected): x = np.ones(1, dtype=np.int32) @@ -311,7 +328,7 @@ def test_overload_attribute_target(self): MyDummy, MyDummyType = self.make_dummy_type() mydummy_type = typeof(MyDummy()) - @overload_attribute(MyDummyType, 'cuda_only', target='cuda') + @overload_attribute(MyDummyType, "cuda_only", target="cuda") def ov_dummy_cuda_attr(obj): def imp(obj): return 42 @@ -330,6 +347,7 @@ def imp(obj): msg = "Unknown attribute 'cuda_only'" with self.assertRaisesRegex(TypingError, msg): + @njit(types.int64(mydummy_type)) def illegal_target_attr_use(x): return x.cuda_only @@ -345,14 +363,15 @@ def test_default_values_and_kwargs(self): """ Test default values and kwargs. """ + @cuda.jit() def kernel(a, b, out): default_values_and_kwargs(out, a, z=b) out = np.empty(2, dtype=np.int64) - kernel[1,1](1, 2, out) + kernel[1, 1](1, 2, out) self.assertEqual(tuple(out), (6, 2)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/numba_cuda/numba/cuda/tests/nrt/test_nrt.py b/numba_cuda/numba/cuda/tests/nrt/test_nrt.py index acdaa0c8d..38d67f7ef 100644 --- a/numba_cuda/numba/cuda/tests/nrt/test_nrt.py +++ b/numba_cuda/numba/cuda/tests/nrt/test_nrt.py @@ -26,7 +26,7 @@ def g(): x = np.empty(10, np.int64) f(x) - g[1,1]() + g[1, 1]() cuda.synchronize() def test_nrt_ptx_contains_refcount(self): @@ -39,7 +39,7 @@ def g(): x = np.empty(10, np.int64) f(x) - g[1,1]() + g[1, 1]() ptx = next(iter(g.inspect_asm().values())) @@ -72,13 +72,12 @@ def g(out_ary): out_ary = np.zeros(1, dtype=np.int64) - g[1,1](out_ary) + g[1, 1](out_ary) self.assertEqual(out_ary[0], 1) class TestNrtStatistics(CUDATestCase): - def setUp(self): self._stream = cuda.default_stream() # Store the current stats state @@ -126,12 +125,11 @@ def foo(): # Check env var explicitly being set works env = os.environ.copy() - env['NUMBA_CUDA_NRT_STATS'] = "1" - env['NUMBA_CUDA_ENABLE_NRT'] = "1" + env["NUMBA_CUDA_NRT_STATS"] = "1" + env["NUMBA_CUDA_ENABLE_NRT"] = "1" run_in_subprocess(src, env=env) def check_env_var_off(self, env): - src = """if 1: from numba import cuda import numpy as np @@ -152,27 +150,26 @@ def foo(): def test_stats_env_var_explicit_off(self): # Checks that explicitly turning the stats off via the env var works. env = os.environ.copy() - env['NUMBA_CUDA_NRT_STATS'] = "0" + env["NUMBA_CUDA_NRT_STATS"] = "0" self.check_env_var_off(env) def test_stats_env_var_default_off(self): # Checks that the env var not being set is the same as "off", i.e. # default for Numba is off. env = os.environ.copy() - env.pop('NUMBA_CUDA_NRT_STATS', None) + env.pop("NUMBA_CUDA_NRT_STATS", None) self.check_env_var_off(env) def test_stats_status_toggle(self): - @cuda.jit def foo(): tmp = np.ones(3) - arr = np.arange(5 * tmp[0]) # noqa: F841 + arr = np.arange(5 * tmp[0]) # noqa: F841 return None with ( - override_config('CUDA_ENABLE_NRT', True), - override_config('CUDA_NRT_STATS', True) + override_config("CUDA_ENABLE_NRT", True), + override_config("CUDA_NRT_STATS", True), ): # Switch on stats rtsys.memsys_enable_stats() @@ -218,9 +215,9 @@ def test_rtsys_stats_query_raises_exception_when_disabled(self): def test_nrt_explicit_stats_query_raises_exception_when_disabled(self): # Checks the various memsys_get_stats functions raise if queried when # the stats counters are disabled. - method_variations = ('alloc', 'free', 'mi_alloc', 'mi_free') + method_variations = ("alloc", "free", "mi_alloc", "mi_free") for meth in method_variations: - stats_func = getattr(rtsys, f'memsys_get_stats_{meth}') + stats_func = getattr(rtsys, f"memsys_get_stats_{meth}") with self.subTest(stats_func=stats_func): # Turn stats off rtsys.memsys_disable_stats() @@ -230,5 +227,5 @@ def test_nrt_explicit_stats_query_raises_exception_when_disabled(self): self.assertIn("NRT stats are disabled.", str(raises.exception)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py b/numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py index 1e9b7aa30..27811bdae 100644 --- a/numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +++ b/numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py @@ -9,7 +9,6 @@ class TestNrtRefCt(EnableNRTStatsMixin, CUDATestCase): - def setUp(self): super(TestNrtRefCt, self).setUp() @@ -19,7 +18,7 @@ def tearDown(self): def run(self, result=None): with ( override_config("CUDA_ENABLE_NRT", True), - override_config('CUDA_NRT_STATS', True) + override_config("CUDA_NRT_STATS", True), ): super(TestNrtRefCt, self).run(result) @@ -33,7 +32,7 @@ def test_no_return(self): @cuda.jit def kernel(): for i in range(n): - temp = np.empty(2) # noqa: F841 + temp = np.empty(2) # noqa: F841 return None init_stats = rtsys.get_allocation_stats() @@ -49,14 +48,13 @@ def test_escaping_var_init_in_loop(self): @cuda.jit def g(n): - x = np.empty((n, 2)) for i in range(n): y = x[i] for i in range(n): - y = x[i] # noqa: F841 + y = x[i] # noqa: F841 return None @@ -70,6 +68,7 @@ def test_invalid_computation_of_lifetime(self): """ Test issue #1573 """ + @cuda.jit def if_with_allocation_and_initialization(arr1, test1): tmp_arr = np.empty_like(arr1) @@ -85,13 +84,15 @@ def if_with_allocation_and_initialization(arr1, test1): init_stats = rtsys.get_allocation_stats() if_with_allocation_and_initialization[1, 1](arr, False) cur_stats = rtsys.get_allocation_stats() - self.assertEqual(cur_stats.alloc - init_stats.alloc, - cur_stats.free - init_stats.free) + self.assertEqual( + cur_stats.alloc - init_stats.alloc, cur_stats.free - init_stats.free + ) def test_del_at_beginning_of_loop(self): """ Test issue #1734 """ + @cuda.jit def f(arr): res = 0 @@ -108,9 +109,10 @@ def f(arr): init_stats = rtsys.get_allocation_stats() f[1, 1](arr) cur_stats = rtsys.get_allocation_stats() - self.assertEqual(cur_stats.alloc - init_stats.alloc, - cur_stats.free - init_stats.free) + self.assertEqual( + cur_stats.alloc - init_stats.alloc, cur_stats.free - init_stats.free + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From 28d9dc83da17cb50636bd65854226538c027e4d1 Mon Sep 17 00:00:00 2001 From: isVoid Date: Thu, 10 Apr 2025 13:38:50 -0700 Subject: [PATCH 36/36] apply compile lock to make sure modules are not compiled more than one time --- numba_cuda/numba/cuda/dispatcher.py | 1 + numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index e68c3b745..72344aa22 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -1111,6 +1111,7 @@ def add_overload(self, kernel, argtypes): self._insert(c_sig, kernel, cuda=True) self.overloads[argtypes] = kernel + @global_compiler_lock def compile(self, sig): """ Compile and bind to the current context a version of this kernel diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py index 4e37bec5b..304f2131a 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py @@ -250,7 +250,7 @@ def concurrent_compilation_launch(kernel): wipe_all_modules_in_context() self.assertEqual(len(seen_mods), 0) - self.assertEqual(max_seen_mods, 4) + self.assertEqual(max_seen_mods, 1) # one moduled shared across threads def test_concurrent_initialization_different_args(self): seen_mods = set() @@ -287,7 +287,7 @@ def concurrent_compilation_launch(): wipe_all_modules_in_context() assert len(seen_mods) == 0 - self.assertEqual(max_seen_mods, 8) # 2 kernels per thread + self.assertEqual(max_seen_mods, 2) # two modules shared across threads if __name__ == "__main__":