Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions numba_cuda/numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,57 @@ def disassemble_cubin_for_cfg(cubin):
return run_nvdisasm(cubin, flags)


class ExternalCodeLibrary(CodeLibrary):
"""Holds code produced externally, for linking with generated code."""

def __init__(self, codegen, name):
super().__init__(codegen, name)
# Files to link
self._linking_files = set()
# Setup and teardown functions for the module.
# The order is determined by the order they are added to the codelib.
self._setup_functions = []
self._teardown_functions = []

@property
def modules(self):
# There are no LLVM IR modules in an ExternalCodeLibrary
return set()

def add_linking_file(self, path_or_obj):
# Adding new files after finalization is prohibited, in case the list
# of libraries has already been added to another code library; the
# newly-added files would be omitted from their linking process.
self._raise_if_finalized()

if isinstance(path_or_obj, LinkableCode):
if path_or_obj.setup_callback:
self._setup_functions.append(path_or_obj.setup_callback)
if path_or_obj.teardown_callback:
self._teardown_functions.append(path_or_obj.teardown_callback)

self._linking_files.add(path_or_obj)

def add_ir_module(self, module):
raise NotImplementedError("Cannot add LLVM IR to external code")

def add_linking_library(self, library):
raise NotImplementedError("Cannot add libraries to external code")

def finalize(self):
self._raise_if_finalized()
self._finalized = True

def get_asm_str(self):
raise NotImplementedError("No assembly for external code")

def get_llvm_str(self):
raise NotImplementedError("No LLVM IR for external code")

def get_function(self, name):
raise NotImplementedError("Cannot get function from external code")


class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
"""
The CUDACodeLibrary generates PTX, SASS, cubins for multiple different
Expand Down Expand Up @@ -298,6 +349,9 @@ def add_linking_library(self, library):
self._raise_if_finalized()

self._linking_libraries.add(library)
self._linking_files.update(library._linking_files)
self._setup_functions.extend(library._setup_functions)
self._teardown_functions.extend(library._teardown_functions)

def add_linking_file(self, path_or_obj):
if isinstance(path_or_obj, LinkableCode):
Expand Down
35 changes: 20 additions & 15 deletions numba_cuda/numba/cuda/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from llvmlite import ir
from numba.core.typing.templates import ConcreteTemplate
from numba.core import ir as numba_ir
from numba.core import (
cgutils,
Expand Down Expand Up @@ -37,6 +36,7 @@
from warnings import warn
from numba.cuda import nvvmutils
from numba.cuda.api import get_current_device
from numba.cuda.codegen import ExternalCodeLibrary
from numba.cuda.cudadrv import nvvm
from numba.cuda.descriptor import cuda_target
from numba.cuda.target import CUDACABICallConv
Expand Down Expand Up @@ -779,32 +779,37 @@ def compile_ptx_for_current_device(


def declare_device_function(name, restype, argtypes, link):
return declare_device_function_template(name, restype, argtypes, link).key


def declare_device_function_template(name, restype, argtypes, link):
from .descriptor import cuda_target

typingctx = cuda_target.typing_context
targetctx = cuda_target.target_context
sig = typing.signature(restype, *argtypes)
extfn = ExternFunction(name, sig, link)

class device_function_template(ConcreteTemplate):
key = extfn
cases = [sig]
# extfn is the descriptor used to call the function from Python code, and
# is used as the key for typing and lowering.
extfn = ExternFunction(name, sig)

fndesc = funcdesc.ExternalFunctionDescriptor(
name=name, restype=restype, argtypes=argtypes
)
# Typing
device_function_template = typing.make_concrete_template(name, extfn, [sig])
typingctx.insert_user_function(extfn, device_function_template)
targetctx.insert_user_function(extfn, fndesc)

# Lowering
lib = ExternalCodeLibrary(f"{name}_externals", targetctx.codegen())
for file in link:
lib.add_linking_file(file)

# ExternalFunctionDescriptor provides a lowering implementation for calling
# external functions
fndesc = funcdesc.ExternalFunctionDescriptor(name, restype, argtypes)
targetctx.insert_user_function(extfn, fndesc, libs=(lib,))

return device_function_template


class ExternFunction:
def __init__(self, name, sig, link):
"""A descriptor that can be used to call the external function from within
a Python kernel."""

def __init__(self, name, sig):
self.name = name
self.sig = sig
self.link = link
6 changes: 3 additions & 3 deletions numba_cuda/numba/cuda/cudadecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from numba.cuda.types import dim3
from numba.core.typeconv import Conversion
from numba import cuda
from numba.cuda.compiler import declare_device_function_template
from numba.cuda.compiler import declare_device_function

registry = Registry()
register = registry.register
Expand Down Expand Up @@ -422,15 +422,15 @@ def _genfp16_binary_operator(op):

def _resolve_wrapped_unary(fname):
link = tuple()
decl = declare_device_function_template(
decl = declare_device_function(
f"__numba_wrapper_{fname}", types.float16, (types.float16,), link
)
return types.Function(decl)


def _resolve_wrapped_binary(fname):
link = tuple()
decl = declare_device_function_template(
decl = declare_device_function(
f"__numba_wrapper_{fname}",
types.float16,
(
Expand Down
4 changes: 3 additions & 1 deletion numba_cuda/numba/cuda/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,6 @@ def declare_device(name, sig, link=None):
msg = "Return type must be provided for device declarations"
raise TypeError(msg)

return declare_device_function(name, restype, argtypes, link)
template = declare_device_function(name, restype, argtypes, link)

return template.key
56 changes: 1 addition & 55 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@
import sys
import ctypes
import functools
from collections import defaultdict

from numba.core import config, ir, serialize, sigutils, types, typing, utils
from numba.core import config, serialize, sigutils, types, typing, utils
from numba.core.caching import Cache, CacheImpl
from numba.core.compiler_lock import global_compiler_lock
from numba.core.dispatcher import Dispatcher
from numba.core.errors import NumbaPerformanceWarning
from numba.core.typing.typeof import Purpose, typeof
from numba.core.types.functions import Function
from numba.cuda.api import get_current_device
from numba.cuda.args import wrap_arg
from numba.cuda.compiler import (
compile_cuda,
CUDACompiler,
kernel_fixup,
ExternFunction,
)
import re
from numba.cuda.cudadrv import driver
Expand Down Expand Up @@ -60,54 +57,6 @@
reshape_funcs = ["nocopy_empty_reshape", "numba_attempt_nocopy_reshape"]


def get_cres_link_objects(cres):
"""Given a compile result, return a set of all linkable code objects that
are required for it to be fully linked."""

link_objects = set()

# List of calls into declared device functions
device_func_calls = [
(name, v)
for name, v in cres.fndesc.typemap.items()
if (isinstance(v, cuda_types.CUDADispatcher))
]

# List of tuples with SSA name of calls and corresponding signature
call_signatures = [
(call.func.name, sig)
for call, sig in cres.fndesc.calltypes.items()
if (isinstance(call, ir.Expr) and call.op == "call")
]

# Map SSA names to all invoked signatures
call_signature_d = defaultdict(list)
for name, sig in call_signatures:
call_signature_d[name].append(sig)

# Add the link objects from the current function's callees
for name, v in device_func_calls:
for sig in call_signature_d.get(name, []):
called_cres = v.dispatcher.overloads[sig.args]
called_link_objects = get_cres_link_objects(called_cres)
link_objects.update(called_link_objects)

# From this point onwards, we are only interested in ExternFunction
# declarations - these are the calls made directly in this function to
# them.
for name, v in cres.fndesc.typemap.items():
if not isinstance(v, Function):
continue

if not isinstance(v.typing_key, ExternFunction):
continue

for obj in v.typing_key.link:
link_objects.add(obj)

return link_objects


class _Kernel(serialize.ReduceMixin):
"""
CUDA Kernel specialized for a given set of argument types. When called, this
Expand Down Expand Up @@ -238,9 +187,6 @@ def link_to_library_functions(

self.maybe_link_nrt(link, tgt_ctx, asm)

for obj in get_cres_link_objects(cres):
lib.add_linking_file(obj)

for filepath in link:
lib.add_linking_file(filepath)

Expand Down
49 changes: 49 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import os
from numba import config, cuda, njit, types
from numba.extending import overload


class Interval:
Expand Down Expand Up @@ -250,6 +251,54 @@ def use_external_add(r, x, y):

np.testing.assert_equal(r[0], 3)

@cuda.jit(lto=lto)
def use_external_add_device(x, y):
return external_add(x, y)

@cuda.jit(lto=lto)
def use_external_add_kernel(r, x, y):
r[0] = use_external_add_device(x[0], y[0])

r = np.zeros(1, dtype=np.uint32)
x = np.ones(1, dtype=np.uint32)
y = np.ones(1, dtype=np.uint32) * 2

use_external_add_kernel[1, 1](r, x, y)

np.testing.assert_equal(r[0], 3)

def test_linked_called_through_overload(self):
cu_code = cuda.CUSource("""
extern "C" __device__
int bar(int *out, int a)
{
*out = a * 2;
return 0;
}
""")

bar = cuda.declare_device("bar", "int32(int32)", link=cu_code)

def bar_call(val):
pass

@overload(bar_call, target="cuda")
def ol_bar_call(a):
return lambda a: bar(a)

@cuda.jit("void(int32[::1], int32[::1])")
def foo(r, x):
i = cuda.grid(1)
if i < len(r):
r[i] = bar_call(x[i])

x = np.arange(10, dtype=np.int32)
r = np.empty_like(x)

foo[1, 32](r, x)

np.testing.assert_equal(r, x * 2)


if __name__ == "__main__":
unittest.main()