diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index 4030c6452..cb72c0dc4 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -58,6 +58,57 @@ def disassemble_cubin_for_cfg(cubin): return run_nvdisasm(cubin, flags) +class ExternalCodeLibrary(CodeLibrary): + """Holds code produced externally, for linking with generated code.""" + + def __init__(self, codegen, name): + super().__init__(codegen, name) + # Files to link + self._linking_files = set() + # Setup and teardown functions for the module. + # The order is determined by the order they are added to the codelib. + self._setup_functions = [] + self._teardown_functions = [] + + @property + def modules(self): + # There are no LLVM IR modules in an ExternalCodeLibrary + return set() + + def add_linking_file(self, path_or_obj): + # Adding new files after finalization is prohibited, in case the list + # of libraries has already been added to another code library; the + # newly-added files would be omitted from their linking process. + self._raise_if_finalized() + + if isinstance(path_or_obj, LinkableCode): + if path_or_obj.setup_callback: + self._setup_functions.append(path_or_obj.setup_callback) + if path_or_obj.teardown_callback: + self._teardown_functions.append(path_or_obj.teardown_callback) + + self._linking_files.add(path_or_obj) + + def add_ir_module(self, module): + raise NotImplementedError("Cannot add LLVM IR to external code") + + def add_linking_library(self, library): + raise NotImplementedError("Cannot add libraries to external code") + + def finalize(self): + self._raise_if_finalized() + self._finalized = True + + def get_asm_str(self): + raise NotImplementedError("No assembly for external code") + + def get_llvm_str(self): + raise NotImplementedError("No LLVM IR for external code") + + def get_function(self, name): + raise NotImplementedError("Cannot get function from external code") + + class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary): """ The CUDACodeLibrary generates PTX, SASS, cubins for multiple different @@ -298,6 +349,9 @@ def add_linking_library(self, library): self._raise_if_finalized() self._linking_libraries.add(library) + self._linking_files.update(library._linking_files) + self._setup_functions.extend(library._setup_functions) + self._teardown_functions.extend(library._teardown_functions) def add_linking_file(self, path_or_obj): if isinstance(path_or_obj, LinkableCode): diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py index 3cab447c3..ce6b5b8fe 100644 --- a/numba_cuda/numba/cuda/compiler.py +++ b/numba_cuda/numba/cuda/compiler.py @@ -1,5 +1,4 @@ from llvmlite import ir -from numba.core.typing.templates import ConcreteTemplate from numba.core import ir as numba_ir from numba.core import ( cgutils, @@ -37,6 +36,7 @@ from warnings import warn from numba.cuda import nvvmutils from numba.cuda.api import get_current_device +from numba.cuda.codegen import ExternalCodeLibrary from numba.cuda.cudadrv import nvvm from numba.cuda.descriptor import cuda_target from numba.cuda.target import CUDACABICallConv @@ -779,32 +779,37 @@ def compile_ptx_for_current_device( def declare_device_function(name, restype, argtypes, link): - return declare_device_function_template(name, restype, argtypes, link).key - - -def declare_device_function_template(name, restype, argtypes, link): from .descriptor import cuda_target typingctx = cuda_target.typing_context targetctx = cuda_target.target_context sig = typing.signature(restype, *argtypes) - extfn = ExternFunction(name, sig, link) - class device_function_template(ConcreteTemplate): - key = extfn - cases = [sig] + # extfn is the descriptor used to call the function from Python code, and + # is used as the key for typing and lowering. + extfn = ExternFunction(name, sig) - fndesc = funcdesc.ExternalFunctionDescriptor( - name=name, restype=restype, argtypes=argtypes - ) + # Typing + device_function_template = typing.make_concrete_template(name, extfn, [sig]) typingctx.insert_user_function(extfn, device_function_template) - targetctx.insert_user_function(extfn, fndesc) + + # Lowering + lib = ExternalCodeLibrary(f"{name}_externals", targetctx.codegen()) + for file in link: + lib.add_linking_file(file) + + # ExternalFunctionDescriptor provides a lowering implementation for calling + # external functions + fndesc = funcdesc.ExternalFunctionDescriptor(name, restype, argtypes) + targetctx.insert_user_function(extfn, fndesc, libs=(lib,)) return device_function_template class ExternFunction: - def __init__(self, name, sig, link): + """A descriptor that can be used to call the external function from within + a Python kernel.""" + + def __init__(self, name, sig): self.name = name self.sig = sig - self.link = link diff --git a/numba_cuda/numba/cuda/cudadecl.py b/numba_cuda/numba/cuda/cudadecl.py index 1d9992cda..a15dacde2 100644 --- a/numba_cuda/numba/cuda/cudadecl.py +++ b/numba_cuda/numba/cuda/cudadecl.py @@ -21,7 +21,7 @@ from numba.cuda.types import dim3 from numba.core.typeconv import Conversion from numba import cuda -from numba.cuda.compiler import declare_device_function_template +from numba.cuda.compiler import declare_device_function registry = Registry() register = registry.register @@ -422,7 +422,7 @@ def _genfp16_binary_operator(op): def _resolve_wrapped_unary(fname): link = tuple() - decl = declare_device_function_template( + decl = declare_device_function( f"__numba_wrapper_{fname}", types.float16, (types.float16,), link ) return types.Function(decl) @@ -430,7 +430,7 @@ def _resolve_wrapped_unary(fname): def _resolve_wrapped_binary(fname): link = tuple() - decl = declare_device_function_template( + decl = declare_device_function( f"__numba_wrapper_{fname}", types.float16, ( diff --git a/numba_cuda/numba/cuda/decorators.py b/numba_cuda/numba/cuda/decorators.py index d5a0a29b3..8deee374e 100644 --- a/numba_cuda/numba/cuda/decorators.py +++ b/numba_cuda/numba/cuda/decorators.py @@ -236,4 +236,6 @@ def declare_device(name, sig, link=None): msg = "Return type must be provided for device declarations" raise TypeError(msg) - return declare_device_function(name, restype, argtypes, link) + template = declare_device_function(name, restype, argtypes, link) + + return template.key diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 6e40ea5eb..e9a00cfa4 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -3,22 +3,19 @@ import sys import ctypes import functools -from collections import defaultdict -from numba.core import config, ir, serialize, sigutils, types, typing, utils +from numba.core import config, serialize, sigutils, types, typing, utils from numba.core.caching import Cache, CacheImpl from numba.core.compiler_lock import global_compiler_lock from numba.core.dispatcher import Dispatcher from numba.core.errors import NumbaPerformanceWarning from numba.core.typing.typeof import Purpose, typeof -from numba.core.types.functions import Function from numba.cuda.api import get_current_device from numba.cuda.args import wrap_arg from numba.cuda.compiler import ( compile_cuda, CUDACompiler, kernel_fixup, - ExternFunction, ) import re from numba.cuda.cudadrv import driver @@ -60,54 +57,6 @@ reshape_funcs = ["nocopy_empty_reshape", "numba_attempt_nocopy_reshape"] -def get_cres_link_objects(cres): - """Given a compile result, return a set of all linkable code objects that - are required for it to be fully linked.""" - - link_objects = set() - - # List of calls into declared device functions - device_func_calls = [ - (name, v) - for name, v in cres.fndesc.typemap.items() - if (isinstance(v, cuda_types.CUDADispatcher)) - ] - - # List of tuples with SSA name of calls and corresponding signature - call_signatures = [ - (call.func.name, sig) - for call, sig in cres.fndesc.calltypes.items() - if (isinstance(call, ir.Expr) and call.op == "call") - ] - - # Map SSA names to all invoked signatures - call_signature_d = defaultdict(list) - for name, sig in call_signatures: - call_signature_d[name].append(sig) - - # Add the link objects from the current function's callees - for name, v in device_func_calls: - for sig in call_signature_d.get(name, []): - called_cres = v.dispatcher.overloads[sig.args] - called_link_objects = get_cres_link_objects(called_cres) - link_objects.update(called_link_objects) - - # From this point onwards, we are only interested in ExternFunction - # declarations - these are the calls made directly in this function to - # them. - for name, v in cres.fndesc.typemap.items(): - if not isinstance(v, Function): - continue - - if not isinstance(v.typing_key, ExternFunction): - continue - - for obj in v.typing_key.link: - link_objects.add(obj) - - return link_objects - - class _Kernel(serialize.ReduceMixin): """ CUDA Kernel specialized for a given set of argument types. When called, this @@ -238,9 +187,6 @@ def link_to_library_functions( self.maybe_link_nrt(link, tgt_ctx, asm) - for obj in get_cres_link_objects(cres): - lib.add_linking_file(obj) - for filepath in link: lib.add_linking_file(filepath) diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_extending.py b/numba_cuda/numba/cuda/tests/cudapy/test_extending.py index 9f78ec851..3dbe00c5d 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_extending.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_extending.py @@ -4,6 +4,7 @@ import numpy as np import os from numba import config, cuda, njit, types +from numba.extending import overload class Interval: @@ -250,6 +251,54 @@ def use_external_add(r, x, y): np.testing.assert_equal(r[0], 3) + @cuda.jit(lto=lto) + def use_external_add_device(x, y): + return external_add(x, y) + + @cuda.jit(lto=lto) + def use_external_add_kernel(r, x, y): + r[0] = use_external_add_device(x[0], y[0]) + + r = np.zeros(1, dtype=np.uint32) + x = np.ones(1, dtype=np.uint32) + y = np.ones(1, dtype=np.uint32) * 2 + + use_external_add_kernel[1, 1](r, x, y) + + np.testing.assert_equal(r[0], 3) + + def test_linked_called_through_overload(self): + cu_code = cuda.CUSource(""" + extern "C" __device__ + int bar(int *out, int a) + { + *out = a * 2; + return 0; + } + """) + + bar = cuda.declare_device("bar", "int32(int32)", link=cu_code) + + def bar_call(val): + pass + + @overload(bar_call, target="cuda") + def ol_bar_call(a): + return lambda a: bar(a) + + @cuda.jit("void(int32[::1], int32[::1])") + def foo(r, x): + i = cuda.grid(1) + if i < len(r): + r[i] = bar_call(x[i]) + + x = np.arange(10, dtype=np.int32) + r = np.empty_like(x) + + foo[1, 32](r, x) + + np.testing.assert_equal(r, x * 2) + if __name__ == "__main__": unittest.main()