Skip to content
41 changes: 29 additions & 12 deletions numba_cuda/numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from numba.core import config, serialize
from numba.core.codegen import Codegen, CodeLibrary
from .cudadrv import devices, driver, nvrtc, nvvm, runtime
from .cudadrv import devices, driver, nvvm, runtime, nvrtc
from numba.cuda.cudadrv.libs import get_cudalib
from numba.cuda.cudadrv.linkable_code import LinkableCode
from numba.cuda.memory_management.nrt import NRT_LIBRARY
Expand Down Expand Up @@ -233,6 +233,33 @@ def get_asm_str(self, cc=None):

return ptx

def get_lto_ptx(self, cc=None):
"""
Get the PTX code after LTO.
"""

if not self._lto:
raise RuntimeError("LTO is not enabled")

if not driver._have_nvjitlink():
raise RuntimeError("Link time optimization requires nvJitLink.")

cc = self._ensure_cc(cc)

linker = driver._Linker.new(
max_registers=self._max_registers,
cc=cc,
additional_flags=["-ptx"],
lto=self._lto,
)

self._link_all(linker, cc, ignore_nonlto=True)

ptx = linker.get_linked_ptx()
ptx = ptx.decode("utf-8")

return ptx

def get_ltoir(self, cc=None):
cc = self._ensure_cc(cc)

Expand Down Expand Up @@ -274,17 +301,7 @@ def get_cubin(self, cc=None):
return cubin

if self._lto and config.DUMP_ASSEMBLY:
linker = driver._Linker.new(
max_registers=self._max_registers,
cc=cc,
additional_flags=["-ptx"],
lto=self._lto,
)
# `-ptx` flag is meant to view the optimized PTX for LTO objects.
# Non-LTO objects are not passed to linker.
self._link_all(linker, cc, ignore_nonlto=True)
ptx = linker.get_linked_ptx()
ptx = ptx.decode("utf-8")
ptx = self.get_lto_ptx(cc=cc)

print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, "-"))
print(ptx)
Expand Down
34 changes: 34 additions & 0 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ def inspect_asm(self, cc):
"""
return self._codelibrary.get_asm_str(cc=cc)

def inspect_lto_ptx(self, cc):
"""
Returns the PTX code for the external functions linked to this kernel.
"""
return self._codelibrary.get_lto_ptx(cc=cc)

def inspect_sass_cfg(self):
"""
Returns the CFG of the SASS for this kernel.
Expand Down Expand Up @@ -1169,6 +1175,34 @@ def inspect_asm(self, signature=None):
for sig, overload in self.overloads.items()
}

def inspect_lto_ptx(self, signature=None):
"""
Return link-time optimized PTX code for the given signature.

:param signature: A tuple of argument types.
:return: The PTX code for the given signature, or a dict of PTX codes
for all previously-encountered signatures.
"""
cc = get_current_device().compute_capability
device = self.targetoptions.get("device")

if signature is not None:
if device:
return self.overloads[signature].library.get_lto_ptx(cc)
else:
return self.overloads[signature].inspect_lto_ptx(cc)
else:
if device:
return {
sig: overload.library.get_lto_ptx(cc)
for sig, overload in self.overloads.items()
}
else:
return {
sig: overload.inspect_lto_ptx(cc)
for sig, overload in self.overloads.items()
}

def inspect_sass_cfg(self, signature=None):
"""
Return this kernel's CFG for the device in the current context.
Expand Down
4 changes: 4 additions & 0 deletions numba_cuda/numba/cuda/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def skip_if_cudadevrt_missing(fn):
return unittest.skipIf(cudadevrt_missing(), "cudadevrt missing")(fn)


def skip_if_nvjitlink_missing(reason):
return unittest.skipIf(not driver._have_nvjitlink(), reason)


class ForeignArray(object):
"""
Class for emulating an array coming from another library through the CUDA
Expand Down
56 changes: 56 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_inspect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import re
import cffi

import numpy as np

from io import StringIO
from numba import cuda, float32, float64, int32, intp
from numba.types import float16, CPointer
from numba.cuda import declare_device
from numba.cuda.testing import unittest, CUDATestCase
from numba.cuda.testing import (
skip_on_cudasim,
skip_with_nvdisasm,
skip_without_nvdisasm,
skip_if_nvjitlink_missing,
)


Expand Down Expand Up @@ -108,6 +114,56 @@ def _test_inspect_sass(self, kernel, name, sass):
self.assertIn("BRA", sass) # Branch
self.assertIn("EXIT", sass) # Exit program

@skip_on_cudasim("Simulator does not generate code to be inspected")
@skip_if_nvjitlink_missing("nvJitLink is required for LTO")
def test_inspect_lto_asm(self):
ffi = cffi.FFI()

ext = cuda.CUSource("""
#include <cuda_fp16.h>
extern "C"
__device__ int add_f2_f2(__half * res, __half * a, __half *b) {
*res = *a + *b;
return 0;
}
""")

add = declare_device(
"add_f2_f2",
float16(CPointer(float16), CPointer(float16)),
link=ext,
)

@cuda.jit
def k(arr):
local_arr = cuda.local.array(shape=1, dtype=np.float16)
local_arr2 = cuda.local.array(shape=1, dtype=np.float16)
local_arr[0] = 1
local_arr2[0] = 2

ptr = ffi.from_buffer(local_arr)
ptr2 = ffi.from_buffer(local_arr2)

arr[0] = add(ptr, ptr2)

arr = np.array([0], dtype=np.float16)

k[1, 1](arr)

allasms = k.inspect_asm()
asm = next(iter(allasms.values()))

regex = re.compile(r"call(.|\n)*add_f2_f2")
self.assertRegex(asm, regex)

all_ext_asms = k.inspect_lto_ptx()
lto_asm = next(iter(all_ext_asms.values()))

self.assertIn("add.f16", lto_asm)
self.assertNotIn("call", lto_asm)

np.testing.assert_equal(arr[0], np.float16(1) + np.float16(2))

@skip_without_nvdisasm("nvdisasm needed for inspect_sass()")
def test_inspect_sass_eager(self):
sig = (float32[::1], int32[::1])
Expand Down