Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions numba_cuda/numba/cuda/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def jit(
lineinfo=False,
cache=False,
launch_bounds=None,
lto=None,
**kws,
):
"""
Expand Down Expand Up @@ -83,6 +84,10 @@ def jit(
If a scalar is provided, it is used as the maximum
number of threads per block.
:type launch_bounds: int | tuple[int]
:param lto: Whether to enable LTO. If unspecified, LTO is enabled by
default when pynvjitlink is available, except for kernels where
``debug=True``.
:type lto: bool
"""

if link and config.ENABLE_CUDASIM:
Expand Down Expand Up @@ -136,6 +141,13 @@ def jit(
if device and kws.get("link"):
raise ValueError("link keyword invalid for device function")

if lto is None:
# Default to using LTO if pynvjitlink is available and we're not debugging
lto = config.CUDA_ENABLE_PYNVJITLINK and not debug
else:
if lto and not config.CUDA_ENABLE_PYNVJITLINK:
raise RuntimeError("LTO requires pynvjitlink, which is not enabled")

if sigutils.is_signature(func_or_sig):
signatures = [func_or_sig]
specialized = True
Expand Down Expand Up @@ -165,6 +177,7 @@ def _jit(func):
targetoptions["forceinline"] = forceinline
targetoptions["extensions"] = extensions
targetoptions["launch_bounds"] = launch_bounds
targetoptions["lto"] = lto

disp = CUDADispatcher(func, targetoptions=targetoptions)

Expand Down Expand Up @@ -235,6 +248,7 @@ def autojitwrapper(func):
targetoptions["forceinline"] = forceinline
targetoptions["extensions"] = extensions
targetoptions["launch_bounds"] = launch_bounds
targetoptions["lto"] = lto
disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions)

if cache:
Expand Down
4 changes: 2 additions & 2 deletions numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def test_linker_disabled_envvar(self):
env = os.environ.copy()
env["NUMBA_CUDA_ENABLE_PYNVJITLINK"] = "0"
with self.assertRaisesRegex(
AssertionError, "LTO and additional flags require PyNvJitLinker"
AssertionError, "LTO requires pynvjitlink, which is not enabled"
):
# Actual error raised is `ValueError`, but `run_in_subprocess`
# reraises as AssertionError.
Expand All @@ -323,7 +323,7 @@ def test_linker_disabled_config(self):
env.pop("NUMBA_CUDA_ENABLE_PYNVJITLINK", None)
with override_config("CUDA_ENABLE_PYNVJITLINK", False):
with self.assertRaisesRegex(
AssertionError, "LTO and additional flags require PyNvJitLinker"
AssertionError, "LTO requires pynvjitlink, which is not enabled"
):
run_in_subprocess(
self.src.format(
Expand Down
12 changes: 12 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from numba import cuda
from numba.core.errors import TypingError
from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim
from numba.tests.support import override_config


def noop(x):
Expand Down Expand Up @@ -89,6 +90,17 @@ def kernel_func():
self.assertIn("resolving callee type: type(CUDADispatcher", excstr)
self.assertIn("NameError: name 'floor' is not defined", excstr)

@skip_on_cudasim("Simulator does not use pynvjitlink")
def test_lto_without_pynvjitlink_error(self):
with self.assertRaisesRegex(RuntimeError, "LTO requires pynvjitlink"):
with override_config("CUDA_ENABLE_PYNVJITLINK", False):

@cuda.jit(lto=True)
def f():
pass

f[1, 1]()


if __name__ == "__main__":
unittest.main()