diff --git a/LAUNCH-CONFIG-CODEX-PROMPT.md b/LAUNCH-CONFIG-CODEX-PROMPT.md
new file mode 100644
index 000000000..f4c3f2a26
--- /dev/null
+++ b/LAUNCH-CONFIG-CODEX-PROMPT.md
@@ -0,0 +1,68 @@
+# Codex Prompt: Launch Config Sensitive (LC-S) work for cuda.coop
+
+You are resuming launch-config work in `numba-cuda` to support cuda.coop
+single-phase. The cache invalidation change for `numba_cuda.__version__` is
+already in a separate PR; do not redo that here.
+
+## Repos / worktrees
+- `numba-cuda` (current worktree): `/home/trentn/src/280-launch-config-v2`
+ - branch: `280-launch-config-v2`
+- `numba-cuda` main baseline: `/home/trentn/src/numba-cuda-main`
+- cuda.coop repo: `/home/trentn/src/cccl/python/cuda_cccl`
+ - see `SINGLE-PHASE-*.md` for context
+
+## Current local state (numba-cuda)
+Run:
+- `git status -sb`
+- `git diff`
+
+Expected (uncommitted) changes in this worktree:
+- `numba_cuda/numba/cuda/compiler.py`
+ - CUDABackend sets `state.metadata["launch_config_sensitive"] = True`
+ when the active launch config is explicitly marked.
+- `numba_cuda/numba/cuda/dispatcher.py`
+ - `_LaunchConfiguration` adds explicit API:
+ `mark_kernel_as_launch_config_sensitive()`, `get_kernel_launch_config_sensitive()`,
+ `is_kernel_launch_config_sensitive()`.
+- `scripts/bench-launch-overhead.py`
+ - import compatibility for `numba.cuda.core.config` vs `numba.core.config`.
+- Untracked: `PR.md`, `tags` (clean up before commit).
+
+## What is already implemented
+- TLS-based launch-config capture in C extension, exposed via
+ `numba_cuda/numba/cuda/launchconfig.py`.
+- Dispatcher plumbing for LC-S (per-config specialization + cache keys + `.lcs` marker).
+- Tests for LC-S recompile + cache coverage.
+- Docs updated for launch-config introspection.
+- In cccl: `cuda/coop/_rewrite.py` now marks LC-S when accessing launch config.
+ It calls `mark_kernel_as_launch_config_sensitive()` when available, with
+ fallback to `state.metadata["launch_config_sensitive"] = True`.
+
+## Open decisions / tasks
+1. **Explicit LC-S API decision: keep**
+ - `_LaunchConfiguration` explicit LC-S API is retained.
+ - Compiler hook in `CUDABackend` uses this API to set metadata.
+ - cccl rewrite is updated to use the API when available.
+
+2. **Run CUDA tests on a GPU**
+ - `pixi run -e cu-12-9-py312 pytest testing --pyargs numba.cuda.tests.cudapy.test_launch_config_sensitive -k launch_config_sensitive`
+ - `pixi run -e cu-12-9-py312 pytest testing --pyargs numba.cuda.tests.cudapy.test_caching -k launch_config_sensitive`
+ - Status: both passing on GPU in this worktree.
+
+3. **Validate disk-cache behavior across processes**
+ - Ensure `.lcs` marker + launch-config cache keying behave correctly.
+ - Status: covered by
+ `LaunchConfigSensitiveCachingTest.test_launch_config_sensitive_cache_keys`
+ in `test_caching.py` (passes, includes separate-process verification).
+
+4. **Audit launch paths**
+ - Confirm all kernel launch paths go through `CUDADispatcher.call()`.
+ - Status: Python launch paths in `dispatcher.py` verified.
+
+5. **Commit / cleanup**
+ - Remove untracked `PR.md` and `tags`.
+ - Prepare commit(s) for the launch-config work.
+
+## Notes
+- If you need to re-run the overhead micro-benchmark, see `LAUNCH-CONFIG.md`.
+- Update `LAUNCH-CONFIG-TODO.md` with any new decisions or test results.
diff --git a/LAUNCH-CONFIG-TODO.md b/LAUNCH-CONFIG-TODO.md
new file mode 100644
index 000000000..ff6645359
--- /dev/null
+++ b/LAUNCH-CONFIG-TODO.md
@@ -0,0 +1,58 @@
+# Launch Config Sensitive (LC-S) plumbing
+
+Last updated: 2026-02-19
+
+## Current status (summary)
+- Launch-config TLS capture exists in C extension and is exposed via
+ `numba_cuda/numba/cuda/launchconfig.py` (current/ensure/capture helpers).
+- Dispatcher plumbing for LC-S is implemented:
+ - `_Kernel` captures `launch_config_sensitive` from compile metadata.
+ - `CUDADispatcher` tracks LC-S and routes to per-launch-config sub-dispatchers.
+ - Disk cache includes a launch-config key and LC-S marker file (`.lcs`).
+- Tests added:
+ - `numba_cuda/numba/cuda/tests/cudapy/test_launch_config_sensitive.py`
+ - `numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_sensitive_usecases.py`
+ - `numba_cuda/numba/cuda/tests/cudapy/test_caching.py` LC-S coverage
+- Docs updated: `docs/source/reference/kernel.rst`.
+- In cccl, `cuda/coop/_rewrite.py` now marks LC-S when accessing launch config.
+ It uses the explicit LaunchConfiguration API when available, with fallback to
+ `state.metadata["launch_config_sensitive"] = True` for compatibility.
+- CUDA tests run and passing on GPU in this worktree:
+ - `pixi run -e cu-12-9-py312 pytest testing --pyargs numba.cuda.tests.cudapy.test_launch_config_sensitive -k launch_config_sensitive`
+ - `pixi run -e cu-12-9-py312 pytest testing --pyargs numba.cuda.tests.cudapy.test_caching -k launch_config_sensitive`
+
+## Local working tree state (numba-cuda)
+- Branch: `280-launch-config-v2`
+- Modified files (uncommitted):
+ - `numba_cuda/numba/cuda/compiler.py`
+ - `numba_cuda/numba/cuda/dispatcher.py`
+ - `scripts/bench-launch-overhead.py`
+- Untracked: `PR.md`, `tags`
+
+## New (uncommitted) LC-S API work
+- `_LaunchConfiguration` gains explicit helpers:
+ - `mark_kernel_as_launch_config_sensitive()`
+ - `get_kernel_launch_config_sensitive()`
+ - `is_kernel_launch_config_sensitive()`
+- `CUDABackend` sets metadata when the launch config is explicitly marked.
+ This provides an official path to mark LC-S without poking at `state.metadata`
+ directly from rewrites.
+
+## Remaining TODO
+1. **Cleanup**
+ - Remove or handle untracked `PR.md` and `tags` before committing.
+
+## Completed checks (2026-02-19)
+- **Cross-process disk-cache behavior**
+ - Verified by:
+ `pixi run -e cu-12-9-py312 pytest testing --pyargs numba.cuda.tests.cudapy.test_caching -k launch_config_sensitive`
+ - `LaunchConfigSensitiveCachingTest.test_launch_config_sensitive_cache_keys`
+ exercises cache reuse in a separate process and passed.
+- **Launch path audit**
+ - Python launch paths in `dispatcher.py` all route through
+ `CUDADispatcher.call()`: `__getitem__()` -> `configure()` ->
+ `_LaunchConfiguration.__call__()` -> `call()`, plus `ForAll.__call__()`.
+
+## Notes
+- Separate PR for cache invalidation on `numba_cuda.__version__` is already
+ pushed; do not re-implement here.
diff --git a/LAUNCH-CONFIG.md b/LAUNCH-CONFIG.md
new file mode 100644
index 000000000..8865358e8
--- /dev/null
+++ b/LAUNCH-CONFIG.md
@@ -0,0 +1,104 @@
+# Launch Config Benchmarking
+
+This repo includes lightweight benchmarking scaffolding to quantify CUDA kernel
+launch overhead across three launch-config implementations (baseline, old
+contextvar branch, and the new v2 branch).
+
+## Status / Next Steps (Launch Config Work)
+- LC-S plumbing is implemented in `dispatcher.py` and supporting files.
+- CUDA LC-S tests have been run on GPU in this branch and are passing.
+- There are uncommitted changes in `compiler.py`, `dispatcher.py`, and
+ `scripts/bench-launch-overhead.py` that add an explicit LC-S API on
+ `_LaunchConfiguration` and a compiler hook to honor it.
+- cccl rewrite integration now uses the explicit LC-S API with a fallback to
+ metadata for compatibility.
+- Cross-process disk-cache behavior is covered by LC-S caching tests and passes.
+- See `LAUNCH-CONFIG-TODO.md` for a detailed handoff checklist.
+
+## What’s Included
+
+### 1) `scripts/bench-launch-overhead.py`
+A focused micro-benchmark that measures launch overhead (us/launch) for kernels
+with 0..4 arguments, using a 1x1 launch. It:
+- warms up each kernel
+- runs `loops` iterations per kernel (default: 100k for 0–3 args, 10k for 4 args)
+- repeats the measurement (default: 7 repeats)
+- reports mean/stdev and deltas vs the first repo
+- optionally writes JSON output
+
+The benchmark is designed to compare multiple repos (or worktrees) in the same
+Python environment.
+
+### 2) `scripts/bench-against.py`
+A helper that compares benchmarks between two git refs using a temporary
+worktree, running the pixi benchmark tasks before and after.
+
+### 3) Pixi tasks
+Defined in `pixi.toml`:
+- `bench-launch-overhead`: runs `scripts/bench-launch-overhead.py`
+- `bench`: pytest benchmark suite (`numba.cuda.tests.benchmarks`)
+- `benchcmp`: compare benchmark results from `bench`
+- `bench-against`: runs `scripts/bench-against.py`
+
+## Recommended Usage (Three-Way Compare)
+
+Assuming you have three working trees for:
+- **baseline** (main or a baseline ref)
+- **contextvar** (old implementation)
+- **v2** (new implementation)
+
+Run the launch-overhead micro-benchmark:
+
+```bash
+pixi run -e cu-12-9-py312 bench-launch-overhead \
+ --repo baseline=/path/to/numba-cuda-main \
+ --repo contextvar=/path/to/numba-cuda-contextvar \
+ --repo v2=/home/trentn/src/280-launch-config-v2
+```
+
+Notes:
+- The script will `pip install -e` each repo by default. Use `--no-install`
+ if you have already installed them and want to skip reinstalling.
+- Use `--python` to point at a specific interpreter if needed.
+- Use `--loops` to override the default loop counts, e.g. `--loops 200000,200000,200000,200000,20000`.
+- Use `--output results.json` to persist the results.
+
+## Example Output
+
+```
+Launch overhead (us/launch):
+args baseline contextvar v2
+0 4.10 +/- 0.05 6.20 +/- 0.06 4.50 +/- 0.04
+1 4.40 +/- 0.05 6.60 +/- 0.06 4.80 +/- 0.05
+...
+
+Deltas vs baseline:
+args contextvar v2
+0 2.10 (+51.2%) 0.40 (+9.8%)
+1 2.20 (+50.0%) 0.40 (+9.1%)
+...
+```
+
+## Benchmark Suite (Broader Coverage)
+
+For more extensive benchmark coverage (not just launch overhead), use:
+
+```bash
+pixi run -e cu-12-9-py312 bench
+```
+
+To compare two git refs using a temporary worktree:
+
+```bash
+pixi run -e cu-12-9-py312 bench-against HEAD~ HEAD
+```
+
+This runs `bench` on the baseline ref and `benchcmp` on the proposed ref.
+
+## Notes / Constraints
+
+- Benchmarks require a real GPU (CUDA simulator is rejected).
+- The micro-benchmark intentionally keeps kernels trivial to isolate launch
+ overhead.
+- The three-way comparison is the most direct way to capture the relative
+ overhead introduced by launch-config state management.
diff --git a/docs/source/reference/kernel.rst b/docs/source/reference/kernel.rst
index 64a452947..c562a189e 100644
--- a/docs/source/reference/kernel.rst
+++ b/docs/source/reference/kernel.rst
@@ -57,6 +57,45 @@ This is similar to launch configuration in CUDA C/C++:
.. note:: The order of ``stream`` and ``sharedmem`` are reversed in Numba
compared to in CUDA C/C++.
+Launch configuration introspection (advanced)
+---------------------------------------------
+
+The current launch configuration is available during compilation triggered by
+kernel launches. This can be useful for debugging or for extensions that need
+to observe how kernels are configured.
+
+.. note:: The capture is compile-time only. If the kernel is already compiled
+ for the given argument types, the captured config may remain ``None``.
+
+.. code-block:: python
+
+ from numba import cuda
+ from numba.cuda import launchconfig
+
+ @cuda.jit
+ def f(x):
+ x[0] = 1
+
+ arr = cuda.device_array(1, dtype="i4")
+ with launchconfig.capture_compile_config(f) as capture:
+ f[1, 1](arr) # first launch triggers compilation
+
+ cfg = capture["config"]
+ print(cfg.griddim, cfg.blockdim, cfg.sharedmem)
+
+Configured kernels also expose pre-launch callbacks for lightweight
+instrumentation:
+
+.. code-block:: python
+
+ cfg = f[1, 1]
+
+ def log_launch(kernel, cfg):
+ print(cfg.griddim, cfg.blockdim)
+
+ cfg.pre_launch_callbacks.append(log_launch)
+ cfg(arr)
+
Dispatcher objects also provide several utility methods for inspection and
creating a specialized instance:
diff --git a/docs/source/user/index.rst b/docs/source/user/index.rst
index c145e0993..694a95e5f 100644
--- a/docs/source/user/index.rst
+++ b/docs/source/user/index.rst
@@ -11,6 +11,7 @@ User guide
.. toctree::
installation.rst
+ release_notes.rst
kernels.rst
memory.rst
device-functions.rst
diff --git a/numba_cuda/numba/cuda/cext/_dispatcher.cpp b/numba_cuda/numba/cuda/cext/_dispatcher.cpp
index d6d9a304c..fa6f79001 100644
--- a/numba_cuda/numba/cuda/cext/_dispatcher.cpp
+++ b/numba_cuda/numba/cuda/cext/_dispatcher.cpp
@@ -13,6 +13,81 @@
#include "traceback.h"
#include "typeconv.hpp"
+static Py_tss_t launch_config_tss_key = Py_tss_NEEDS_INIT;
+static const char *launch_config_kw = "__numba_cuda_launch_config__";
+
+static int
+launch_config_tss_init(void)
+{
+ if (PyThread_tss_create(&launch_config_tss_key) != 0) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Failed to initialize launch config TLS");
+ return -1;
+ }
+ return 0;
+}
+
+static PyObject *
+launch_config_get_borrowed(void)
+{
+ return (PyObject *) PyThread_tss_get(&launch_config_tss_key);
+}
+
+static int
+launch_config_set(PyObject *obj)
+{
+ PyObject *old = (PyObject *) PyThread_tss_get(&launch_config_tss_key);
+ if (obj != NULL) {
+ Py_INCREF(obj);
+ }
+ if (PyThread_tss_set(&launch_config_tss_key, (void *) obj) != 0) {
+ Py_XDECREF(obj);
+ PyErr_SetString(PyExc_RuntimeError,
+ "Failed to set launch config TLS");
+ return -1;
+ }
+ Py_XDECREF(old);
+ return 0;
+}
+
+class LaunchConfigGuard {
+public:
+ explicit LaunchConfigGuard(PyObject *value)
+ : prev(NULL), active(false), requested(value != NULL)
+ {
+ if (!requested) {
+ return;
+ }
+ prev = launch_config_get_borrowed();
+ Py_XINCREF(prev);
+ if (launch_config_set(value) != 0) {
+ Py_XDECREF(prev);
+ prev = NULL;
+ return;
+ }
+ active = true;
+ }
+
+ bool failed(void) const
+ {
+ return requested && !active;
+ }
+
+ ~LaunchConfigGuard(void)
+ {
+ if (!active) {
+ return;
+ }
+ launch_config_set(prev);
+ Py_XDECREF(prev);
+ }
+
+private:
+ PyObject *prev;
+ bool active;
+ bool requested;
+};
+
/*
* Notes on the C_TRACE macro:
*
@@ -840,6 +915,7 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
PyObject *cfunc;
PyThreadState *ts = PyThreadState_Get();
PyObject *locals = NULL;
+ PyObject *launch_config = NULL;
/* If compilation is enabled, ensure that an exact match is found and if
* not compile one */
@@ -855,9 +931,26 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
goto CLEANUP;
}
}
+ if (kws != NULL) {
+ launch_config = PyDict_GetItemString(kws, launch_config_kw);
+ if (launch_config != NULL) {
+ Py_INCREF(launch_config);
+ if (PyDict_DelItemString(kws, launch_config_kw) < 0) {
+ Py_DECREF(launch_config);
+ launch_config = NULL;
+ goto CLEANUP;
+ }
+ if (launch_config == Py_None) {
+ Py_DECREF(launch_config);
+ launch_config = NULL;
+ }
+ }
+ }
if (self->fold_args) {
- if (find_named_args(self, &args, &kws))
+ if (find_named_args(self, &args, &kws)) {
+ Py_XDECREF(launch_config);
return NULL;
+ }
}
else
Py_INCREF(args);
@@ -913,6 +1006,11 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
} else if (matches == 0) {
/* No matching definition */
if (self->can_compile) {
+ LaunchConfigGuard guard(launch_config);
+ if (guard.failed()) {
+ retval = NULL;
+ goto CLEANUP;
+ }
retval = cuda_compile_only(self, args, kws, locals);
} else if (self->fallbackdef) {
/* Have object fallback */
@@ -924,6 +1022,11 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
}
} else if (self->can_compile) {
/* Ambiguous, but are allowed to compile */
+ LaunchConfigGuard guard(launch_config);
+ if (guard.failed()) {
+ retval = NULL;
+ goto CLEANUP;
+ }
retval = cuda_compile_only(self, args, kws, locals);
} else {
/* Ambiguous */
@@ -935,6 +1038,7 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
if (tys != prealloc)
delete[] tys;
Py_DECREF(args);
+ Py_XDECREF(launch_config);
return retval;
}
@@ -1040,10 +1144,23 @@ static PyObject *compute_fingerprint(PyObject *self, PyObject *args)
return typeof_compute_fingerprint(val);
}
+static PyObject *
+get_current_launch_config(PyObject *self, PyObject *args)
+{
+ PyObject *config = launch_config_get_borrowed();
+ if (config == NULL) {
+ Py_RETURN_NONE;
+ }
+ Py_INCREF(config);
+ return config;
+}
+
static PyMethodDef ext_methods[] = {
#define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL }
declmethod(typeof_init),
declmethod(compute_fingerprint),
+ { "get_current_launch_config", (PyCFunction)get_current_launch_config,
+ METH_NOARGS, NULL },
{ NULL },
#undef declmethod
};
@@ -1055,6 +1172,10 @@ MOD_INIT(_dispatcher) {
if (m == NULL)
return MOD_ERROR_VAL;
+ if (launch_config_tss_init() != 0) {
+ return MOD_ERROR_VAL;
+ }
+
DispatcherType.tp_new = PyType_GenericNew;
if (PyType_Ready(&DispatcherType) < 0) {
return MOD_ERROR_VAL;
diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py
index cbe1dfac2..4f666209d 100644
--- a/numba_cuda/numba/cuda/compiler.py
+++ b/numba_cuda/numba/cuda/compiler.py
@@ -15,6 +15,7 @@
from numba.cuda.core.interpreter import Interpreter
from numba.cuda import cgutils, typing, lowering, nvvmutils, utils
+from numba.cuda import launchconfig
from numba.cuda.api import get_current_device
from numba.cuda.codegen import ExternalCodeLibrary
@@ -398,6 +399,14 @@ def run_pass(self, state):
"""
lowered = state["cr"]
signature = typing.signature(state.return_type, *state.args)
+ launch_cfg = launchconfig.current_launch_config()
+ if (
+ launch_cfg is not None
+ and launch_cfg.is_kernel_launch_config_sensitive()
+ ):
+ if state.metadata is None:
+ state.metadata = {}
+ state.metadata["launch_config_sensitive"] = True
state.cr = cuda_compile_result(
typing_context=state.typingctx,
@@ -408,6 +417,9 @@ def run_pass(self, state):
call_helper=lowered.call_helper,
signature=signature,
fndesc=lowered.fndesc,
+ # Preserve metadata populated by rewrite passes (e.g. launch-config
+ # sensitivity) so downstream consumers can act on it.
+ metadata=state.metadata,
)
return True
diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py
index aa8fb76be..65488f5f5 100644
--- a/numba_cuda/numba/cuda/dispatcher.py
+++ b/numba_cuda/numba/cuda/dispatcher.py
@@ -20,6 +20,7 @@
from numba.cuda.core import errors
from numba.cuda import serialize, utils
from numba import cuda
+from numba.cuda import launchconfig
from numba.cuda.np import numpy_support
from numba.cuda.core.compiler_lock import global_compiler_lock
@@ -55,6 +56,7 @@
import numba.cuda.core.event as ev
from numba.cuda.cext import _dispatcher
+_LAUNCH_CONFIG_KW = "__numba_cuda_launch_config__"
cuda_fp16_math_funcs = [
"hsin",
@@ -165,6 +167,10 @@ def __init__(
max_registers=max_registers,
lto=lto,
)
+ self.launch_config_sensitive = bool(
+ getattr(cres, "metadata", None)
+ and cres.metadata.get("launch_config_sensitive", False)
+ )
tgt_ctx = cres.target_context
lib = cres.library
kernel = lib.get_function(cres.fndesc.llvm_func_name)
@@ -297,6 +303,7 @@ def _rebuild(
call_helper,
extensions,
shared_memory_carveout=None,
+ launch_config_sensitive=False,
):
"""
Rebuild an instance.
@@ -316,6 +323,7 @@ def _rebuild(
instance.call_helper = call_helper
instance.extensions = extensions
instance.shared_memory_carveout = shared_memory_carveout
+ instance.launch_config_sensitive = launch_config_sensitive
return instance
def _reduce_states(self):
@@ -336,6 +344,7 @@ def _reduce_states(self):
call_helper=self.call_helper,
extensions=self.extensions,
shared_memory_carveout=self.shared_memory_carveout,
+ launch_config_sensitive=self.launch_config_sensitive,
)
def _parse_carveout(self, carveout):
@@ -693,6 +702,9 @@ def __init__(self, dispatcher, griddim, blockdim, stream, sharedmem):
self.blockdim = blockdim
self.stream = driver._to_core_stream(stream)
self.sharedmem = sharedmem
+ self.pre_launch_callbacks = []
+ self.args = None
+ self._kernel_launch_config_sensitive = None
if (
config.CUDA_LOW_OCCUPANCY_WARNINGS
@@ -716,19 +728,46 @@ def __init__(self, dispatcher, griddim, blockdim, stream, sharedmem):
warn(errors.NumbaPerformanceWarning(msg))
def __call__(self, *args):
- return self.dispatcher.call(
- args, self.griddim, self.blockdim, self.stream, self.sharedmem
- )
+ return self.dispatcher.call(args, self)
+
+ def mark_kernel_as_launch_config_sensitive(self):
+ """Mark this configured launch path as launch-config sensitive.
+
+ Once set, this flag is intentionally sticky for this
+ ``_LaunchConfiguration`` instance. This aligns with the expected LC-S
+ use case: if code generation depends on launch config for this
+ kernel/configuration path, treat it as launch-config sensitive for all
+ subsequent compilations through the same configured launcher.
+ """
+ self._kernel_launch_config_sensitive = True
+
+ def get_kernel_launch_config_sensitive(self):
+ """Return the launch-config sensitivity flag.
+
+ The result is ``None`` if no explicit decision was made.
+ """
+ return self._kernel_launch_config_sensitive
+
+ def is_kernel_launch_config_sensitive(self):
+ """Return True if this kernel was marked as launch-config sensitive."""
+ return bool(self._kernel_launch_config_sensitive)
def __getstate__(self):
state = self.__dict__.copy()
state["stream"] = int(state["stream"].handle)
+ # Avoid serializing callables that may not be picklable.
+ state["pre_launch_callbacks"] = []
+ state["args"] = None
return state
def __setstate__(self, state):
handle = state.pop("stream")
self.__dict__.update(state)
self.stream = driver._to_core_stream(handle)
+ if "pre_launch_callbacks" not in self.__dict__:
+ self.pre_launch_callbacks = []
+ if "args" not in self.__dict__:
+ self.args = None
class CUDACacheImpl(CacheImpl):
@@ -756,6 +795,52 @@ class CUDACache(Cache):
_impl_class = CUDACacheImpl
+ def __init__(self, py_func):
+ self._launch_config_key = None
+ self._launch_config_sensitive_flag = None
+ super().__init__(py_func)
+ marker_name = f"{self._impl.filename_base}.lcs"
+ self._launch_config_marker_path = os.path.join(
+ self._cache_path, marker_name
+ )
+
+ def _index_key(self, sig, codegen):
+ key = super()._index_key(sig, codegen)
+ if self._launch_config_key is None:
+ return key
+ return key + (("launch_config", self._launch_config_key),)
+
+ def set_launch_config_key(self, key):
+ self._launch_config_key = key
+
+ def is_launch_config_sensitive(self):
+ if self._launch_config_sensitive_flag is None:
+ self._launch_config_sensitive_flag = os.path.exists(
+ self._launch_config_marker_path
+ )
+ return self._launch_config_sensitive_flag
+
+ def mark_launch_config_sensitive(self):
+ if self._launch_config_sensitive_flag is True:
+ return True
+ try:
+ self._impl.locator.ensure_cache_path()
+ with open(self._launch_config_marker_path, "a"):
+ pass
+ except OSError:
+ self._launch_config_sensitive_flag = False
+ return False
+ self._launch_config_sensitive_flag = True
+ return True
+
+ def flush(self):
+ super().flush()
+ try:
+ os.unlink(self._launch_config_marker_path)
+ except FileNotFoundError:
+ pass
+ self._launch_config_sensitive_flag = None
+
def load_overload(self, sig, target_context):
# Loading an overload refreshes the context to ensure it is initialized.
with utils.numba_target_override():
@@ -1540,6 +1625,14 @@ def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler):
self._cache_hits = collections.Counter()
self._cache_misses = collections.Counter()
+ # Whether the compiled kernels are launch-config sensitive (e.g., IR
+ # rewrites depend on launch configuration).
+ self._launch_config_sensitive = False
+ self._launch_config_default_key = None
+ self._launch_config_is_specialized = False
+ self._launch_config_specialization_key = None
+ self._launch_config_specializations = {}
+
# The following properties are for specialization of CUDADispatchers. A
# specialized CUDADispatcher is one that is compiled for exactly one
# set of argument types, and bypasses some argument type checking for
@@ -1585,6 +1678,80 @@ def __getitem__(self, args):
raise ValueError("must specify at least the griddim and blockdim")
return self.configure(*args)
+ @staticmethod
+ def _launch_config_key(launch_config):
+ return (
+ launch_config.griddim,
+ launch_config.blockdim,
+ launch_config.sharedmem,
+ )
+
+ def _cache_launch_config_key(self, launch_config):
+ if self._launch_config_is_specialized:
+ return self._launch_config_specialization_key
+ if self._launch_config_default_key is not None:
+ return self._launch_config_default_key
+ if launch_config is None:
+ return None
+ return self._launch_config_key(launch_config)
+
+ def _configure_cache_for_launch_config(self, launch_config):
+ if not isinstance(self._cache, CUDACache):
+ return
+ if self._launch_config_sensitive or self._launch_config_is_specialized:
+ key = self._cache_launch_config_key(launch_config)
+ self._cache.set_launch_config_key(key)
+ return
+ if self._cache.is_launch_config_sensitive():
+ if launch_config is None:
+ key = None
+ else:
+ key = self._launch_config_key(launch_config)
+ self._cache.set_launch_config_key(key)
+ return
+ self._cache.set_launch_config_key(None)
+
+ def _get_launch_config_specialization(self, key):
+ dispatcher = self._launch_config_specializations.get(key)
+ if dispatcher is None:
+ dispatcher = CUDADispatcher(
+ self.py_func,
+ targetoptions=self.targetoptions,
+ pipeline_class=self._compiler.pipeline_class,
+ )
+ dispatcher._launch_config_sensitive = True
+ dispatcher._launch_config_is_specialized = True
+ dispatcher._launch_config_specialization_key = key
+ dispatcher._launch_config_default_key = key
+ if isinstance(self._cache, CUDACache):
+ dispatcher.enable_caching()
+ dispatcher._configure_cache_for_launch_config(None)
+ self._launch_config_specializations[key] = dispatcher
+ return dispatcher
+
+ def _select_launch_config_dispatcher(self, launch_config):
+ if not self._launch_config_sensitive:
+ return self
+ if self._launch_config_is_specialized:
+ return self
+ key = self._launch_config_key(launch_config)
+ if self._launch_config_default_key is None:
+ self._launch_config_default_key = key
+ return self
+ if key == self._launch_config_default_key:
+ return self
+ return self._get_launch_config_specialization(key)
+
+ def _update_launch_config_sensitivity(self, kernel, launch_config):
+ if not getattr(kernel, "launch_config_sensitive", False):
+ return
+ if not self._launch_config_sensitive:
+ self._launch_config_sensitive = True
+ if self._launch_config_default_key is None:
+ self._launch_config_default_key = self._launch_config_key(
+ launch_config
+ )
+
def forall(self, ntasks, tpb=0, stream=0, sharedmem=0):
"""Returns a 1D-configured dispatcher for a given number of tasks.
@@ -1632,16 +1799,36 @@ def __call__(self, *args, **kwargs):
# An attempt to launch an unconfigured kernel
raise ValueError(missing_launch_config_msg)
- def call(self, args, griddim, blockdim, stream, sharedmem):
+ def call(self, args, launch_config):
"""
Compile if necessary and invoke this kernel with *args*.
"""
- if self.specialized:
- kernel = next(iter(self.overloads.values()))
- else:
- kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
+ griddim = launch_config.griddim
+ blockdim = launch_config.blockdim
+ stream = launch_config.stream
+ sharedmem = launch_config.sharedmem
- kernel.launch(args, griddim, blockdim, stream, sharedmem)
+ launch_config.args = args
+ try:
+ dispatcher = self._select_launch_config_dispatcher(launch_config)
+ if dispatcher is not self:
+ return dispatcher.call(args, launch_config)
+
+ if self.specialized:
+ kernel = next(iter(self.overloads.values()))
+ else:
+ kernel = _dispatcher.Dispatcher._cuda_call(
+ self, *args, **{_LAUNCH_CONFIG_KW: launch_config}
+ )
+
+ self._update_launch_config_sensitivity(kernel, launch_config)
+
+ for callback in launch_config.pre_launch_callbacks:
+ callback(kernel, launch_config)
+
+ kernel.launch(args, griddim, blockdim, stream, sharedmem)
+ finally:
+ launch_config.args = None
def _compile_for_args(self, *args, **kws):
# Based on _DispatcherBase._compile_for_args.
@@ -1892,9 +2079,36 @@ def compile(self, sig):
if kernel is not None:
return kernel
+ launch_config = launchconfig.current_launch_config()
+ self._configure_cache_for_launch_config(launch_config)
+
# Can we load from the disk cache?
kernel = self._cache.load_overload(sig, self.targetctx)
+ if (
+ kernel is not None
+ and isinstance(self._cache, CUDACache)
+ and getattr(kernel, "launch_config_sensitive", False)
+ ):
+ cache_has_marker = self._cache.is_launch_config_sensitive()
+ if not cache_has_marker:
+ # Pre-existing cache entries without a launch-config marker are
+ # unsafe for LCS kernels. Force a recompile under the new key.
+ if launch_config is not None:
+ self._cache.set_launch_config_key(
+ self._launch_config_key(launch_config)
+ )
+ if not self._cache.mark_launch_config_sensitive():
+ # If we cannot record the marker, disable disk cache to
+ # avoid unsafe reuse.
+ self._cache = NullCache()
+ kernel = None
+ else:
+ if launch_config is not None:
+ self._cache.set_launch_config_key(
+ self._launch_config_key(launch_config)
+ )
+
if kernel is not None:
self._cache_hits[sig] += 1
else:
@@ -1906,6 +2120,17 @@ def compile(self, sig):
kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
# We call bind to force codegen, so that there is a cubin to cache
kernel.bind()
+ if isinstance(self._cache, CUDACache) and getattr(
+ kernel, "launch_config_sensitive", False
+ ):
+ if launch_config is not None:
+ self._cache.set_launch_config_key(
+ self._launch_config_key(launch_config)
+ )
+ if not self._cache.mark_launch_config_sensitive():
+ # If we cannot record the marker, disable disk cache to
+ # avoid unsafe reuse.
+ self._cache = NullCache()
self._cache.save_overload(sig, kernel)
self.add_overload(kernel, argtypes)
diff --git a/numba_cuda/numba/cuda/launchconfig.py b/numba_cuda/numba/cuda/launchconfig.py
new file mode 100644
index 000000000..c374d5b30
--- /dev/null
+++ b/numba_cuda/numba/cuda/launchconfig.py
@@ -0,0 +1,59 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-2-Clause
+"""Launch configuration access for CUDA compilation contexts.
+
+The current launch configuration is populated only during CUDA compilation
+triggered by kernel launches. It is thread-local and cleared immediately after
+compilation completes.
+"""
+
+import contextlib
+from functools import wraps
+
+from numba.cuda.cext import _dispatcher
+
+
+def current_launch_config():
+ """Return the current launch configuration, or None if not set."""
+ return _dispatcher.get_current_launch_config()
+
+
+def ensure_current_launch_config():
+ """Return the current launch configuration or raise if not set."""
+ config = current_launch_config()
+ if config is None:
+ raise RuntimeError("No launch config set for this thread")
+ return config
+
+
+@contextlib.contextmanager
+def capture_compile_config(dispatcher):
+ """Capture the launch config seen during compilation for a dispatcher.
+
+ The returned dict has a single key, ``"config"``, which is populated when a
+ compilation is triggered by a kernel launch. If the kernel is already
+ compiled, the dict value may remain ``None``.
+ """
+ if dispatcher is None:
+ raise TypeError("dispatcher is required")
+
+ record = {"config": None}
+ original = dispatcher._compile_for_args
+
+ @wraps(original)
+ def wrapped(*args, **kws):
+ record["config"] = current_launch_config()
+ return original(*args, **kws)
+
+ dispatcher._compile_for_args = wrapped
+ try:
+ yield record
+ finally:
+ dispatcher._compile_for_args = original
+
+
+__all__ = [
+ "current_launch_config",
+ "ensure_current_launch_config",
+ "capture_compile_config",
+]
diff --git a/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_insensitive_usecases.py b/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_insensitive_usecases.py
new file mode 100644
index 000000000..354ec3f30
--- /dev/null
+++ b/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_insensitive_usecases.py
@@ -0,0 +1,26 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-2-Clause
+
+import numpy as np
+
+from numba import cuda
+import sys
+
+
+@cuda.jit(cache=True)
+def cache_kernel(x):
+ x[0] = 1
+
+
+def launch(blockdim):
+ arr = np.zeros(1, dtype=np.int32)
+ cache_kernel[1, blockdim](arr)
+ return arr
+
+
+def self_test():
+ mod = sys.modules[__name__]
+ out = mod.launch(32)
+ assert out[0] == 1
+ out = mod.launch(64)
+ assert out[0] == 1
diff --git a/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_sensitive_usecases.py b/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_sensitive_usecases.py
new file mode 100644
index 000000000..669007622
--- /dev/null
+++ b/numba_cuda/numba/cuda/tests/cudapy/cache_launch_config_sensitive_usecases.py
@@ -0,0 +1,61 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-2-Clause
+
+import numpy as np
+
+from numba import cuda
+from numba.cuda import launchconfig
+from numba.cuda.core.rewrites import register_rewrite, Rewrite, rewrite_registry
+import sys
+
+_REWRITE_FLAG = "_launch_config_cache_rewrite_registered"
+
+
+if not getattr(rewrite_registry, _REWRITE_FLAG, False):
+
+ @register_rewrite("after-inference")
+ class LaunchConfigSensitiveCacheRewrite(Rewrite):
+ _TARGET_NAME = "lcs_cache_kernel"
+
+ def __init__(self, state):
+ super().__init__(state)
+ self._state = state
+ self._block = None
+ self._logged = False
+
+ def match(self, func_ir, block, typemap, calltypes):
+ if func_ir.func_id.func_name != self._TARGET_NAME:
+ return False
+ if self._logged:
+ return False
+ self._block = block
+ return True
+
+ def apply(self):
+ # Ensure launch config is available and mark compilation as
+ # launch-config sensitive so the disk cache keys include it.
+ cfg = launchconfig.ensure_current_launch_config()
+ cfg.mark_kernel_as_launch_config_sensitive()
+ self._logged = True
+ return self._block
+
+ setattr(rewrite_registry, _REWRITE_FLAG, True)
+
+
+@cuda.jit(cache=True)
+def lcs_cache_kernel(x):
+ x[0] = 1
+
+
+def launch(blockdim):
+ arr = np.zeros(1, dtype=np.int32)
+ lcs_cache_kernel[1, blockdim](arr)
+ return arr
+
+
+def self_test():
+ mod = sys.modules[__name__]
+ out = mod.launch(32)
+ assert out[0] == 1
+ out = mod.launch(64)
+ assert out[0] == 1
diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_caching.py b/numba_cuda/numba/cuda/tests/cudapy/test_caching.py
index 3d3eadc32..2449bd647 100644
--- a/numba_cuda/numba/cuda/tests/cudapy/test_caching.py
+++ b/numba_cuda/numba/cuda/tests/cudapy/test_caching.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-2-Clause
import multiprocessing
+import json
import os
import shutil
import unittest
@@ -84,6 +85,11 @@ def get_cache_mtimes(self):
for fn in sorted(self.cache_contents())
)
+ def count_cache_markers(self, suffix=".lcs"):
+ return len(
+ [fn for fn in self.cache_contents() if fn.endswith(suffix)]
+ )
+
def check_pycache(self, n):
c = self.cache_contents()
self.assertEqual(len(c), n, c)
@@ -97,19 +103,35 @@ class DispatcherCacheUsecasesTest(BaseCacheTest):
usecases_file = os.path.join(here, "cache_usecases.py")
modname = "dispatcher_caching_test_fodder"
- def run_in_separate_process(self):
+ def run_in_separate_process(self, *, envvars=None, report_code=None):
# Cached functions can be run from a distinct process.
# Also stresses issue #1603: uncached function calling cached function
# shouldn't fail compiling.
- code = """if 1:
- import sys
-
- sys.path.insert(0, %(tempdir)r)
- mod = __import__(%(modname)r)
- mod.self_test()
- """ % dict(tempdir=self.tempdir, modname=self.modname)
+ code_lines = [
+ "if 1:",
+ " import sys",
+ ]
+ if report_code is not None:
+ code_lines.append(" import json")
+ code_lines.extend(
+ [
+ f" sys.path.insert(0, {self.tempdir!r})",
+ f" mod = __import__({self.modname!r})",
+ " mod.self_test()",
+ ]
+ )
+ if report_code is not None:
+ for line in report_code.splitlines():
+ if line.strip():
+ code_lines.append(f" {line}")
+ code_lines.append(
+ ' print("__CACHE_REPORT__" + json.dumps(report))'
+ )
+ code = "\n".join(code_lines)
subp_env = os.environ.copy()
+ if envvars is not None:
+ subp_env.update(envvars)
popen = subprocess.Popen(
[sys.executable, "-c", code],
stdout=subprocess.PIPE,
@@ -124,6 +146,17 @@ def run_in_separate_process(self):
"stderr follows\n%s\n"
% (popen.returncode, out.decode(), err.decode()),
)
+ if report_code is None:
+ return None
+ stdout = out.decode().splitlines()
+ marker = "__CACHE_REPORT__"
+ for line in reversed(stdout):
+ if line.startswith(marker):
+ return json.loads(line[len(marker) :])
+ raise AssertionError(
+ "cache report missing from subprocess output:\n%s"
+ % out.decode()
+ )
def check_hits(self, func, hits, misses=None):
st = func.stats
@@ -415,6 +448,175 @@ def cached_kernel_global(output):
GLOBAL_DEVICE_ARRAY = None
+@skip_on_cudasim("Simulator does not implement caching")
+class LaunchConfigSensitiveCachingTest(DispatcherCacheUsecasesTest):
+ here = os.path.dirname(__file__)
+ usecases_file = os.path.join(
+ here, "cache_launch_config_sensitive_usecases.py"
+ )
+ modname = "cuda_launch_config_sensitive_cache_test_fodder"
+
+ def setUp(self):
+ DispatcherCacheUsecasesTest.setUp(self)
+ CUDATestCase.setUp(self)
+
+ def tearDown(self):
+ CUDATestCase.tearDown(self)
+ DispatcherCacheUsecasesTest.tearDown(self)
+
+ def test_launch_config_sensitive_cache_keys(self):
+ self.check_pycache(0)
+ mod = self.import_module()
+ self.check_pycache(0)
+
+ mod.launch(32)
+ self.check_pycache(3) # index, data, marker
+ self.assertEqual(self.count_cache_markers(), 1)
+
+ mod.launch(64)
+ self.check_pycache(4) # index, 2 data, marker
+ self.assertEqual(self.count_cache_markers(), 1)
+
+ mod2 = self.import_module()
+ self.assertIsNot(mod, mod2)
+ mod2.launch(32)
+ self.check_hits(mod2.lcs_cache_kernel, 1, 0)
+ self.check_pycache(4)
+
+ mod2.launch(64)
+ self.check_hits(mod2.lcs_cache_kernel, 1, 0)
+ self.assertEqual(
+ len(mod2.lcs_cache_kernel._launch_config_specializations), 1
+ )
+ specialization = next(
+ iter(mod2.lcs_cache_kernel._launch_config_specializations.values())
+ )
+ self.check_hits(specialization, 1, 0)
+ self.check_pycache(4)
+
+ mtimes = self.get_cache_mtimes()
+ report = self.run_in_separate_process(
+ report_code="\n".join(
+ [
+ "main_hits = sum(mod.lcs_cache_kernel.stats.cache_hits.values())",
+ "main_misses = sum(mod.lcs_cache_kernel.stats.cache_misses.values())",
+ "spec = next(iter(mod.lcs_cache_kernel._launch_config_specializations.values()))",
+ "spec_hits = sum(spec.stats.cache_hits.values())",
+ "spec_misses = sum(spec.stats.cache_misses.values())",
+ "report = {'main_hits': main_hits, 'main_misses': main_misses, 'spec_hits': spec_hits, 'spec_misses': spec_misses}",
+ ]
+ )
+ )
+ self.assertEqual(report["main_hits"], 1)
+ self.assertEqual(report["main_misses"], 0)
+ self.assertEqual(report["spec_hits"], 1)
+ self.assertEqual(report["spec_misses"], 0)
+ self.assertEqual(self.get_cache_mtimes(), mtimes)
+
+
+@skip_on_cudasim("Simulator does not implement caching")
+class LaunchConfigInsensitiveCachingTest(DispatcherCacheUsecasesTest):
+ here = os.path.dirname(__file__)
+ usecases_file = os.path.join(
+ here, "cache_launch_config_insensitive_usecases.py"
+ )
+ modname = "cuda_launch_config_insensitive_cache_test_fodder"
+
+ def setUp(self):
+ DispatcherCacheUsecasesTest.setUp(self)
+ CUDATestCase.setUp(self)
+
+ def tearDown(self):
+ CUDATestCase.tearDown(self)
+ DispatcherCacheUsecasesTest.tearDown(self)
+
+ def test_launch_config_insensitive_cache_keys(self):
+ self.check_pycache(0)
+ mod = self.import_module()
+ self.check_pycache(0)
+
+ mod.launch(32)
+ self.check_pycache(2) # index, data
+ self.assertEqual(self.count_cache_markers(), 0)
+
+ mod.launch(64)
+ self.check_pycache(2)
+
+ mod2 = self.import_module()
+ self.assertIsNot(mod, mod2)
+ mod2.launch(64)
+ self.check_hits(mod2.cache_kernel, 1, 0)
+ self.check_pycache(2)
+
+ mtimes = self.get_cache_mtimes()
+ report = self.run_in_separate_process(
+ report_code=(
+ "hits = sum(mod.cache_kernel.stats.cache_hits.values())\n"
+ "misses = sum(mod.cache_kernel.stats.cache_misses.values())\n"
+ "report = {'hits': hits, 'misses': misses}"
+ )
+ )
+ self.assertEqual(report["hits"], 1)
+ self.assertEqual(report["misses"], 0)
+ self.assertEqual(self.get_cache_mtimes(), mtimes)
+
+
+@skip_on_cudasim("Simulator does not implement caching")
+class MultiDeviceCachingTest(DispatcherCacheUsecasesTest):
+ here = os.path.dirname(__file__)
+ usecases_file = os.path.join(
+ here, "cache_launch_config_insensitive_usecases.py"
+ )
+ modname = "cuda_multi_device_cache_test_fodder"
+
+ def setUp(self):
+ DispatcherCacheUsecasesTest.setUp(self)
+ CUDATestCase.setUp(self)
+
+ def tearDown(self):
+ CUDATestCase.tearDown(self)
+ DispatcherCacheUsecasesTest.tearDown(self)
+
+ def test_cache_separate_per_compute_capability(self):
+ gpus = list(cuda.gpus)
+ if len(gpus) < 2:
+ self.skipTest("requires at least two GPUs")
+
+ cc_map = {}
+ for gpu in gpus:
+ cc_map.setdefault(gpu.compute_capability, []).append(gpu.id)
+
+ if len(cc_map) < 2:
+ self.skipTest(
+ "requires at least two distinct compute capabilities"
+ )
+
+ for cc, ids in sorted(cc_map.items()):
+ dev_id = ids[0]
+ with cuda.gpus[dev_id]:
+ mod = self.import_module()
+ mod.launch(32)
+ hits = sum(mod.cache_kernel.stats.cache_hits.values())
+ misses = sum(mod.cache_kernel.stats.cache_misses.values())
+ self.assertEqual(
+ hits, 0, f"unexpected cache hit for CC {cc}"
+ )
+ self.assertEqual(
+ misses, 1, f"expected cache miss for CC {cc}"
+ )
+
+ mod2 = self.import_module()
+ mod2.launch(32)
+ hits = sum(mod2.cache_kernel.stats.cache_hits.values())
+ misses = sum(mod2.cache_kernel.stats.cache_misses.values())
+ self.assertEqual(
+ hits, 1, f"expected cache hit for CC {cc}"
+ )
+ self.assertEqual(
+ misses, 0, f"unexpected cache miss for CC {cc}"
+ )
+
+
@skip_on_cudasim("Simulator does not implement caching")
class CUDACooperativeGroupTest(DispatcherCacheUsecasesTest):
# See Issue #9432: https://github.com/numba/numba/issues/9432
diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py b/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py
index aa837dc1b..eea77b09f 100644
--- a/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py
+++ b/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py
@@ -16,6 +16,7 @@
)
from numba import cuda
from numba.cuda import config, types
+from numba.cuda import launchconfig
from numba.cuda.core.errors import TypingError
from numba.cuda.testing import (
cc_X_or_above,
@@ -128,6 +129,52 @@ def f(x, y):
class TestDispatcher(CUDATestCase):
"""Most tests based on those in numba.tests.test_dispatcher."""
+ @skip_on_cudasim("Dispatcher C-extension not used in the simulator")
+ def test_launch_config_available_during_compile(self):
+ @cuda.jit
+ def f(x):
+ x[0] = 1
+
+ seen = {}
+ orig = f._compile_for_args
+
+ def wrapped(*args, **kws):
+ seen["config"] = launchconfig.current_launch_config()
+ return orig(*args, **kws)
+
+ f._compile_for_args = wrapped
+
+ arr = np.zeros(1, dtype=np.int32)
+ self.assertIsNone(launchconfig.current_launch_config())
+ f[1, 1](arr)
+
+ cfg = seen.get("config")
+ self.assertIsNotNone(cfg)
+ self.assertIs(cfg.dispatcher, f)
+ self.assertEqual(cfg.griddim, (1, 1, 1))
+ self.assertEqual(cfg.blockdim, (1, 1, 1))
+ self.assertIsNone(launchconfig.current_launch_config())
+ with self.assertRaises(RuntimeError):
+ launchconfig.ensure_current_launch_config()
+
+ @skip_on_cudasim("Dispatcher C-extension not used in the simulator")
+ def test_capture_compile_config(self):
+ @cuda.jit
+ def f(x):
+ x[0] = 1
+
+ arr = np.zeros(1, dtype=np.int32)
+ original = f._compile_for_args
+ with launchconfig.capture_compile_config(f) as capture:
+ f[1, 1](arr)
+
+ cfg = capture["config"]
+ self.assertIsNotNone(cfg)
+ self.assertIs(cfg.dispatcher, f)
+ self.assertEqual(cfg.griddim, (1, 1, 1))
+ self.assertEqual(cfg.blockdim, (1, 1, 1))
+ self.assertIs(f._compile_for_args, original)
+
def test_coerce_input_types(self):
# Do not allow unsafe conversions if we can still compile other
# specializations.
diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_launch_config_sensitive.py b/numba_cuda/numba/cuda/tests/cudapy/test_launch_config_sensitive.py
new file mode 100644
index 000000000..744a52532
--- /dev/null
+++ b/numba_cuda/numba/cuda/tests/cudapy/test_launch_config_sensitive.py
@@ -0,0 +1,86 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-2-Clause
+
+import numpy as np
+
+from numba import cuda
+from numba.cuda import launchconfig
+from numba.cuda.core.rewrites import register_rewrite, Rewrite
+from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase
+
+
+LAUNCH_CONFIG_LOG = []
+
+
+def _clear_launch_config_log():
+ LAUNCH_CONFIG_LOG.clear()
+
+
+@register_rewrite("after-inference")
+class LaunchConfigSensitiveRewrite(Rewrite):
+ """Rewrite that marks kernels as launch-config sensitive and logs config.
+
+ This mimics cuda.coop's need to access launch config during rewrite, and
+ provides a global log for tests to assert on.
+ """
+
+ _TARGET_NAME = "launch_config_sensitive_kernel"
+
+ def __init__(self, state):
+ super().__init__(state)
+ self._state = state
+ self._logged = False
+ self._block = None
+
+ def match(self, func_ir, block, typemap, calltypes):
+ if func_ir.func_id.func_name != self._TARGET_NAME:
+ return False
+ if self._logged:
+ return False
+ self._block = block
+ return True
+
+ def apply(self):
+ cfg = launchconfig.ensure_current_launch_config()
+ LAUNCH_CONFIG_LOG.append(
+ {
+ "griddim": cfg.griddim,
+ "blockdim": cfg.blockdim,
+ "sharedmem": cfg.sharedmem,
+ }
+ )
+ # Mark compilation as launch-config sensitive so the dispatcher can
+ # avoid reusing the compiled kernel across different launch configs.
+ cfg.mark_kernel_as_launch_config_sensitive()
+ self._logged = True
+ return self._block
+
+
+@skip_on_cudasim("Dispatcher C-extension not used in the simulator")
+class TestLaunchConfigSensitive(CUDATestCase):
+ def setUp(self):
+ super().setUp()
+ _clear_launch_config_log()
+
+ def test_launch_config_sensitive_requires_recompile(self):
+ @cuda.jit
+ def launch_config_sensitive_kernel(x):
+ x[0] = 1
+
+ arr = np.zeros(1, dtype=np.int32)
+
+ launch_config_sensitive_kernel[1, 32](arr)
+ self.assertEqual(len(LAUNCH_CONFIG_LOG), 1)
+ self.assertEqual(LAUNCH_CONFIG_LOG[0]["blockdim"], (32, 1, 1))
+ self.assertEqual(LAUNCH_CONFIG_LOG[0]["griddim"], (1, 1, 1))
+
+ launch_config_sensitive_kernel[1, 64](arr)
+ # Expect a new compilation for the new launch config, which will log
+ # a second entry with the updated block dimension.
+ self.assertEqual(len(LAUNCH_CONFIG_LOG), 2)
+ self.assertEqual(LAUNCH_CONFIG_LOG[1]["blockdim"], (64, 1, 1))
+ self.assertEqual(LAUNCH_CONFIG_LOG[1]["griddim"], (1, 1, 1))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/pixi.toml b/pixi.toml
index cf85a0f00..7042e2814 100644
--- a/pixi.toml
+++ b/pixi.toml
@@ -261,6 +261,9 @@ bench-against = { cmd = [
"$PIXI_PROJECT_ROOT/scripts/bench-against.py",
] }
+[target.linux.tasks.bench-launch-overhead]
+cmd = ["python", "$PIXI_PROJECT_ROOT/scripts/bench-launch-overhead.py"]
+
[target.linux.tasks.build-docs]
cmd = ["make", "-C", "$PIXI_PROJECT_ROOT/docs", "html"]
diff --git a/plots/launch-overhead-abs.png b/plots/launch-overhead-abs.png
new file mode 100644
index 000000000..d3238c5d1
Binary files /dev/null and b/plots/launch-overhead-abs.png differ
diff --git a/plots/launch-overhead-abs.svg b/plots/launch-overhead-abs.svg
new file mode 100644
index 000000000..d099d1572
--- /dev/null
+++ b/plots/launch-overhead-abs.svg
@@ -0,0 +1,1568 @@
+
+
+
diff --git a/plots/launch-overhead-pct.png b/plots/launch-overhead-pct.png
new file mode 100644
index 000000000..56e6140cc
Binary files /dev/null and b/plots/launch-overhead-pct.png differ
diff --git a/plots/launch-overhead-pct.svg b/plots/launch-overhead-pct.svg
new file mode 100644
index 000000000..7c1732b72
--- /dev/null
+++ b/plots/launch-overhead-pct.svg
@@ -0,0 +1,1851 @@
+
+
+
diff --git a/plots/launch-overhead-us.png b/plots/launch-overhead-us.png
new file mode 100644
index 000000000..3c0a8e4f4
Binary files /dev/null and b/plots/launch-overhead-us.png differ
diff --git a/plots/launch-overhead-us.svg b/plots/launch-overhead-us.svg
new file mode 100644
index 000000000..3d76b4233
--- /dev/null
+++ b/plots/launch-overhead-us.svg
@@ -0,0 +1,1866 @@
+
+
+
diff --git a/scripts/bench-launch-overhead.py b/scripts/bench-launch-overhead.py
new file mode 100644
index 000000000..2f8b406cc
--- /dev/null
+++ b/scripts/bench-launch-overhead.py
@@ -0,0 +1,313 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-2-Clause
+
+import argparse
+import json
+import os
+import statistics
+import subprocess
+import sys
+import time
+
+
+DEFAULT_LOOPS = {
+ 0: 100_000,
+ 1: 100_000,
+ 2: 100_000,
+ 3: 100_000,
+ 4: 10_000,
+}
+DEFAULT_REPEATS = 7
+
+
+def _parse_repo(spec):
+ if "=" not in spec:
+ raise ValueError("Repo spec must be in the form label=/path/to/repo.")
+ label, path = spec.split("=", 1)
+ label = label.strip()
+ path = os.path.abspath(os.path.expanduser(path.strip()))
+ if not label:
+ raise ValueError("Repo spec label cannot be empty.")
+ return label, path
+
+
+def _parse_loops(spec):
+ if spec is None:
+ return DEFAULT_LOOPS
+ parts = [p.strip() for p in spec.split(",") if p.strip()]
+ if len(parts) != 5:
+ raise ValueError("Loops must be a comma-separated list of 5 integers.")
+ loops = {}
+ for idx, value in enumerate(parts):
+ loops[idx] = int(value)
+ return loops
+
+
+def _git_rev(path):
+ return subprocess.check_output(
+ ["git", "rev-parse", "HEAD"],
+ cwd=path,
+ text=True,
+ ).strip()
+
+
+def _pip_install(repo_path, python):
+ subprocess.run(
+ [
+ python,
+ "-m",
+ "pip",
+ "install",
+ "-e",
+ repo_path,
+ "--no-deps",
+ ],
+ check=True,
+ )
+
+
+def _run_worker(label, loops, repeats, json_only):
+ import numpy as np
+ import numba
+ from numba import cuda
+
+ try:
+ from numba.cuda.core import config
+ except ModuleNotFoundError:
+ # Older branches use numba.core.config instead of numba.cuda.core.
+ from numba.core import config
+
+ if config.ENABLE_CUDASIM:
+ raise RuntimeError("CUDA simulator enabled; benchmarks require GPU.")
+
+ cuda.current_context()
+
+ arrs = [cuda.device_array(10_000, dtype=np.float32) for _ in range(4)]
+
+ @cuda.jit("void()")
+ def some_kernel_1():
+ return
+
+ @cuda.jit("void(float32[:])")
+ def some_kernel_2(arr1):
+ return
+
+ @cuda.jit("void(float32[:],float32[:])")
+ def some_kernel_3(arr1, arr2):
+ return
+
+ @cuda.jit("void(float32[:],float32[:],float32[:])")
+ def some_kernel_4(arr1, arr2, arr3):
+ return
+
+ @cuda.jit("void(float32[:],float32[:],float32[:],float32[:])")
+ def some_kernel_5(arr1, arr2, arr3, arr4):
+ return
+
+ kernels = [
+ ("0", some_kernel_1, ()),
+ ("1", some_kernel_2, (arrs[0],)),
+ ("2", some_kernel_3, (arrs[0], arrs[1])),
+ ("3", some_kernel_4, (arrs[0], arrs[1], arrs[2])),
+ ("4", some_kernel_5, (arrs[0], arrs[1], arrs[2], arrs[3])),
+ ]
+
+ results = {}
+ for idx, (name, kernel, args) in enumerate(kernels):
+ loop_count = loops[idx]
+ kernel[1, 1](*args)
+ cuda.synchronize()
+ samples = []
+ for _ in range(repeats):
+ start = time.perf_counter()
+ for _ in range(loop_count):
+ kernel[1, 1](*args)
+ cuda.synchronize()
+ elapsed = time.perf_counter() - start
+ samples.append(elapsed / loop_count)
+ mean_s = statistics.mean(samples)
+ stdev_s = statistics.stdev(samples) if repeats > 1 else 0.0
+ results[name] = {
+ "loops": loop_count,
+ "mean_us": mean_s * 1e6,
+ "stdev_us": stdev_s * 1e6,
+ }
+
+ device = cuda.get_current_device()
+ payload = {
+ "label": label,
+ "numba_version": numba.__version__,
+ "device": {
+ "name": device.name.decode()
+ if hasattr(device.name, "decode")
+ else device.name,
+ "cc": device.compute_capability,
+ },
+ "cuda_runtime_version": cuda.runtime.get_version(),
+ "results": results,
+ "repeats": repeats,
+ }
+
+ if not json_only:
+ print(
+ f"{label}: {payload['numba_version']} "
+ f"CUDA {payload['cuda_runtime_version']}"
+ )
+ print(json.dumps(payload))
+
+
+def _format_us(value):
+ return f"{value:.2f}"
+
+
+def _print_table(results):
+ labels = [r["label"] for r in results]
+ baseline = results[0]["label"]
+ print("Launch overhead (us/launch):")
+ header = ["args"] + labels
+ rows = []
+ for arg in ["0", "1", "2", "3", "4"]:
+ row = [arg]
+ for r in results:
+ mean = r["results"][arg]["mean_us"]
+ stdev = r["results"][arg]["stdev_us"]
+ row.append(f"{_format_us(mean)} +/- {_format_us(stdev)}")
+ rows.append(row)
+
+ col_widths = [
+ max(len(row[i]) for row in [header] + rows) for i in range(len(header))
+ ]
+ fmt = " ".join(f"{{:<{width}}}" for width in col_widths)
+ print(fmt.format(*header))
+ for row in rows:
+ print(fmt.format(*row))
+
+ if len(results) > 1:
+ print("")
+ print(f"Deltas vs {baseline}:")
+ header = ["args"] + labels[1:]
+ delta_rows = []
+ for arg in ["0", "1", "2", "3", "4"]:
+ base_mean = results[0]["results"][arg]["mean_us"]
+ row = [arg]
+ for r in results[1:]:
+ mean = r["results"][arg]["mean_us"]
+ delta = mean - base_mean
+ pct = (delta / base_mean) * 100 if base_mean else 0.0
+ row.append(f"{_format_us(delta)} ({pct:+.1f}%)")
+ delta_rows.append(row)
+
+ col_widths = [
+ max(len(row[i]) for row in [header] + delta_rows)
+ for i in range(len(header))
+ ]
+ fmt = " ".join(f"{{:<{width}}}" for width in col_widths)
+ print(fmt.format(*header))
+ for row in delta_rows:
+ print(fmt.format(*row))
+
+
+def _run_driver(args):
+ repos = [_parse_repo(spec) for spec in args.repo]
+ loops = _parse_loops(args.loops)
+ results = []
+ for label, path in repos:
+ sha = _git_rev(path)
+ if not args.no_install:
+ _pip_install(path, args.python)
+ output = subprocess.check_output(
+ [
+ args.python,
+ os.path.abspath(__file__),
+ "--run",
+ "--label",
+ label,
+ "--loops",
+ args.loops
+ if args.loops
+ else ",".join(str(loops[i]) for i in range(5)),
+ "--repeats",
+ str(args.repeats),
+ ],
+ text=True,
+ )
+ last_line = output.strip().splitlines()[-1]
+ payload = json.loads(last_line)
+ payload["repo"] = path
+ payload["sha"] = sha
+ results.append(payload)
+
+ _print_table(results)
+ if args.output:
+ with open(args.output, "w", encoding="utf-8") as handle:
+ json.dump(results, handle, indent=2, sort_keys=True)
+ print(f"Wrote {args.output}")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description=("Benchmark kernel launch overhead across repos.")
+ )
+ parser.add_argument(
+ "--repo",
+ action="append",
+ default=[],
+ help="Repo spec as label=/path (repeatable).",
+ )
+ parser.add_argument(
+ "--no-install",
+ action="store_true",
+ help="Skip pip install -e for repos.",
+ )
+ parser.add_argument(
+ "--python",
+ default=sys.executable,
+ help="Python executable to use.",
+ )
+ parser.add_argument(
+ "--loops",
+ default=None,
+ help="Comma-separated loops for 0..4 args.",
+ )
+ parser.add_argument(
+ "--repeats",
+ type=int,
+ default=DEFAULT_REPEATS,
+ help="Number of repeats for each kernel.",
+ )
+ parser.add_argument(
+ "--output",
+ default=None,
+ help="Optional JSON output path.",
+ )
+ parser.add_argument(
+ "--run",
+ action="store_true",
+ help="Run in worker mode (internal).",
+ )
+ parser.add_argument(
+ "--label",
+ default="",
+ help="Label for worker mode.",
+ )
+ parser.add_argument(
+ "--json-only",
+ action="store_true",
+ help="Suppress non-JSON output in worker mode.",
+ )
+ args = parser.parse_args()
+
+ if args.run:
+ loops = _parse_loops(args.loops)
+ _run_worker(args.label, loops, args.repeats, args.json_only)
+ return
+
+ if not args.repo:
+ parser.error("--repo must be provided at least once.")
+
+ _run_driver(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/plot-launch-overhead.py b/scripts/plot-launch-overhead.py
new file mode 100644
index 000000000..a0f119803
--- /dev/null
+++ b/scripts/plot-launch-overhead.py
@@ -0,0 +1,313 @@
+#!/usr/bin/env python3
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-2-Clause
+
+"""Plot launch-overhead results from bench-launch-overhead.py."""
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import re
+from pathlib import Path
+
+
+def _maybe_set_backend() -> None:
+ if os.environ.get("DISPLAY") is None and os.name != "nt":
+ import matplotlib
+
+ matplotlib.use("Agg")
+
+
+def _load_input(path: Path) -> list[dict]:
+ text = path.read_text(encoding="utf-8")
+ try:
+ data = json.loads(text)
+ return _normalize_json(data)
+ except json.JSONDecodeError:
+ return _parse_table(text)
+
+
+def _normalize_json(data: object) -> list[dict]:
+ if isinstance(data, dict):
+ data = [data]
+ if not isinstance(data, list):
+ raise ValueError(
+ "Expected a JSON list of results or a single result dict."
+ )
+ for entry in data:
+ if not isinstance(entry, dict):
+ raise ValueError("Each JSON entry must be a dict.")
+ if "label" not in entry or "results" not in entry:
+ raise ValueError(
+ "Each JSON entry must include 'label' and 'results'."
+ )
+ return data
+
+
+def _split_cols(line: str) -> list[str]:
+ return [col for col in re.split(r"\s{2,}|\t+", line.strip()) if col]
+
+
+def _parse_table(text: str) -> list[dict]:
+ lines = text.splitlines()
+ start = None
+ for idx, line in enumerate(lines):
+ if line.strip().startswith("Launch overhead (us/launch):"):
+ start = idx
+ break
+ if start is None or start + 2 >= len(lines):
+ raise ValueError("Could not locate launch-overhead table in text.")
+
+ header = _split_cols(lines[start + 1])
+ if not header or header[0] != "args":
+ raise ValueError("Malformed table header.")
+ labels = header[1:]
+ results: dict[str, dict[str, dict[str, float]]] = {
+ label: {} for label in labels
+ }
+
+ row_idx = start + 2
+ while row_idx < len(lines):
+ line = lines[row_idx].rstrip()
+ if not line.strip() or line.strip().startswith("Deltas vs baseline"):
+ break
+ cols = _split_cols(line)
+ if len(cols) < 2:
+ row_idx += 1
+ continue
+ arg = cols[0]
+ for label, cell in zip(labels, cols[1:]):
+ match = re.search(
+ r"([+-]?\d+(?:\.\d+)?)\s*\+/-\s*([+-]?\d+(?:\.\d+)?)",
+ cell,
+ )
+ if not match:
+ raise ValueError(f"Could not parse cell '{cell}'.")
+ mean = float(match.group(1))
+ stdev = float(match.group(2))
+ results[label][arg] = {"mean_us": mean, "stdev_us": stdev}
+ row_idx += 1
+
+ return [{"label": label, "results": results[label]} for label in labels]
+
+
+def _extract_series(
+ results: list[dict],
+) -> tuple[list[int], list[str], dict, dict]:
+ labels = [entry["label"] for entry in results]
+ arg_keys = sorted(
+ {int(k) for entry in results for k in entry["results"].keys()}
+ )
+ means: dict[str, list[float]] = {}
+ stdevs: dict[str, list[float]] = {}
+ for entry in results:
+ label = entry["label"]
+ means[label] = []
+ stdevs[label] = []
+ for arg in arg_keys:
+ record = entry["results"].get(str(arg))
+ if not record:
+ raise ValueError(f"Missing arg {arg} for label '{label}'.")
+ means[label].append(float(record.get("mean_us", 0.0)))
+ stdevs[label].append(float(record.get("stdev_us", 0.0)))
+ return arg_keys, labels, means, stdevs
+
+
+def _select_baseline(labels: list[str], baseline: str | None) -> str:
+ if baseline:
+ if baseline not in labels:
+ raise ValueError(
+ f"Baseline '{baseline}' not found in labels: {labels}"
+ )
+ return baseline
+ return labels[0]
+
+
+def _summarize_device(results: list[dict]) -> str | None:
+ devices = []
+ runtimes = []
+ for entry in results:
+ device = entry.get("device")
+ runtime = entry.get("cuda_runtime_version")
+ if device:
+ name = device.get("name")
+ cc = device.get("cc")
+ if name and cc:
+ devices.append((name, tuple(cc)))
+ if runtime:
+ runtimes.append(tuple(runtime))
+ device_line = None
+ if devices and len(set(devices)) == 1:
+ name, cc = devices[0]
+ device_line = f"{name} (CC {cc[0]}.{cc[1]})"
+ runtime_line = None
+ if runtimes and len(set(runtimes)) == 1:
+ runtime_line = f"CUDA runtime {runtimes[0][0]}.{runtimes[0][1]}"
+ if device_line and runtime_line:
+ return f"{device_line} • {runtime_line}"
+ if device_line:
+ return device_line
+ if runtime_line:
+ return runtime_line
+ return None
+
+
+def _sanitize_svg(path: Path) -> None:
+ if path.suffix.lower() != ".svg":
+ return
+ text = path.read_text(encoding="utf-8")
+ sanitized = "\n".join(line.rstrip() for line in text.splitlines())
+ if text.endswith("\n"):
+ sanitized += "\n"
+ if sanitized != text:
+ path.write_text(sanitized, encoding="utf-8")
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(
+ description="Plot launch-overhead results from bench-launch-overhead.py."
+ )
+ parser.add_argument(
+ "input",
+ type=Path,
+ help="Path to results JSON (preferred) or bench stdout text.",
+ )
+ parser.add_argument(
+ "--output",
+ type=Path,
+ default=None,
+ help="Output image path (png/svg/pdf). If omitted, show the plot.",
+ )
+ parser.add_argument(
+ "--baseline",
+ default=None,
+ help="Label to use as baseline (defaults to first entry).",
+ )
+ parser.add_argument(
+ "--delta",
+ choices=("pct", "us", "none"),
+ default="pct",
+ help="Show delta vs baseline as percent, microseconds, or not at all.",
+ )
+ parser.add_argument(
+ "--title",
+ default=None,
+ help="Optional title override.",
+ )
+ parser.add_argument(
+ "--no-sns",
+ action="store_true",
+ help="Disable seaborn styling even if installed.",
+ )
+ parser.add_argument(
+ "--dpi",
+ type=int,
+ default=150,
+ help="Output DPI when saving images.",
+ )
+ args = parser.parse_args()
+
+ _maybe_set_backend()
+
+ import numpy as np
+ import matplotlib.pyplot as plt
+
+ try:
+ import seaborn as sns # type: ignore
+ except Exception:
+ sns = None
+
+ results = _load_input(args.input)
+ arg_keys, labels, means, stdevs = _extract_series(results)
+ baseline_label = _select_baseline(labels, args.baseline)
+
+ if sns is not None and not args.no_sns:
+ sns.set_theme(style="whitegrid")
+ palette = sns.color_palette(n_colors=len(labels))
+ else:
+ palette = plt.cm.tab10.colors
+
+ nrows = 2 if args.delta != "none" else 1
+ height = 6.5 if nrows == 2 else 4.0
+ fig, axes = plt.subplots(
+ nrows=nrows,
+ ncols=1,
+ figsize=(8.5, height),
+ sharex=True,
+ constrained_layout=True,
+ )
+ if nrows == 1:
+ ax_abs = axes
+ ax_delta = None
+ else:
+ ax_abs, ax_delta = axes
+
+ x = np.array(arg_keys)
+ for idx, label in enumerate(labels):
+ ax_abs.errorbar(
+ x,
+ means[label],
+ yerr=stdevs[label],
+ label=label,
+ color=palette[idx % len(palette)],
+ marker="o",
+ linewidth=2,
+ capsize=3,
+ )
+ ax_abs.set_ylabel("Launch overhead (us/launch)")
+ ax_abs.set_xticks(x)
+ ax_abs.grid(True, axis="y", alpha=0.3)
+ ax_abs.legend(frameon=False, title="Repo", ncol=min(3, len(labels)))
+
+ if ax_delta is not None:
+ base = np.array(means[baseline_label])
+ ax_delta.axhline(0.0, color="0.5", linestyle="--", linewidth=1)
+ for idx, label in enumerate(labels):
+ if label == baseline_label:
+ continue
+ delta = np.array(means[label]) - base
+ if args.delta == "pct":
+ delta = np.where(base != 0, (delta / base) * 100, 0.0)
+ ylabel = "Delta vs baseline (%)"
+ else:
+ ylabel = "Delta vs baseline (us/launch)"
+ ax_delta.plot(
+ x,
+ delta,
+ label=label,
+ color=palette[idx % len(palette)],
+ marker="o",
+ linewidth=2,
+ )
+ ax_delta.set_ylabel(ylabel)
+ ax_delta.set_xlabel("Kernel args")
+ ax_delta.grid(True, axis="y", alpha=0.3)
+ ax_delta.legend(frameon=False, title="Repo", ncol=min(3, len(labels)))
+ else:
+ ax_abs.set_xlabel("Kernel args")
+
+ title = args.title or "CUDA kernel launch overhead"
+ subtitle = _summarize_device(results)
+ if subtitle:
+ title = f"{title}\n{subtitle}"
+ fig.suptitle(title, fontsize=12)
+
+ if args.output:
+ fig.savefig(args.output, dpi=args.dpi)
+ _sanitize_svg(args.output)
+ print(f"Wrote {args.output}")
+ return 0
+
+ try:
+ plt.show()
+ except Exception:
+ fallback = Path("launch-overhead.png")
+ fig.savefig(fallback, dpi=args.dpi)
+ print(f"Wrote {fallback}")
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())