diff --git a/docs/source/reference/kernel.rst b/docs/source/reference/kernel.rst index 64a452947..9b8e15183 100644 --- a/docs/source/reference/kernel.rst +++ b/docs/source/reference/kernel.rst @@ -57,6 +57,87 @@ This is similar to launch configuration in CUDA C/C++: .. note:: The order of ``stream`` and ``sharedmem`` are reversed in Numba compared to in CUDA C/C++. +Launch configuration access (advanced) +-------------------------------------- + +The configured-launch object returned by ``dispatcher[griddim, blockdim, ...]`` +exposes launch metadata and callback hooks that can be consumed by advanced +tooling (for example, rewrite passes and extension integrations). + +Compile-time launch-config access +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The active launch configuration is available only while compilation is in +progress for a kernel launch. + +.. note:: This is compile-time state. If a launch reuses an existing compiled + kernel for the same cache key, no compilation occurs and no compile-time + launch config is set. For launch-config-sensitive kernels, a different + launch configuration can trigger a separate compilation/specialization; see + :ref:`cuda-launch-config-sensitive-compilation`. + +.. code-block:: python + + from numba import cuda + from numba.cuda import launchconfig + + @cuda.jit + def f(x): + x[0] = 1 + + arr = cuda.device_array(1, dtype="i4") + with launchconfig.capture_compile_config(f) as capture: + f[1, 1](arr) # first launch triggers compilation + + cfg = capture["config"] + print(cfg.griddim, cfg.blockdim, cfg.sharedmem) + +For use inside compilation passes: + +.. code-block:: python + + from numba.cuda import launchconfig + + cfg = launchconfig.ensure_current_launch_config() + print(cfg.griddim, cfg.blockdim, cfg.sharedmem, cfg.args) + +Pre-launch callbacks +~~~~~~~~~~~~~~~~~~~~ + +Configured launches expose ``pre_launch_callbacks``. Each callback is called +immediately before launch with ``(kernel, launch_config)``. + +.. code-block:: python + + cfg = f[1, 1] + + def log_launch(kernel, cfg): + print(cfg.griddim, cfg.blockdim) + + cfg.pre_launch_callbacks.append(log_launch) + cfg(arr) + +.. _cuda-launch-config-sensitive-compilation: + +Launch-config-sensitive compilation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If code generation depends on launch configuration (for example, a rewrite +pass that inspects ``cfg.blockdim`` and emits different IR), mark the active +launch as launch-config sensitive: + +.. code-block:: python + + cfg = launchconfig.ensure_current_launch_config() + cfg.mark_kernel_as_launch_config_sensitive() + +This instructs the dispatcher/cache machinery to avoid unsafe reuse across +different launch configurations for that kernel path. + +.. note:: Launch-config-sensitive cache keying for ``cache=True`` applies to + kernels that are otherwise disk-cacheable. Kernels that require external + linking files are not currently disk-cacheable. + Dispatcher objects also provide several utility methods for inspection and creating a specialized instance: diff --git a/numba_cuda/numba/cuda/cext/_dispatcher.cpp b/numba_cuda/numba/cuda/cext/_dispatcher.cpp index d6d9a304c..fa6f79001 100644 --- a/numba_cuda/numba/cuda/cext/_dispatcher.cpp +++ b/numba_cuda/numba/cuda/cext/_dispatcher.cpp @@ -13,6 +13,81 @@ #include "traceback.h" #include "typeconv.hpp" +static Py_tss_t launch_config_tss_key = Py_tss_NEEDS_INIT; +static const char *launch_config_kw = "__numba_cuda_launch_config__"; + +static int +launch_config_tss_init(void) +{ + if (PyThread_tss_create(&launch_config_tss_key) != 0) { + PyErr_SetString(PyExc_RuntimeError, + "Failed to initialize launch config TLS"); + return -1; + } + return 0; +} + +static PyObject * +launch_config_get_borrowed(void) +{ + return (PyObject *) PyThread_tss_get(&launch_config_tss_key); +} + +static int +launch_config_set(PyObject *obj) +{ + PyObject *old = (PyObject *) PyThread_tss_get(&launch_config_tss_key); + if (obj != NULL) { + Py_INCREF(obj); + } + if (PyThread_tss_set(&launch_config_tss_key, (void *) obj) != 0) { + Py_XDECREF(obj); + PyErr_SetString(PyExc_RuntimeError, + "Failed to set launch config TLS"); + return -1; + } + Py_XDECREF(old); + return 0; +} + +class LaunchConfigGuard { +public: + explicit LaunchConfigGuard(PyObject *value) + : prev(NULL), active(false), requested(value != NULL) + { + if (!requested) { + return; + } + prev = launch_config_get_borrowed(); + Py_XINCREF(prev); + if (launch_config_set(value) != 0) { + Py_XDECREF(prev); + prev = NULL; + return; + } + active = true; + } + + bool failed(void) const + { + return requested && !active; + } + + ~LaunchConfigGuard(void) + { + if (!active) { + return; + } + launch_config_set(prev); + Py_XDECREF(prev); + } + +private: + PyObject *prev; + bool active; + bool requested; +}; + /* * Notes on the C_TRACE macro: * @@ -840,6 +915,7 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws) PyObject *cfunc; PyThreadState *ts = PyThreadState_Get(); PyObject *locals = NULL; + PyObject *launch_config = NULL; /* If compilation is enabled, ensure that an exact match is found and if * not compile one */ @@ -855,9 +931,26 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws) goto CLEANUP; } } + if (kws != NULL) { + launch_config = PyDict_GetItemString(kws, launch_config_kw); + if (launch_config != NULL) { + Py_INCREF(launch_config); + if (PyDict_DelItemString(kws, launch_config_kw) < 0) { + Py_DECREF(launch_config); + launch_config = NULL; + goto CLEANUP; + } + if (launch_config == Py_None) { + Py_DECREF(launch_config); + launch_config = NULL; + } + } + } if (self->fold_args) { - if (find_named_args(self, &args, &kws)) + if (find_named_args(self, &args, &kws)) { + Py_XDECREF(launch_config); return NULL; + } } else Py_INCREF(args); @@ -913,6 +1006,11 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws) } else if (matches == 0) { /* No matching definition */ if (self->can_compile) { + LaunchConfigGuard guard(launch_config); + if (guard.failed()) { + retval = NULL; + goto CLEANUP; + } retval = cuda_compile_only(self, args, kws, locals); } else if (self->fallbackdef) { /* Have object fallback */ @@ -924,6 +1022,11 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws) } } else if (self->can_compile) { /* Ambiguous, but are allowed to compile */ + LaunchConfigGuard guard(launch_config); + if (guard.failed()) { + retval = NULL; + goto CLEANUP; + } retval = cuda_compile_only(self, args, kws, locals); } else { /* Ambiguous */ @@ -935,6 +1038,7 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws) if (tys != prealloc) delete[] tys; Py_DECREF(args); + Py_XDECREF(launch_config); return retval; } @@ -1040,10 +1144,23 @@ static PyObject *compute_fingerprint(PyObject *self, PyObject *args) return typeof_compute_fingerprint(val); } +static PyObject * +get_current_launch_config(PyObject *self, PyObject *args) +{ + PyObject *config = launch_config_get_borrowed(); + if (config == NULL) { + Py_RETURN_NONE; + } + Py_INCREF(config); + return config; +} + static PyMethodDef ext_methods[] = { #define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL } declmethod(typeof_init), declmethod(compute_fingerprint), + { "get_current_launch_config", (PyCFunction)get_current_launch_config, + METH_NOARGS, NULL }, { NULL }, #undef declmethod }; @@ -1055,6 +1172,10 @@ MOD_INIT(_dispatcher) { if (m == NULL) return MOD_ERROR_VAL; + if (launch_config_tss_init() != 0) { + return MOD_ERROR_VAL; + } + DispatcherType.tp_new = PyType_GenericNew; if (PyType_Ready(&DispatcherType) < 0) { return MOD_ERROR_VAL; diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py index cbe1dfac2..4f666209d 100644 --- a/numba_cuda/numba/cuda/compiler.py +++ b/numba_cuda/numba/cuda/compiler.py @@ -15,6 +15,7 @@ from numba.cuda.core.interpreter import Interpreter from numba.cuda import cgutils, typing, lowering, nvvmutils, utils +from numba.cuda import launchconfig from numba.cuda.api import get_current_device from numba.cuda.codegen import ExternalCodeLibrary @@ -398,6 +399,14 @@ def run_pass(self, state): """ lowered = state["cr"] signature = typing.signature(state.return_type, *state.args) + launch_cfg = launchconfig.current_launch_config() + if ( + launch_cfg is not None + and launch_cfg.is_kernel_launch_config_sensitive() + ): + if state.metadata is None: + state.metadata = {} + state.metadata["launch_config_sensitive"] = True state.cr = cuda_compile_result( typing_context=state.typingctx, @@ -408,6 +417,9 @@ def run_pass(self, state): call_helper=lowered.call_helper, signature=signature, fndesc=lowered.fndesc, + # Preserve metadata populated by rewrite passes (e.g. launch-config + # sensitivity) so downstream consumers can act on it. + metadata=state.metadata, ) return True diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index aa8fb76be..65488f5f5 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -20,6 +20,7 @@ from numba.cuda.core import errors from numba.cuda import serialize, utils from numba import cuda +from numba.cuda import launchconfig from numba.cuda.np import numpy_support from numba.cuda.core.compiler_lock import global_compiler_lock @@ -55,6 +56,7 @@ import numba.cuda.core.event as ev from numba.cuda.cext import _dispatcher +_LAUNCH_CONFIG_KW = "__numba_cuda_launch_config__" cuda_fp16_math_funcs = [ "hsin", @@ -165,6 +167,10 @@ def __init__( max_registers=max_registers, lto=lto, ) + self.launch_config_sensitive = bool( + getattr(cres, "metadata", None) + and cres.metadata.get("launch_config_sensitive", False) + ) tgt_ctx = cres.target_context lib = cres.library kernel = lib.get_function(cres.fndesc.llvm_func_name) @@ -297,6 +303,7 @@ def _rebuild( call_helper, extensions, shared_memory_carveout=None, + launch_config_sensitive=False, ): """ Rebuild an instance. @@ -316,6 +323,7 @@ def _rebuild( instance.call_helper = call_helper instance.extensions = extensions instance.shared_memory_carveout = shared_memory_carveout + instance.launch_config_sensitive = launch_config_sensitive return instance def _reduce_states(self): @@ -336,6 +344,7 @@ def _reduce_states(self): call_helper=self.call_helper, extensions=self.extensions, shared_memory_carveout=self.shared_memory_carveout, + launch_config_sensitive=self.launch_config_sensitive, ) def _parse_carveout(self, carveout): @@ -693,6 +702,9 @@ def __init__(self, dispatcher, griddim, blockdim, stream, sharedmem): self.blockdim = blockdim self.stream = driver._to_core_stream(stream) self.sharedmem = sharedmem + self.pre_launch_callbacks = [] + self.args = None + self._kernel_launch_config_sensitive = None if ( config.CUDA_LOW_OCCUPANCY_WARNINGS @@ -716,19 +728,46 @@ def __init__(self, dispatcher, griddim, blockdim, stream, sharedmem): warn(errors.NumbaPerformanceWarning(msg)) def __call__(self, *args): - return self.dispatcher.call( - args, self.griddim, self.blockdim, self.stream, self.sharedmem - ) + return self.dispatcher.call(args, self) + + def mark_kernel_as_launch_config_sensitive(self): + """Mark this configured launch path as launch-config sensitive. + + Once set, this flag is intentionally sticky for this + ``_LaunchConfiguration`` instance. This aligns with the expected LC-S + use case: if code generation depends on launch config for this + kernel/configuration path, treat it as launch-config sensitive for all + subsequent compilations through the same configured launcher. + """ + self._kernel_launch_config_sensitive = True + + def get_kernel_launch_config_sensitive(self): + """Return the launch-config sensitivity flag. + + The result is ``None`` if no explicit decision was made. + """ + return self._kernel_launch_config_sensitive + + def is_kernel_launch_config_sensitive(self): + """Return True if this kernel was marked as launch-config sensitive.""" + return bool(self._kernel_launch_config_sensitive) def __getstate__(self): state = self.__dict__.copy() state["stream"] = int(state["stream"].handle) + # Avoid serializing callables that may not be picklable. + state["pre_launch_callbacks"] = [] + state["args"] = None return state def __setstate__(self, state): handle = state.pop("stream") self.__dict__.update(state) self.stream = driver._to_core_stream(handle) + if "pre_launch_callbacks" not in self.__dict__: + self.pre_launch_callbacks = [] + if "args" not in self.__dict__: + self.args = None class CUDACacheImpl(CacheImpl): @@ -756,6 +795,52 @@ class CUDACache(Cache): _impl_class = CUDACacheImpl + def __init__(self, py_func): + self._launch_config_key = None + self._launch_config_sensitive_flag = None + super().__init__(py_func) + marker_name = f"{self._impl.filename_base}.lcs" + self._launch_config_marker_path = os.path.join( + self._cache_path, marker_name + ) + + def _index_key(self, sig, codegen): + key = super()._index_key(sig, codegen) + if self._launch_config_key is None: + return key + return key + (("launch_config", self._launch_config_key),) + + def set_launch_config_key(self, key): + self._launch_config_key = key + + def is_launch_config_sensitive(self): + if self._launch_config_sensitive_flag is None: + self._launch_config_sensitive_flag = os.path.exists( + self._launch_config_marker_path + ) + return self._launch_config_sensitive_flag + + def mark_launch_config_sensitive(self): + if self._launch_config_sensitive_flag is True: + return True + try: + self._impl.locator.ensure_cache_path() + with open(self._launch_config_marker_path, "a"): + pass + except OSError: + self._launch_config_sensitive_flag = False + return False + self._launch_config_sensitive_flag = True + return True + + def flush(self): + super().flush() + try: + os.unlink(self._launch_config_marker_path) + except FileNotFoundError: + pass + self._launch_config_sensitive_flag = None + def load_overload(self, sig, target_context): # Loading an overload refreshes the context to ensure it is initialized. with utils.numba_target_override(): @@ -1540,6 +1625,14 @@ def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler): self._cache_hits = collections.Counter() self._cache_misses = collections.Counter() + # Whether the compiled kernels are launch-config sensitive (e.g., IR + # rewrites depend on launch configuration). + self._launch_config_sensitive = False + self._launch_config_default_key = None + self._launch_config_is_specialized = False + self._launch_config_specialization_key = None + self._launch_config_specializations = {} + # The following properties are for specialization of CUDADispatchers. A # specialized CUDADispatcher is one that is compiled for exactly one # set of argument types, and bypasses some argument type checking for @@ -1585,6 +1678,80 @@ def __getitem__(self, args): raise ValueError("must specify at least the griddim and blockdim") return self.configure(*args) + @staticmethod + def _launch_config_key(launch_config): + return ( + launch_config.griddim, + launch_config.blockdim, + launch_config.sharedmem, + ) + + def _cache_launch_config_key(self, launch_config): + if self._launch_config_is_specialized: + return self._launch_config_specialization_key + if self._launch_config_default_key is not None: + return self._launch_config_default_key + if launch_config is None: + return None + return self._launch_config_key(launch_config) + + def _configure_cache_for_launch_config(self, launch_config): + if not isinstance(self._cache, CUDACache): + return + if self._launch_config_sensitive or self._launch_config_is_specialized: + key = self._cache_launch_config_key(launch_config) + self._cache.set_launch_config_key(key) + return + if self._cache.is_launch_config_sensitive(): + if launch_config is None: + key = None + else: + key = self._launch_config_key(launch_config) + self._cache.set_launch_config_key(key) + return + self._cache.set_launch_config_key(None) + + def _get_launch_config_specialization(self, key): + dispatcher = self._launch_config_specializations.get(key) + if dispatcher is None: + dispatcher = CUDADispatcher( + self.py_func, + targetoptions=self.targetoptions, + pipeline_class=self._compiler.pipeline_class, + ) + dispatcher._launch_config_sensitive = True + dispatcher._launch_config_is_specialized = True + dispatcher._launch_config_specialization_key = key + dispatcher._launch_config_default_key = key + if isinstance(self._cache, CUDACache): + dispatcher.enable_caching() + dispatcher._configure_cache_for_launch_config(None) + self._launch_config_specializations[key] = dispatcher + return dispatcher + + def _select_launch_config_dispatcher(self, launch_config): + if not self._launch_config_sensitive: + return self + if self._launch_config_is_specialized: + return self + key = self._launch_config_key(launch_config) + if self._launch_config_default_key is None: + self._launch_config_default_key = key + return self + if key == self._launch_config_default_key: + return self + return self._get_launch_config_specialization(key) + + def _update_launch_config_sensitivity(self, kernel, launch_config): + if not getattr(kernel, "launch_config_sensitive", False): + return + if not self._launch_config_sensitive: + self._launch_config_sensitive = True + if self._launch_config_default_key is None: + self._launch_config_default_key = self._launch_config_key( + launch_config + ) + def forall(self, ntasks, tpb=0, stream=0, sharedmem=0): """Returns a 1D-configured dispatcher for a given number of tasks. @@ -1632,16 +1799,36 @@ def __call__(self, *args, **kwargs): # An attempt to launch an unconfigured kernel raise ValueError(missing_launch_config_msg) - def call(self, args, griddim, blockdim, stream, sharedmem): + def call(self, args, launch_config): """ Compile if necessary and invoke this kernel with *args*. """ - if self.specialized: - kernel = next(iter(self.overloads.values())) - else: - kernel = _dispatcher.Dispatcher._cuda_call(self, *args) + griddim = launch_config.griddim + blockdim = launch_config.blockdim + stream = launch_config.stream + sharedmem = launch_config.sharedmem - kernel.launch(args, griddim, blockdim, stream, sharedmem) + launch_config.args = args + try: + dispatcher = self._select_launch_config_dispatcher(launch_config) + if dispatcher is not self: + return dispatcher.call(args, launch_config) + + if self.specialized: + kernel = next(iter(self.overloads.values())) + else: + kernel = _dispatcher.Dispatcher._cuda_call( + self, *args, **{_LAUNCH_CONFIG_KW: launch_config} + ) + + self._update_launch_config_sensitivity(kernel, launch_config) + + for callback in launch_config.pre_launch_callbacks: + callback(kernel, launch_config) + + kernel.launch(args, griddim, blockdim, stream, sharedmem) + finally: + launch_config.args = None def _compile_for_args(self, *args, **kws): # Based on _DispatcherBase._compile_for_args. @@ -1892,9 +2079,36 @@ def compile(self, sig): if kernel is not None: return kernel + launch_config = launchconfig.current_launch_config() + self._configure_cache_for_launch_config(launch_config) + # Can we load from the disk cache? kernel = self._cache.load_overload(sig, self.targetctx) + if ( + kernel is not None + and isinstance(self._cache, CUDACache) + and getattr(kernel, "launch_config_sensitive", False) + ): + cache_has_marker = self._cache.is_launch_config_sensitive() + if not cache_has_marker: + # Pre-existing cache entries without a launch-config marker are + # unsafe for LCS kernels. Force a recompile under the new key. + if launch_config is not None: + self._cache.set_launch_config_key( + self._launch_config_key(launch_config) + ) + if not self._cache.mark_launch_config_sensitive(): + # If we cannot record the marker, disable disk cache to + # avoid unsafe reuse. + self._cache = NullCache() + kernel = None + else: + if launch_config is not None: + self._cache.set_launch_config_key( + self._launch_config_key(launch_config) + ) + if kernel is not None: self._cache_hits[sig] += 1 else: @@ -1906,6 +2120,17 @@ def compile(self, sig): kernel = _Kernel(self.py_func, argtypes, **self.targetoptions) # We call bind to force codegen, so that there is a cubin to cache kernel.bind() + if isinstance(self._cache, CUDACache) and getattr( + kernel, "launch_config_sensitive", False + ): + if launch_config is not None: + self._cache.set_launch_config_key( + self._launch_config_key(launch_config) + ) + if not self._cache.mark_launch_config_sensitive(): + # If we cannot record the marker, disable disk cache to + # avoid unsafe reuse. + self._cache = NullCache() self._cache.save_overload(sig, kernel) self.add_overload(kernel, argtypes) diff --git a/numba_cuda/numba/cuda/launchconfig.py b/numba_cuda/numba/cuda/launchconfig.py new file mode 100644 index 000000000..c374d5b30 --- /dev/null +++ b/numba_cuda/numba/cuda/launchconfig.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause +"""Launch configuration access for CUDA compilation contexts. + +The current launch configuration is populated only during CUDA compilation +triggered by kernel launches. It is thread-local and cleared immediately after +compilation completes. +""" + +import contextlib +from functools import wraps + +from numba.cuda.cext import _dispatcher + + +def current_launch_config(): + """Return the current launch configuration, or None if not set.""" + return _dispatcher.get_current_launch_config() + + +def ensure_current_launch_config(): + """Return the current launch configuration or raise if not set.""" + config = current_launch_config() + if config is None: + raise RuntimeError("No launch config set for this thread") + return config + + +@contextlib.contextmanager +def capture_compile_config(dispatcher): + """Capture the launch config seen during compilation for a dispatcher. + + The returned dict has a single key, ``"config"``, which is populated when a + compilation is triggered by a kernel launch. If the kernel is already + compiled, the dict value may remain ``None``. + """ + if dispatcher is None: + raise TypeError("dispatcher is required") + + record = {"config": None} + original = dispatcher._compile_for_args + + @wraps(original) + def wrapped(*args, **kws): + record["config"] = current_launch_config() + return original(*args, **kws) + + dispatcher._compile_for_args = wrapped + try: + yield record + finally: + dispatcher._compile_for_args = original + + +__all__ = [ + "current_launch_config", + "ensure_current_launch_config", + "capture_compile_config", +] diff --git a/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_insensitive_usecases.py b/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_insensitive_usecases.py new file mode 100644 index 000000000..354ec3f30 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_insensitive_usecases.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import numpy as np + +from numba import cuda +import sys + + +@cuda.jit(cache=True) +def cache_kernel(x): + x[0] = 1 + + +def launch(blockdim): + arr = np.zeros(1, dtype=np.int32) + cache_kernel[1, blockdim](arr) + return arr + + +def self_test(): + mod = sys.modules[__name__] + out = mod.launch(32) + assert out[0] == 1 + out = mod.launch(64) + assert out[0] == 1 diff --git a/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_sensitive_usecases.py b/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_sensitive_usecases.py new file mode 100644 index 000000000..669007622 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_sensitive_usecases.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import numpy as np + +from numba import cuda +from numba.cuda import launchconfig +from numba.cuda.core.rewrites import register_rewrite, Rewrite, rewrite_registry +import sys + +_REWRITE_FLAG = "_launch_config_cache_rewrite_registered" + + +if not getattr(rewrite_registry, _REWRITE_FLAG, False): + + @register_rewrite("after-inference") + class LaunchConfigSensitiveCacheRewrite(Rewrite): + _TARGET_NAME = "lcs_cache_kernel" + + def __init__(self, state): + super().__init__(state) + self._state = state + self._block = None + self._logged = False + + def match(self, func_ir, block, typemap, calltypes): + if func_ir.func_id.func_name != self._TARGET_NAME: + return False + if self._logged: + return False + self._block = block + return True + + def apply(self): + # Ensure launch config is available and mark compilation as + # launch-config sensitive so the disk cache keys include it. + cfg = launchconfig.ensure_current_launch_config() + cfg.mark_kernel_as_launch_config_sensitive() + self._logged = True + return self._block + + setattr(rewrite_registry, _REWRITE_FLAG, True) + + +@cuda.jit(cache=True) +def lcs_cache_kernel(x): + x[0] = 1 + + +def launch(blockdim): + arr = np.zeros(1, dtype=np.int32) + lcs_cache_kernel[1, blockdim](arr) + return arr + + +def self_test(): + mod = sys.modules[__name__] + out = mod.launch(32) + assert out[0] == 1 + out = mod.launch(64) + assert out[0] == 1 diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_caching.py b/numba_cuda/numba/cuda/tests/cudapy/test_caching.py index 3d3eadc32..5d6ed023a 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_caching.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_caching.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause import multiprocessing +import json import os import shutil import unittest @@ -84,6 +85,9 @@ def get_cache_mtimes(self): for fn in sorted(self.cache_contents()) ) + def count_cache_markers(self, suffix=".lcs"): + return len([fn for fn in self.cache_contents() if fn.endswith(suffix)]) + def check_pycache(self, n): c = self.cache_contents() self.assertEqual(len(c), n, c) @@ -97,19 +101,35 @@ class DispatcherCacheUsecasesTest(BaseCacheTest): usecases_file = os.path.join(here, "cache_usecases.py") modname = "dispatcher_caching_test_fodder" - def run_in_separate_process(self): + def run_in_separate_process(self, *, envvars=None, report_code=None): # Cached functions can be run from a distinct process. # Also stresses issue #1603: uncached function calling cached function # shouldn't fail compiling. - code = """if 1: - import sys - - sys.path.insert(0, %(tempdir)r) - mod = __import__(%(modname)r) - mod.self_test() - """ % dict(tempdir=self.tempdir, modname=self.modname) + code_lines = [ + "if 1:", + " import sys", + ] + if report_code is not None: + code_lines.append(" import json") + code_lines.extend( + [ + f" sys.path.insert(0, {self.tempdir!r})", + f" mod = __import__({self.modname!r})", + " mod.self_test()", + ] + ) + if report_code is not None: + for line in report_code.splitlines(): + if line.strip(): + code_lines.append(f" {line}") + code_lines.append( + ' print("__CACHE_REPORT__" + json.dumps(report))' + ) + code = "\n".join(code_lines) subp_env = os.environ.copy() + if envvars is not None: + subp_env.update(envvars) popen = subprocess.Popen( [sys.executable, "-c", code], stdout=subprocess.PIPE, @@ -124,6 +144,16 @@ def run_in_separate_process(self): "stderr follows\n%s\n" % (popen.returncode, out.decode(), err.decode()), ) + if report_code is None: + return None + stdout = out.decode().splitlines() + marker = "__CACHE_REPORT__" + for line in reversed(stdout): + if line.startswith(marker): + return json.loads(line[len(marker) :]) + raise AssertionError( + "cache report missing from subprocess output:\n%s" % out.decode() + ) def check_hits(self, func, hits, misses=None): st = func.stats @@ -415,6 +445,167 @@ def cached_kernel_global(output): GLOBAL_DEVICE_ARRAY = None +@skip_on_cudasim("Simulator does not implement caching") +class LaunchConfigSensitiveCachingTest(DispatcherCacheUsecasesTest): + here = os.path.dirname(__file__) + usecases_file = os.path.join( + here, "cache_launch_config_sensitive_usecases.py" + ) + modname = "cuda_launch_config_sensitive_cache_test_fodder" + + def setUp(self): + DispatcherCacheUsecasesTest.setUp(self) + CUDATestCase.setUp(self) + + def tearDown(self): + CUDATestCase.tearDown(self) + DispatcherCacheUsecasesTest.tearDown(self) + + def test_launch_config_sensitive_cache_keys(self): + self.check_pycache(0) + mod = self.import_module() + self.check_pycache(0) + + mod.launch(32) + self.check_pycache(3) # index, data, marker + self.assertEqual(self.count_cache_markers(), 1) + + mod.launch(64) + self.check_pycache(4) # index, 2 data, marker + self.assertEqual(self.count_cache_markers(), 1) + + mod2 = self.import_module() + self.assertIsNot(mod, mod2) + mod2.launch(32) + self.check_hits(mod2.lcs_cache_kernel, 1, 0) + self.check_pycache(4) + + mod2.launch(64) + self.check_hits(mod2.lcs_cache_kernel, 1, 0) + self.assertEqual( + len(mod2.lcs_cache_kernel._launch_config_specializations), 1 + ) + specialization = next( + iter(mod2.lcs_cache_kernel._launch_config_specializations.values()) + ) + self.check_hits(specialization, 1, 0) + self.check_pycache(4) + + mtimes = self.get_cache_mtimes() + report = self.run_in_separate_process( + report_code="\n".join( + [ + "main_hits = sum(mod.lcs_cache_kernel.stats.cache_hits.values())", + "main_misses = sum(mod.lcs_cache_kernel.stats.cache_misses.values())", + "spec = next(iter(mod.lcs_cache_kernel._launch_config_specializations.values()))", + "spec_hits = sum(spec.stats.cache_hits.values())", + "spec_misses = sum(spec.stats.cache_misses.values())", + "report = {'main_hits': main_hits, 'main_misses': main_misses, 'spec_hits': spec_hits, 'spec_misses': spec_misses}", + ] + ) + ) + self.assertEqual(report["main_hits"], 1) + self.assertEqual(report["main_misses"], 0) + self.assertEqual(report["spec_hits"], 1) + self.assertEqual(report["spec_misses"], 0) + self.assertEqual(self.get_cache_mtimes(), mtimes) + + +@skip_on_cudasim("Simulator does not implement caching") +class LaunchConfigInsensitiveCachingTest(DispatcherCacheUsecasesTest): + here = os.path.dirname(__file__) + usecases_file = os.path.join( + here, "cache_launch_config_insensitive_usecases.py" + ) + modname = "cuda_launch_config_insensitive_cache_test_fodder" + + def setUp(self): + DispatcherCacheUsecasesTest.setUp(self) + CUDATestCase.setUp(self) + + def tearDown(self): + CUDATestCase.tearDown(self) + DispatcherCacheUsecasesTest.tearDown(self) + + def test_launch_config_insensitive_cache_keys(self): + self.check_pycache(0) + mod = self.import_module() + self.check_pycache(0) + + mod.launch(32) + self.check_pycache(2) # index, data + self.assertEqual(self.count_cache_markers(), 0) + + mod.launch(64) + self.check_pycache(2) + + mod2 = self.import_module() + self.assertIsNot(mod, mod2) + mod2.launch(64) + self.check_hits(mod2.cache_kernel, 1, 0) + self.check_pycache(2) + + mtimes = self.get_cache_mtimes() + report = self.run_in_separate_process( + report_code=( + "hits = sum(mod.cache_kernel.stats.cache_hits.values())\n" + "misses = sum(mod.cache_kernel.stats.cache_misses.values())\n" + "report = {'hits': hits, 'misses': misses}" + ) + ) + self.assertEqual(report["hits"], 1) + self.assertEqual(report["misses"], 0) + self.assertEqual(self.get_cache_mtimes(), mtimes) + + +@skip_on_cudasim("Simulator does not implement caching") +class MultiDeviceCachingTest(DispatcherCacheUsecasesTest): + here = os.path.dirname(__file__) + usecases_file = os.path.join( + here, "cache_launch_config_insensitive_usecases.py" + ) + modname = "cuda_multi_device_cache_test_fodder" + + def setUp(self): + DispatcherCacheUsecasesTest.setUp(self) + CUDATestCase.setUp(self) + + def tearDown(self): + CUDATestCase.tearDown(self) + DispatcherCacheUsecasesTest.tearDown(self) + + def test_cache_separate_per_compute_capability(self): + gpus = list(cuda.gpus) + if len(gpus) < 2: + self.skipTest("requires at least two GPUs") + + cc_map = {} + for gpu in gpus: + cc_map.setdefault(gpu.compute_capability, []).append(gpu.id) + + if len(cc_map) < 2: + self.skipTest("requires at least two distinct compute capabilities") + + for cc, ids in sorted(cc_map.items()): + dev_id = ids[0] + with cuda.gpus[dev_id]: + mod = self.import_module() + mod.launch(32) + hits = sum(mod.cache_kernel.stats.cache_hits.values()) + misses = sum(mod.cache_kernel.stats.cache_misses.values()) + self.assertEqual(hits, 0, f"unexpected cache hit for CC {cc}") + self.assertEqual(misses, 1, f"expected cache miss for CC {cc}") + + mod2 = self.import_module() + mod2.launch(32) + hits = sum(mod2.cache_kernel.stats.cache_hits.values()) + misses = sum(mod2.cache_kernel.stats.cache_misses.values()) + self.assertEqual(hits, 1, f"expected cache hit for CC {cc}") + self.assertEqual( + misses, 0, f"unexpected cache miss for CC {cc}" + ) + + @skip_on_cudasim("Simulator does not implement caching") class CUDACooperativeGroupTest(DispatcherCacheUsecasesTest): # See Issue #9432: https://github.com/numba/numba/issues/9432 diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py b/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py index aa837dc1b..eea77b09f 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py @@ -16,6 +16,7 @@ ) from numba import cuda from numba.cuda import config, types +from numba.cuda import launchconfig from numba.cuda.core.errors import TypingError from numba.cuda.testing import ( cc_X_or_above, @@ -128,6 +129,52 @@ def f(x, y): class TestDispatcher(CUDATestCase): """Most tests based on those in numba.tests.test_dispatcher.""" + @skip_on_cudasim("Dispatcher C-extension not used in the simulator") + def test_launch_config_available_during_compile(self): + @cuda.jit + def f(x): + x[0] = 1 + + seen = {} + orig = f._compile_for_args + + def wrapped(*args, **kws): + seen["config"] = launchconfig.current_launch_config() + return orig(*args, **kws) + + f._compile_for_args = wrapped + + arr = np.zeros(1, dtype=np.int32) + self.assertIsNone(launchconfig.current_launch_config()) + f[1, 1](arr) + + cfg = seen.get("config") + self.assertIsNotNone(cfg) + self.assertIs(cfg.dispatcher, f) + self.assertEqual(cfg.griddim, (1, 1, 1)) + self.assertEqual(cfg.blockdim, (1, 1, 1)) + self.assertIsNone(launchconfig.current_launch_config()) + with self.assertRaises(RuntimeError): + launchconfig.ensure_current_launch_config() + + @skip_on_cudasim("Dispatcher C-extension not used in the simulator") + def test_capture_compile_config(self): + @cuda.jit + def f(x): + x[0] = 1 + + arr = np.zeros(1, dtype=np.int32) + original = f._compile_for_args + with launchconfig.capture_compile_config(f) as capture: + f[1, 1](arr) + + cfg = capture["config"] + self.assertIsNotNone(cfg) + self.assertIs(cfg.dispatcher, f) + self.assertEqual(cfg.griddim, (1, 1, 1)) + self.assertEqual(cfg.blockdim, (1, 1, 1)) + self.assertIs(f._compile_for_args, original) + def test_coerce_input_types(self): # Do not allow unsafe conversions if we can still compile other # specializations. diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_launch_config_sensitive.py b/numba_cuda/numba/cuda/tests/cudapy/test_launch_config_sensitive.py new file mode 100644 index 000000000..744a52532 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_launch_config_sensitive.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import numpy as np + +from numba import cuda +from numba.cuda import launchconfig +from numba.cuda.core.rewrites import register_rewrite, Rewrite +from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase + + +LAUNCH_CONFIG_LOG = [] + + +def _clear_launch_config_log(): + LAUNCH_CONFIG_LOG.clear() + + +@register_rewrite("after-inference") +class LaunchConfigSensitiveRewrite(Rewrite): + """Rewrite that marks kernels as launch-config sensitive and logs config. + + This mimics cuda.coop's need to access launch config during rewrite, and + provides a global log for tests to assert on. + """ + + _TARGET_NAME = "launch_config_sensitive_kernel" + + def __init__(self, state): + super().__init__(state) + self._state = state + self._logged = False + self._block = None + + def match(self, func_ir, block, typemap, calltypes): + if func_ir.func_id.func_name != self._TARGET_NAME: + return False + if self._logged: + return False + self._block = block + return True + + def apply(self): + cfg = launchconfig.ensure_current_launch_config() + LAUNCH_CONFIG_LOG.append( + { + "griddim": cfg.griddim, + "blockdim": cfg.blockdim, + "sharedmem": cfg.sharedmem, + } + ) + # Mark compilation as launch-config sensitive so the dispatcher can + # avoid reusing the compiled kernel across different launch configs. + cfg.mark_kernel_as_launch_config_sensitive() + self._logged = True + return self._block + + +@skip_on_cudasim("Dispatcher C-extension not used in the simulator") +class TestLaunchConfigSensitive(CUDATestCase): + def setUp(self): + super().setUp() + _clear_launch_config_log() + + def test_launch_config_sensitive_requires_recompile(self): + @cuda.jit + def launch_config_sensitive_kernel(x): + x[0] = 1 + + arr = np.zeros(1, dtype=np.int32) + + launch_config_sensitive_kernel[1, 32](arr) + self.assertEqual(len(LAUNCH_CONFIG_LOG), 1) + self.assertEqual(LAUNCH_CONFIG_LOG[0]["blockdim"], (32, 1, 1)) + self.assertEqual(LAUNCH_CONFIG_LOG[0]["griddim"], (1, 1, 1)) + + launch_config_sensitive_kernel[1, 64](arr) + # Expect a new compilation for the new launch config, which will log + # a second entry with the updated block dimension. + self.assertEqual(len(LAUNCH_CONFIG_LOG), 2) + self.assertEqual(LAUNCH_CONFIG_LOG[1]["blockdim"], (64, 1, 1)) + self.assertEqual(LAUNCH_CONFIG_LOG[1]["griddim"], (1, 1, 1)) + + +if __name__ == "__main__": + unittest.main()