diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index b261b1d3c..68cce944a 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -206,6 +206,9 @@ def _ensure_cc(self, cc): return device.compute_capability def get_asm_str(self, cc=None): + return "\n".join(self.get_asm_strs(cc=cc)) + + def get_asm_strs(self, cc=None): cc = self._ensure_cc(cc) ptxes = self._ptx_cache.get(cc, None) @@ -218,21 +221,25 @@ def get_asm_str(self, cc=None): irs = self.llvm_strs - ptx = nvvm.compile_ir(irs, **options) + if "g" in options: + ptxes = [nvvm.compile_ir(ir, **options) for ir in irs] + else: + ptxes = [nvvm.compile_ir(irs, **options)] # Sometimes the result from NVVM contains trailing whitespace and # nulls, which we strip so that the assembly dump looks a little # tidier. - ptx = ptx.decode().strip("\x00").strip() + ptxes = [ptx.decode().strip("\x00").strip() for ptx in ptxes] if config.DUMP_ASSEMBLY: print(("ASSEMBLY %s" % self._name).center(80, "-")) - print(ptx) + for ptx in ptxes: + print(ptx) print("=" * 80) - self._ptx_cache[cc] = ptx + self._ptx_cache[cc] = ptxes - return ptx + return ptxes def get_lto_ptx(self, cc=None): """ @@ -284,8 +291,9 @@ def _link_all(self, linker, cc, ignore_nonlto=False): ltoir = self.get_ltoir(cc=cc) linker.add_ltoir(ltoir) else: - ptx = self.get_asm_str(cc=cc) - linker.add_ptx(ptx.encode()) + ptxes = self.get_asm_strs(cc=cc) + for ptx in ptxes: + linker.add_ptx(ptx.encode()) for path in self._linking_files: linker.add_file_guess_ext(path, ignore_nonlto) @@ -432,7 +440,10 @@ def finalize(self): for mod in library.modules: for fn in mod.functions: if not fn.is_declaration: - fn.linkage = "linkonce_odr" + if "g" in self._nvvm_options: + fn.linkage = "weak_odr" + else: + fn.linkage = "linkonce_odr" self._finalized = True diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py index 13cc39eed..e38de0a20 100644 --- a/numba_cuda/numba/cuda/compiler.py +++ b/numba_cuda/numba/cuda/compiler.py @@ -1023,10 +1023,9 @@ def compile_all( ) if lto: - code = lib.get_ltoir(cc=cc) + codes = [lib.get_ltoir(cc=cc)] else: - code = lib.get_asm_str(cc=cc) - codes = [code] + codes = lib.get_asm_strs(cc=cc) # linking_files is_ltoir = output == "ltoir" @@ -1241,7 +1240,14 @@ def compile( if lto: code = lib.get_ltoir(cc=cc) else: - code = lib.get_asm_str(cc=cc) + codes = lib.get_asm_strs(cc=cc) + if len(codes) == 1: + code = codes[0] + else: + raise RuntimeError( + "Compiling this function results in multiple " + "PTX files. Use compile_all() instead" + ) return code, resty diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py b/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py index 7196a065a..31c847c20 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py @@ -169,6 +169,16 @@ def check_debug_info(self, ptx): # ending in the filename of this module. self.assertRegex(ptx, '\\.file.*test_compiler.py"') + # We did test for the presence of debuginfo here, but in practice it made + # no sense - the C ABI wrapper generates a call instruction that has + # nothing to correlate with the DWARF, so it would confuse the debugger + # immediately anyway. With the resolution of Issue #588 (using separate + # translation of each IR module when debuginfo is enabled) the debuginfo + # isn't even produced for the ABI wrapper, because there was none present + # in that module anyway. So this test can only be expected to fail until we + # have a proper way of generating device functions with the C ABI without + # requiring the hack of generating a wrapper. + @unittest.expectedFailure def test_device_function_with_debug(self): # See Issue #6719 - this ensures that compilation with debug succeeds # with CUDA 11.2 / NVVM 7.0 onwards. Previously it failed because NVVM