diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index d2ec579ad..b23642e3d 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -2,7 +2,7 @@ from numba.core import config, serialize from numba.core.codegen import Codegen, CodeLibrary -from .cudadrv import devices, driver, nvrtc, nvvm, runtime +from .cudadrv import devices, driver, nvvm, runtime, nvrtc from numba.cuda.cudadrv.libs import get_cudalib from numba.cuda.cudadrv.linkable_code import LinkableCode from numba.cuda.memory_management.nrt import NRT_LIBRARY @@ -233,6 +233,33 @@ def get_asm_str(self, cc=None): return ptx + def get_lto_ptx(self, cc=None): + """ + Get the PTX code after LTO. + """ + + if not self._lto: + raise RuntimeError("LTO is not enabled") + + if not driver._have_nvjitlink(): + raise RuntimeError("Link time optimization requires nvJitLink.") + + cc = self._ensure_cc(cc) + + linker = driver._Linker.new( + max_registers=self._max_registers, + cc=cc, + additional_flags=["-ptx"], + lto=self._lto, + ) + + self._link_all(linker, cc, ignore_nonlto=True) + + ptx = linker.get_linked_ptx() + ptx = ptx.decode("utf-8") + + return ptx + def get_ltoir(self, cc=None): cc = self._ensure_cc(cc) @@ -274,17 +301,7 @@ def get_cubin(self, cc=None): return cubin if self._lto and config.DUMP_ASSEMBLY: - linker = driver._Linker.new( - max_registers=self._max_registers, - cc=cc, - additional_flags=["-ptx"], - lto=self._lto, - ) - # `-ptx` flag is meant to view the optimized PTX for LTO objects. - # Non-LTO objects are not passed to linker. - self._link_all(linker, cc, ignore_nonlto=True) - ptx = linker.get_linked_ptx() - ptx = ptx.decode("utf-8") + ptx = self.get_lto_ptx(cc=cc) print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, "-")) print(ptx) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 0967d2fd5..104ef3252 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -384,6 +384,12 @@ def inspect_asm(self, cc): """ return self._codelibrary.get_asm_str(cc=cc) + def inspect_lto_ptx(self, cc): + """ + Returns the PTX code for the external functions linked to this kernel. + """ + return self._codelibrary.get_lto_ptx(cc=cc) + def inspect_sass_cfg(self): """ Returns the CFG of the SASS for this kernel. @@ -1169,6 +1175,34 @@ def inspect_asm(self, signature=None): for sig, overload in self.overloads.items() } + def inspect_lto_ptx(self, signature=None): + """ + Return link-time optimized PTX code for the given signature. + + :param signature: A tuple of argument types. + :return: The PTX code for the given signature, or a dict of PTX codes + for all previously-encountered signatures. + """ + cc = get_current_device().compute_capability + device = self.targetoptions.get("device") + + if signature is not None: + if device: + return self.overloads[signature].library.get_lto_ptx(cc) + else: + return self.overloads[signature].inspect_lto_ptx(cc) + else: + if device: + return { + sig: overload.library.get_lto_ptx(cc) + for sig, overload in self.overloads.items() + } + else: + return { + sig: overload.inspect_lto_ptx(cc) + for sig, overload in self.overloads.items() + } + def inspect_sass_cfg(self, signature=None): """ Return this kernel's CFG for the device in the current context. diff --git a/numba_cuda/numba/cuda/testing.py b/numba_cuda/numba/cuda/testing.py index 217f7d370..f681a5cc2 100644 --- a/numba_cuda/numba/cuda/testing.py +++ b/numba_cuda/numba/cuda/testing.py @@ -189,6 +189,10 @@ def skip_if_cudadevrt_missing(fn): return unittest.skipIf(cudadevrt_missing(), "cudadevrt missing")(fn) +def skip_if_nvjitlink_missing(reason): + return unittest.skipIf(not driver._have_nvjitlink(), reason) + + class ForeignArray(object): """ Class for emulating an array coming from another library through the CUDA diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py b/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py index 5a038a11c..2d94e0bf3 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py @@ -1,12 +1,18 @@ +import re +import cffi + import numpy as np from io import StringIO from numba import cuda, float32, float64, int32, intp +from numba.types import float16, CPointer +from numba.cuda import declare_device from numba.cuda.testing import unittest, CUDATestCase from numba.cuda.testing import ( skip_on_cudasim, skip_with_nvdisasm, skip_without_nvdisasm, + skip_if_nvjitlink_missing, ) @@ -108,6 +114,56 @@ def _test_inspect_sass(self, kernel, name, sass): self.assertIn("BRA", sass) # Branch self.assertIn("EXIT", sass) # Exit program + @skip_on_cudasim("Simulator does not generate code to be inspected") + @skip_if_nvjitlink_missing("nvJitLink is required for LTO") + def test_inspect_lto_asm(self): + ffi = cffi.FFI() + + ext = cuda.CUSource(""" + #include + extern "C" + __device__ int add_f2_f2(__half * res, __half * a, __half *b) { + *res = *a + *b; + return 0; + } + """) + + add = declare_device( + "add_f2_f2", + float16(CPointer(float16), CPointer(float16)), + link=ext, + ) + + @cuda.jit + def k(arr): + local_arr = cuda.local.array(shape=1, dtype=np.float16) + local_arr2 = cuda.local.array(shape=1, dtype=np.float16) + local_arr[0] = 1 + local_arr2[0] = 2 + + ptr = ffi.from_buffer(local_arr) + ptr2 = ffi.from_buffer(local_arr2) + + arr[0] = add(ptr, ptr2) + + arr = np.array([0], dtype=np.float16) + + k[1, 1](arr) + + allasms = k.inspect_asm() + asm = next(iter(allasms.values())) + + regex = re.compile(r"call(.|\n)*add_f2_f2") + self.assertRegex(asm, regex) + + all_ext_asms = k.inspect_lto_ptx() + lto_asm = next(iter(all_ext_asms.values())) + + self.assertIn("add.f16", lto_asm) + self.assertNotIn("call", lto_asm) + + np.testing.assert_equal(arr[0], np.float16(1) + np.float16(2)) + @skip_without_nvdisasm("nvdisasm needed for inspect_sass()") def test_inspect_sass_eager(self): sig = (float32[::1], int32[::1])