Skip to content

Updates for recent API changes#313

Merged
gmarkall merged 7 commits intoNVIDIA:mainfrom
gmarkall:api-updates
Jul 2, 2025
Merged

Updates for recent API changes#313
gmarkall merged 7 commits intoNVIDIA:mainfrom
gmarkall:api-updates

Conversation

@gmarkall
Copy link
Contributor

@gmarkall gmarkall commented Jul 1, 2025

This PR updates Numba-CUDA for recent driver API changes. Some related changes and simplifications are required; these are detailed in individual commit messages.

Fixes #281.

gmarkall added 4 commits July 1, 2025 23:49
Changes in the compute capability support matrix in nvvm.py will
continue to be needed with new CUDA versions if we maintain a list of
explicitly-supported compute capabilities. NVRTC supports retrieving the
supported list programmatically, so we switch to using it instead.

This does assume that the user's environment has a consistent set of
components (NVVM, NVRTC, etc.) - this is generally expected to be the
case with recent developments in package management, and there's little
we can do about an inconsistent environment anyway.

Changes outside of nvvm.py / nvrtc.py are to accommodate the movement of
this functionality. A major side effect is that we no longer need to
initialize the list of supported CCs prior to forking, because we don't
need to use the CUDA runtime to populate the supported CC list.
We only used the CUDA runtime library to get the runtime version so that
we could populate the list of supported compute capabilities in nvvm.py.
Now that we don't do this, and that NVRTC provides the CUDA toolkit
version, there is no need to use the CUDA runtime API at all.

The Numba API for the runtime version is not deleted in case it was used
by external code - instead, it uses NVRTC to obtain the toolkit version.

Because NVRTC used the runtime version to determine what prototypes to
bind, we need to stop doing that to avoid a circular dependency /
deadlock - instead of checking the runtime version and creating the list
of prototypes, we try to add all known prototypes, and ignore errors in
those related to LTOIR, which can occur with CUDA 11 where they were not
present.

The `runtime.is_supported_version()` API and its test is removed - it
would always have been `False` on CUDA 12 (incorrectly) and this has
never been reported as an issue, so it seems very unlikely that anyone
was using it.
Recent toolkits move the CCCL headers into their own subdirectory, so we
need to add this subdirectory to the include path so that headers such
as `cuda/atomic` etc. can be located successfully in all cases.
The most recent `cuCtxCreate()` API in the CUDA bindings will require an
additional optional parameter. We don't have to supply a value for it
(other than `None`), but we do need to provide the argument on binding
versions where it is required.
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jul 1, 2025

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@gmarkall gmarkall added the 2 - In Progress Currently a work in progress label Jul 1, 2025
@gmarkall
Copy link
Contributor Author

gmarkall commented Jul 1, 2025

/ok to test

gmarkall added 2 commits July 2, 2025 00:16
The change to use NVRTC for the supported compute capabilities also had
the implicit effect of making the default compute capability the lowest
supported by the installed NVRTC version. We need it to default to at
least 7.5 (unless specified higher by the user) to maintain the
behaviour of the compute capability logic from nvvm.py that was
replaced.
@gmarkall
Copy link
Contributor Author

gmarkall commented Jul 1, 2025

/ok to test

@gmarkall gmarkall added 3 - Ready for Review Ready for review by team and removed 2 - In Progress Currently a work in progress labels Jul 2, 2025
@gmarkall
Copy link
Contributor Author

gmarkall commented Jul 2, 2025

The wheels-deps-wheels timeout is due to a deadlock:

$ NUMBA_CUDA_ENABLE_PYNVJITLINK=1 NUMBA_CUDA_TEST_BIN_DIR=$NUMBA_CUDA_TEST_BIN_DIR python -m numba.runtests numba.cuda.tests -v
/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/__init__.py:55: UserWarning: Explicitly enabling pynvjitlink is no longer necessary. NVIDIA bindings are enabled. cuda.core will be used in place of pynvjitlink.
  warnings.warn(
^CTraceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba/runtests.py", line 9, in <module>
    sys.exit(0 if _main(sys.argv) else 1)
                  ~~~~~^^^^^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba/testing/_runtests.py", line 25, in _main
    return run_tests(argv, defaultTest='numba.tests',
           ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                     **kwds).wasSuccessful()
                     ^^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba/testing/__init__.py", line 54, in run_tests
    prog = NumbaTestProgram(argv=argv,
                            module=None,
    ...<3 lines>...
                            verbosity=verbosity,
                            nomultiproc=nomultiproc)
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba/testing/main.py", line 204, in __init__
    super(NumbaTestProgram, self).__init__(*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/unittest/main.py", line 103, in __init__
    self.parseArgs(argv)
    ~~~~~~~~~~~~~~^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba/testing/main.py", line 293, in parseArgs
    super(NumbaTestProgram, self).parseArgs(argv)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/unittest/main.py", line 142, in parseArgs
    self.createTests()
    ~~~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/unittest/main.py", line 153, in createTests
    self.test = self.testLoader.loadTestsFromNames(self.testNames,
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
                                                   self.module)
                                                   ^^^^^^^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/unittest/loader.py", line 207, in loadTestsFromNames
    suites = [self.loadTestsFromName(name, module) for name in names]
              ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/unittest/loader.py", line 174, in loadTestsFromName
    return self.loadTestsFromModule(obj)
           ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/unittest/loader.py", line 113, in loadTestsFromModule
    return load_tests(self, tests, pattern)
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/tests/__init__.py", line 44, in load_tests
    suite.addTests(load_testsuite(loader, join(this_dir, "nocuda")))
                   ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/tests/__init__.py", line 34, in load_testsuite
    suite.addTests(loader.loadTestsFromName(f))
                   ~~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/unittest/loader.py", line 137, in loadTestsFromName
    module = __import__(module_name)
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py", line 7, in <module>
    @unittest.skipIf(not nvvm.is_available(), "No libNVVM")
                         ~~~~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/nvvm.py", line 69, in is_available
    NVVM()
    ~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/nvvm.py", line 158, in __new__
    inst.driver = open_cudalib("nvvm")
                  ~~~~~~~~~~~~^^^^^^^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/libs.py", line 83, in open_cudalib
    path = get_cudalib(lib)
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/libs.py", line 54, in get_cudalib
    return get_cuda_paths()[lib].info or _dllnamepattern % lib
           ~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 427, in get_cuda_paths
    "nvrtc": _get_nvrtc_path(),
             ~~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 398, in _get_nvrtc_path
    by, path = _get_nvrtc_path_decision()
               ~~~~~~~~~~~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 110, in _get_nvrtc_path_decision
    return _find_first_valid_lazy(options)
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 35, in _find_first_valid_lazy
    value = fn()
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 173, in _get_nvrtc_wheel
    dso_path = get_nvrtc_dso_path()
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 154, in get_nvrtc_dso_path
    major = get_major_cuda_version()
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 140, in get_major_cuda_version
    return get_version()[0]
           ~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/runtime.py", line 23, in get_version
    return runtime.get_version()
           ~~~~~~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/runtime.py", line 13, in get_version
    return NVRTC().get_version()
           ~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/nvrtc.py", line 150, in __new__
    lib = open_cudalib("nvrtc")
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/libs.py", line 83, in open_cudalib
    path = get_cudalib(lib)
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/libs.py", line 54, in get_cudalib
    return get_cuda_paths()[lib].info or _dllnamepattern % lib
           ~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 427, in get_cuda_paths
    "nvrtc": _get_nvrtc_path(),
             ~~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 398, in _get_nvrtc_path
    by, path = _get_nvrtc_path_decision()
               ~~~~~~~~~~~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 110, in _get_nvrtc_path_decision
    return _find_first_valid_lazy(options)
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 35, in _find_first_valid_lazy
    value = fn()
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 173, in _get_nvrtc_wheel
    dso_path = get_nvrtc_dso_path()
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 154, in get_nvrtc_dso_path
    major = get_major_cuda_version()
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cuda_paths.py", line 140, in get_major_cuda_version
    return get_version()[0]
           ~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/runtime.py", line 23, in get_version
    return runtime.get_version()
           ~~~~~~~~~~~~~~~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/runtime.py", line 13, in get_version
    return NVRTC().get_version()
           ~~~~~^^
  File "/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/cudadrv/nvrtc.py", line 144, in __new__
    with _nvrtc_lock:
         ^^^^^^^^^^^
KeyboardInterrupt

@gmarkall
Copy link
Contributor Author

gmarkall commented Jul 2, 2025

Simple reproducer:

$ python -c "from numba.cuda.cudadrv import nvvm; nvvm.is_available()"
/home/gmarkall/miniforge3/envs/test-ctk13-api-wheels/lib/python3.13/site-packages/numba_cuda/numba/cuda/__init__.py:55: UserWarning: Explicitly enabling pynvjitlink is no longer necessary. NVIDIA bindings are enabled. cuda.core will be used in place of pynvjitlink.
  warnings.warn(

@gmarkall
Copy link
Contributor Author

gmarkall commented Jul 2, 2025

Unfortunately, I'm relying on NVRTC to get the toolkit version to determine the name of the library to use to load NVRTC. 🤦

@gmarkall
Copy link
Contributor Author

gmarkall commented Jul 2, 2025

/ok to test

We use NVRTC to get the CUDA version, so we can't use the CUDA version
to determine the NVRTC DLL / SO anymore. Instead, check for the presence
of each version, preferring the highest.
@gmarkall
Copy link
Contributor Author

gmarkall commented Jul 2, 2025

/ok to test

@gmarkall gmarkall merged commit f5f81fa into NVIDIA:main Jul 2, 2025
39 checks passed
@gmarkall gmarkall added 5 - Ready to merge Testing and reviews complete, ready to merge and removed 3 - Ready for Review Ready for review by team labels Jul 2, 2025
gmarkall added a commit to gmarkall/numba-cuda that referenced this pull request Jul 2, 2025
- Updates for recent API changes (NVIDIA#313)
- Fix lineinfo generation when compile_internal used (NVIDIA#271) (NVIDIA#287)
- Build docs with NVIDIA Sphinx theme (NVIDIA#312)
- Don't skip debug tests when LTO enabled by default (NVIDIA#311)
- Use `cuda.bindings` and `cuda.core` for `Linker` (NVIDIA#133)
- Enable LTO by default when pynvjitlink is available (NVIDIA#310)
@gmarkall gmarkall mentioned this pull request Jul 2, 2025
gmarkall added a commit that referenced this pull request Jul 2, 2025
- Updates for recent API changes (#313)
- Fix lineinfo generation when compile_internal used (#271) (#287)
- Build docs with NVIDIA Sphinx theme (#312)
- Don't skip debug tests when LTO enabled by default (#311)
- Use `cuda.bindings` and `cuda.core` for `Linker` (#133)
- Enable LTO by default when pynvjitlink is available (#310)
gmarkall added a commit to gmarkall/numba-cuda that referenced this pull request Nov 3, 2025
PR NVIDIA#313 removed the `runtime.is_supported_version()` API, but it is used
by the `cuda.is_supported_version()` public API. This commit restores
the `cuda.is_supported_version()` API by checking whether the CUDA
runtime major version is 12 or 13.

The version number check will need bumping as appropriate when future
toolkit major versions are added and existing toolkit major version are
dropped. This situation will be caught by the test that is added to
exercise this API.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

5 - Ready to merge Testing and reviews complete, ready to merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support programmatically retrieving the lowest supported compute capability

2 participants