Skip to content

Conversation

@rparolin
Copy link
Contributor

@rparolin rparolin commented Nov 25, 2025

PR #609: Migrate kernel launch to cuda.core

TL;DR: Migrates numba-cuda's kernel launching from legacy launch_kernel API to cuda.core.experimental.launch. Performance concerns addressed through optimization work. Tests pass. Removes old abstractions.


What Changed

Core Files Modified:

  1. driver.py - Major surgery here

    • Removed: launch_kernel(), abstract Module/Function base classes, CtypesModule, CtypesFunction
    • Added: _to_core_stream() helper for stream conversion
    • Changed: CudaPythonModule now wraps ObjectCode directly
    • Uses cuLibraryGetGlobal/cuLibraryUnload instead of cuModuleGetGlobal/cuModuleUnload
    • Function attributes now read from kernel.attributes instead of CUDA driver calls
  2. dispatcher.py - Launch infrastructure rewrite

    • Now uses LaunchConfig and cuda.core.experimental.launch()
    • Stream conversion happens early in _LaunchConfiguration.__init__
    • Kernel args preparation simplified (less ctypes overhead)
    • Added __getstate__/__setstate__ for pickle support, but see figure out exactly what pickling a stream means #648
  3. Tests - Updated everywhere

    • Changed CudaAPIErrorCUDAError
    • Module callbacks now receive ObjectCode instead of raw handles
    • Stream conversions use ExperimentalStream.from_handle() or _to_core_stream()

Performance Impact

Initial concern: Early benchmarks showed 30-40% regression on single-arg cases.

Resolution: Addressed through optimization work in recent commits:

  • Removed instance-checking helpers (prepare_args) to minimize overhead
  • Simplified kernel args preparation (less ctypes wrapping, direct value passing)
  • Reduced stream conversion overhead by moving the conversion outside of the kernel invocation.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 25, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@rparolin rparolin marked this pull request as ready for review November 25, 2025 20:43
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 25, 2025

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@rparolin rparolin marked this pull request as draft November 25, 2025 20:43
@rparolin
Copy link
Contributor Author

/ok to test 8f72cb7

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 25, 2025

Greptile Overview

Greptile Summary

This PR migrates numba-cuda's kernel launching infrastructure from the legacy launch_kernel API to cuda.core.experimental.launch. Key changes include:

  • driver.py: Removed launch_kernel function and the abstract Module/Function base classes. CudaPythonModule now wraps ObjectCode directly and uses cuLibraryGetGlobal/cuLibraryUnload instead of cuModuleGetGlobal/cuModuleUnload. Added _to_core_stream() for converting numba Stream objects to ExperimentalStream.
  • dispatcher.py: _Kernel.launch now uses LaunchConfig and cuda.core.experimental.launch. Stream conversion happens early in _LaunchConfiguration.__init__. Added __getstate__/__setstate__ methods for pickling support.
  • nrt.py: Updated _single_thread_launch to use the new launch API.
  • Tests: Updated to use new APIs and CUDAError instead of CudaAPIError. Module callbacks now receive ObjectCode instead of raw handles.
  • pixi.toml: Added ipython and pyinstrument dev dependencies.

Confidence Score: 4/5

  • This PR is a significant but well-structured refactoring with comprehensive test updates - safe to merge after verification that all tests pass.
  • The changes are substantial (migrating to a new launch API) but follow a consistent pattern throughout. The code removes legacy abstractions in favor of using cuda.core directly. Tests have been properly updated to reflect the new error types and APIs. No obvious logic errors were found in the core changes.
  • Pay attention to driver.py and dispatcher.py as they contain the core launch infrastructure changes.

Important Files Changed

File Analysis

Filename Score Overview
numba_cuda/numba/cuda/cudadrv/driver.py 4/5 Major refactoring to use cuda.core.launch API - removes old launch_kernel, Module ABC, CtypesModule/Function classes; adds _to_core_stream for stream conversion and simplifies to use ObjectCode directly.
numba_cuda/numba/cuda/dispatcher.py 4/5 Updated kernel launching to use LaunchConfig and cuda.core.experimental.launch; stream conversion moved to _LaunchConfiguration init; added pickle support via getstate/setstate.
numba_cuda/numba/cuda/memory_management/nrt.py 5/5 Updated _single_thread_launch to use new LaunchConfig and launch API with _to_core_stream conversion.
numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py 5/5 Tests updated to use cuda.core.experimental LaunchConfig and launch APIs; stream conversion uses _to_core_stream or ExperimentalStream.from_handle.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@rparolin
Copy link
Contributor Author

/ok to test 015a37c

@rparolin
Copy link
Contributor Author

/ok to test aa92ce5

@rparolin rparolin force-pushed the rparolin/migrate_driver_launch_to_cuda_core branch from aa92ce5 to b95207c Compare November 25, 2025 21:30
@rparolin
Copy link
Contributor Author

/ok to test b95207c

@rparolin
Copy link
Contributor Author

/ok to test 7917bb1

@rparolin
Copy link
Contributor Author

/ok to test fa14321

@rparolin rparolin marked this pull request as ready for review November 25, 2025 23:06
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@leofang
Copy link
Member

leofang commented Nov 25, 2025

  • I suggest to hold off from merging for now. We should avoid importing private attributes and functions if possible. Any gap identified by this PR (which is a good exercise!) needs to be addressed in cuda-core.
  • What is the perf number before/after this change? As it stands, it is hard to guess if this PR brings any perf improvement. I would hate to cause perf regression just because we want dogfooding 🙂

@rparolin
Copy link
Contributor Author

  • I suggest to hold off from merging for now. We should avoid importing private attributes and functions if possible. Any gap identified by this PR (which is a good exercise!) needs to be addressed in cuda-core.
  • What is the perf number before/after this change? As it stands, it is hard to guess if this PR brings any perf improvement. I would hate to cause perf regression just because we want dogfooding 🙂
  │ Component                           │   Baseline   │   Current    │    Delta     │   Change   │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │ PYTHON DISPATCH LAYER               │              │              │              │            │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │   dispatcher.py → launch_kernel()   │     ~2.0 μs  │     ~2.0 μs  │      ~0.0 μs │      ~0%   │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │ CUDA.CORE OBJECTS (NEW OVERHEAD)    │              │              │              │            │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │   LaunchConfig creation             │      ---     │     0.411 μs │    +0.411 μs │    +100%   │
  │   ObjectCode stub creation          │      ---     │     0.093 μs │    +0.093 μs │    +100%   │
  │   Kernel._from_obj()                │      ---     │    10.971 μs │   +10.971 μs │    +100%   │
  │   Stream.from_handle()              │      ---     │     2.706 μs │    +2.706 μs │    +100%   │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │   Subtotal (cuda.core objects)      │      0.0 μs  │    14.181 μs │   +14.181 μs │    +100%   │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │ C/GPU DISPATCH LAYER                │              │              │              │            │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │   CUDA driver API call              │    ~25.2 μs  │    ~25.0 μs  │     -0.2 μs  │     -1%    │
  │   GPU kernel dispatch               │  (included)  │  (included)  │       ---    │     ---    │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │ TOTAL END-TO-END                    │              │              │              │            │
  ├─────────────────────────────────────┼──────────────┼──────────────┼──────────────┼────────────┤
  │   Median launch time                │    27.23 μs  │    39.18 μs  │   +11.95 μs  │   +43.9%   │
  │   Minimum launch time               │    23.70 μs  │    30.15 μs  │    +6.45 μs  │   +27.2%   │
  │   P95 launch time                   │    74.26 μs  │    57.90 μs  │   -16.36 μs  │   -22.0%   │
  └─────────────────────────────────────┴──────────────┴──────────────┴──────────────┴────────────┘
    **🔬 Bottom Line**
  The entire 11.95 μs overhead is now fully explained and measured:
  • 92% from Kernel._from_obj() (11.0 μs)
  • 19% from Stream.from_handle() (2.7 μs)
  • 3% from LaunchConfig (0.4 μs)
  • 1% from ObjectCode stub (0.1 μs)

@gmarkall
Copy link
Contributor

What is the perf number before/after this change?

It would be good to run the launch benchmarks to quantify this change (edit testing/pytest.ini to remove --benchmark-disable) then run pytest --benchmark-only. You should get something like:

------------------------------------------------------------------------------------------------- benchmark: 4 tests ------------------------------------------------------------------------------------------------
Name (time in us)                        Min                    Max                   Mean             StdDev                 Median                IQR            Outliers         OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_one_arg[device_array]          749.1760 (1.0)         998.8480 (1.0)         769.6751 (1.0)      17.1446 (1.0)         765.7590 (1.0)       8.1370 (1.0)        83;102  1,299.2495 (1.0)         943           1
test_one_arg[cupy]                1,896.8900 (2.53)      2,359.8850 (2.36)      1,946.2075 (2.53)     41.6876 (2.43)      1,937.1740 (2.53)     23.5080 (2.89)        36;34    513.8198 (0.40)        444           1
test_many_args[device_array]      6,195.0860 (8.27)      6,472.5740 (6.48)      6,272.8671 (8.15)     44.8138 (2.61)      6,264.1275 (8.18)     50.9895 (6.27)         40;6    159.4167 (0.12)        156           1
test_many_args[cupy]             27,561.5590 (36.79)    27,951.6780 (27.98)    27,702.3494 (35.99)    80.1222 (4.67)     27,687.4050 (36.16)    96.2945 (11.83)         7;1     36.0980 (0.03)         37           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@gmarkall gmarkall added the 4 - Waiting on author Waiting for author to respond to review label Nov 26, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@rparolin
Copy link
Contributor Author

What is the perf number before/after this change?

It would be good to run the launch benchmarks to quantify this change (edit testing/pytest.ini to remove --benchmark-disable) then run pytest --benchmark-only. You should get something like:

------------------------------------------------------------------------------------------------- benchmark: 4 tests ------------------------------------------------------------------------------------------------
Name (time in us)                        Min                    Max                   Mean             StdDev                 Median                IQR            Outliers         OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_one_arg[device_array]          749.1760 (1.0)         998.8480 (1.0)         769.6751 (1.0)      17.1446 (1.0)         765.7590 (1.0)       8.1370 (1.0)        83;102  1,299.2495 (1.0)         943           1
test_one_arg[cupy]                1,896.8900 (2.53)      2,359.8850 (2.36)      1,946.2075 (2.53)     41.6876 (2.43)      1,937.1740 (2.53)     23.5080 (2.89)        36;34    513.8198 (0.40)        444           1
test_many_args[device_array]      6,195.0860 (8.27)      6,472.5740 (6.48)      6,272.8671 (8.15)     44.8138 (2.61)      6,264.1275 (8.18)     50.9895 (6.27)         40;6    159.4167 (0.12)        156           1
test_many_args[cupy]             27,561.5590 (36.79)    27,951.6780 (27.98)    27,702.3494 (35.99)    80.1222 (4.67)     27,687.4050 (36.16)    96.2945 (11.83)         7;1     36.0980 (0.03)         37           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Name (time in ms)                    Min                Max               Mean            StdDev             Median               IQR            Outliers       OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_one_arg[device_array]        2.5915 (1.0)       4.6078 (1.0)       3.1092 (1.0)      0.4016 (1.0)       2.9746 (1.0)      0.4342 (1.0)          15;5  321.6312 (1.0)          92           1
test_many_args[device_array]     16.0708 (6.20)     29.9002 (6.49)     20.1777 (6.49)     2.9787 (7.42)     19.5070 (6.56)     3.1564 (7.27)         14;4   49.5596 (0.15)         54           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

I'm having issues getting the cupy benchmarks to run successfully but here is what I can report now.

@gmarkall
Copy link
Contributor

/ok to test

@rparolin
Copy link
Contributor Author

/ok to test 9421787

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

After some more investigation, it looks like _to_core_stream is the main source of the bottleneck.

Is caching the numba Stream -> cuda core Stream conversion a viable option?

I will also look into whether we can move the conversion higher up in the loop.

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

It looks like streams are already being cached as part of the CUDADispatcher.configure call, so it seems like it's okay to do so (or the existing code is incorrect).

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

Ok, moving the core stream conversion into the _LaunchConfiguration constructor shaves off most of the difference now.

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

Naturally, that breaks pickling jit 🥳

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

Pushed up a possible implementation of __getstate__/__setstate__ for _LaunchConfiguration.

@cpcloud cpcloud force-pushed the rparolin/migrate_driver_launch_to_cuda_core branch from fca2332 to 83fb41a Compare December 10, 2025 18:06
@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

New single argument benchmarks:

image

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

The variance seems to be rather large between runs:

image

Note that in this case there's actually a negligible improvement.

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

This seems within noise to me, since I can't reproduce a consistent slowdown between 3 and 15%:

image

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

/ok to test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

One thing that also may have gotten lost in my sea of comments is that there'll be some slight improvements from the next release of cuda.core due to Cythonization of some of the launch configuraion stuff.

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

/ok to test

@rparolin
Copy link
Contributor Author

Just to confirm, we aren't able to reliably reproduce the 30%+ performance regression once we eliminated the overhead by the stream creation helper?

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@cpcloud
Copy link
Contributor

cpcloud commented Dec 10, 2025

That is correct. There's still overhead to it, but it's not being invoked on every single launch call. It happens when you slice into the function to get the grid and block dimensions.

Before:

arr = cuda.device_array(...)

func = one_arg[1, 1]

for _ in range(HUGE_NUMBER):
    func(arr) # _to_core_stream invoked here

After:

arr = cuda.device_array(...)

func = one_arg[1, 1]  # _to_core_stream invoked here

for _ in range(HUGE_NUMBER):
    func(arr)

Copy link
Contributor

@kkraus14 kkraus14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cpcloud cpcloud merged commit 19e256a into NVIDIA:main Dec 10, 2025
71 checks passed
gmarkall added a commit to gmarkall/numba-cuda that referenced this pull request Dec 17, 2025
- Fix NVIDIA#624: Accept Numba IR nodes in all places Numba-CUDA IR nodes are expected (NVIDIA#643)
- Fix Issue NVIDIA#588: separate compilation of NVVM IR modules when generating debuginfo (NVIDIA#591)
- feat: allow printing nested tuples (NVIDIA#667)
- build(deps): bump actions/setup-python from 5.6.0 to 6.1.0 (NVIDIA#655)
- build(deps): bump actions/upload-artifact from 4 to 5 (NVIDIA#652)
- Test RAPIDS 25.12 (NVIDIA#661)
- Do not manually set DUMP_ASSEMBLY in `nvjitlink` tests (NVIDIA#662)
- feat: add print support for int64 tuples (NVIDIA#663)
- Only run dependabot monthly and open fewer PRs (NVIDIA#658)
- test: fix bogus `self` argument to `Context` (NVIDIA#656)
- Fix false negative NRT link decision when NRT was previously toggled on (NVIDIA#650)
- Add support for dependabot (NVIDIA#647)
- refactor: cull dead linker objects (NVIDIA#649)
- Migrate numba-cuda driver to use cuda.core.launch API (NVIDIA#609)
- feat: add set_shared_memory_carveout (NVIDIA#629)
- chore: bump version in pixi.toml (NVIDIA#641)
- refactor: remove devicearray code to reduce complexity (NVIDIA#600)
@gmarkall gmarkall mentioned this pull request Dec 17, 2025
gmarkall added a commit that referenced this pull request Dec 17, 2025
- Capture global device arrays in kernels and device functions (#666)
- Fix #624: Accept Numba IR
nodes in all places Numba-CUDA IR nodes are expected
(#643)
- Fix Issue #588: separate
compilation of NVVM IR modules when generating debuginfo
(#591)
- feat: allow printing nested tuples
(#667)
- build(deps): bump actions/setup-python from 5.6.0 to 6.1.0
(#655)
- build(deps): bump actions/upload-artifact from 4 to 5
(#652)
- Test RAPIDS 25.12 (#661)
- Do not manually set DUMP_ASSEMBLY in `nvjitlink` tests
(#662)
- feat: add print support for int64 tuples
(#663)
- Only run dependabot monthly and open fewer PRs
(#658)
- test: fix bogus `self` argument to `Context`
(#656)
- Fix false negative NRT link decision when NRT was previously toggled
on (#650)
- Add support for dependabot
(#647)
- refactor: cull dead linker objects
(#649)
- Migrate numba-cuda driver to use cuda.core.launch API
(#609)
- feat: add set_shared_memory_carveout
(#629)
- chore: bump version in pixi.toml
(#641)
- refactor: remove devicearray code to reduce complexity
(#600)
ZzEeKkAa added a commit to ZzEeKkAa/numba-cuda that referenced this pull request Jan 8, 2026
v0.23.0

- Capture global device arrays in kernels and device functions (NVIDIA#666)
- Fix NVIDIA#624: Accept Numba IR nodes in all places Numba-CUDA IR nodes are expected (NVIDIA#643)
- Fix Issue NVIDIA#588: separate compilation of NVVM IR modules when generating debuginfo (NVIDIA#591)
- feat: allow printing nested tuples (NVIDIA#667)
- build(deps): bump actions/setup-python from 5.6.0 to 6.1.0 (NVIDIA#655)
- build(deps): bump actions/upload-artifact from 4 to 5 (NVIDIA#652)
- Test RAPIDS 25.12 (NVIDIA#661)
- Do not manually set DUMP_ASSEMBLY in `nvjitlink` tests (NVIDIA#662)
- feat: add print support for int64 tuples (NVIDIA#663)
- Only run dependabot monthly and open fewer PRs (NVIDIA#658)
- test: fix bogus `self` argument to `Context` (NVIDIA#656)
- Fix false negative NRT link decision when NRT was previously toggled on (NVIDIA#650)
- Add support for dependabot (NVIDIA#647)
- refactor: cull dead linker objects (NVIDIA#649)
- Migrate numba-cuda driver to use cuda.core.launch API (NVIDIA#609)
- feat: add set_shared_memory_carveout (NVIDIA#629)
- chore: bump version in pixi.toml (NVIDIA#641)
- refactor: remove devicearray code to reduce complexity (NVIDIA#600)
brandon-b-miller added a commit that referenced this pull request Jan 26, 2026
PR #609 made some changes to
the way modules were loaded that results in the wrong object being
passed to `cuOccupancyMaxPotentialBlockSize` (previously a `CUFunction`
and now a `CUKernel`). This causes the max block size calculation to
fail after eventually getting the wrong object and leads to a
`CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES` on certain GPUs. This is observable
on a V100 with a resource hungry kernel:

```
python -m numba.runtests numba.cuda.tests.cudapy.test_gufunc.TestCUDAGufunc.test_gufunc_small
```
```
cuda.core._utils.cuda_utils.CUDAError: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES: This indicates that a launch did not occur because it did not have appropriate resources. 
```

This PR removes the `numba-cuda` native maximum threads per block
computation machinery and routes through `cuda-python` APIs to get the
same information.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

4 - Waiting on author Waiting for author to respond to review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants