Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 274 additions & 7 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
import os
import sys
import ctypes
import collections
import functools
import types as pytypes
import weakref
import uuid

from numba.core import serialize, sigutils, types, typing, config
from numba.core import compiler, serialize, sigutils, types, typing, config
from numba.cuda import utils
from numba.core.caching import Cache, CacheImpl
from numba.core.caching import Cache, CacheImpl, NullCache
from numba.core.compiler_lock import global_compiler_lock
from numba.core.dispatcher import Dispatcher
from numba.core.errors import NumbaPerformanceWarning
from numba.core.dispatcher import _DispatcherBase
from numba.core.errors import NumbaPerformanceWarning, TypingError
from numba.core.typing.templates import fold_arguments
from numba.core.typing.typeof import Purpose, typeof
from numba.cuda.api import get_current_device
from numba.cuda.args import wrap_arg
Expand Down Expand Up @@ -728,7 +733,134 @@ def load_overload(self, sig, target_context):
return super().load_overload(sig, target_context)


class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
class _MemoMixin:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is used / relevant to the CUDA target - if it is removed, does everything still appear to work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, don't worry about this comment too much - I think it will show up when we run coverage if it's not needed, so let's not think about it now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

__uuid = None
# A {uuid -> instance} mapping, for deserialization
_memo = weakref.WeakValueDictionary()
# hold refs to last N functions deserialized, retaining them in _memo
# regardless of whether there is another reference
_recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE)

@property
def _uuid(self):
"""
An instance-specific UUID, to avoid multiple deserializations of
a given instance.

Note: this is lazily-generated, for performance reasons.
"""
u = self.__uuid
if u is None:
u = str(uuid.uuid4())
self._set_uuid(u)
return u

def _set_uuid(self, u):
assert self.__uuid is None
self.__uuid = u
self._memo[u] = self
self._recent.append(self)


_CompileStats = collections.namedtuple(
"_CompileStats", ("cache_path", "cache_hits", "cache_misses")
)


class _FunctionCompiler(object):
def __init__(self, py_func, targetdescr, targetoptions, pipeline_class):
self.py_func = py_func
self.targetdescr = targetdescr
self.targetoptions = targetoptions
self.pysig = utils.pysignature(self.py_func)
self.pipeline_class = pipeline_class
# Remember key=(args, return_type) combinations that will fail
# compilation to avoid compilation attempt on them. The values are
# the exceptions.
self._failed_cache = {}

def fold_argument_types(self, args, kws):
"""
Given positional and named argument types, fold keyword arguments
and resolve defaults by inserting types.Omitted() instances.

A (pysig, argument types) tuple is returned.
"""

def normal_handler(index, param, value):
return value

def default_handler(index, param, default):
return types.Omitted(default)

def stararg_handler(index, param, values):
return types.StarArgTuple(values)

# For now, we take argument values from the @jit function
args = fold_arguments(
self.pysig,
args,
kws,
normal_handler,
default_handler,
stararg_handler,
)
return self.pysig, args

def compile(self, args, return_type):
status, retval = self._compile_cached(args, return_type)
if status:
return retval
else:
raise retval

def _compile_cached(self, args, return_type):
key = tuple(args), return_type
try:
return False, self._failed_cache[key]
except KeyError:
pass

try:
retval = self._compile_core(args, return_type)
except TypingError as e:
self._failed_cache[key] = e
return False, e
else:
return True, retval

def _compile_core(self, args, return_type):
flags = compiler.Flags()
self.targetdescr.options.parse_as_flags(flags, self.targetoptions)
flags = self._customize_flags(flags)

impl = self._get_implementation(args, {})
cres = compiler.compile_extra(
self.targetdescr.typing_context,
self.targetdescr.target_context,
impl,
args=args,
return_type=return_type,
flags=flags,
locals={},
pipeline_class=self.pipeline_class,
)
# Check typing error if object mode is used
if cres.typing_error is not None and not flags.enable_pyobject:
raise cres.typing_error
return cres

def get_globals_for_reduction(self):
return serialize._get_function_globals_for_reduction(self.py_func)

def _get_implementation(self, args, kws):
return self.py_func

def _customize_flags(self, flags):
return flags


class CUDADispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
"""
CUDA Dispatcher object. When configured and called, the dispatcher will
specialize itself for the given arguments (if no suitable specialized
Expand All @@ -747,10 +879,42 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
targetdescr = cuda_target

def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler):
super().__init__(
py_func, targetoptions=targetoptions, pipeline_class=pipeline_class
"""
Parameters
----------
py_func: function object to be compiled
targetoptions: dict, optional
Target-specific config options.
pipeline_class: type numba.compiler.CompilerBase
The compiler pipeline type.
"""
self.typingctx = self.targetdescr.typing_context
self.targetctx = self.targetdescr.target_context

pysig = utils.pysignature(py_func)
arg_count = len(pysig.parameters)
can_fallback = not targetoptions.get("nopython", False)

_DispatcherBase.__init__(
self,
arg_count,
py_func,
pysig,
can_fallback,
exact_match_required=False,
)

functools.update_wrapper(self, py_func)

self.targetoptions = targetoptions
self._cache = NullCache()
compiler_class = _FunctionCompiler
self._compiler = compiler_class(
py_func, self.targetdescr, targetoptions, pipeline_class
)
self._cache_hits = collections.Counter()
self._cache_misses = collections.Counter()

# The following properties are for specialization of CUDADispatchers. A
# specialized CUDADispatcher is one that is compiled for exactly one
# set of argument types, and bypasses some argument type checking for
Expand All @@ -763,13 +927,29 @@ def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler):
# argument types
self.specializations = {}

def dump(self, tab=""):
print(
f"{tab}DUMP {type(self).__name__}[{self.py_func.__name__}"
f", type code={self._type._code}]"
)
for cres in self.overloads.values():
cres.dump(tab=tab + " ")
print(f"{tab}END DUMP {type(self).__name__}[{self.py_func.__name__}]")

@property
def _numba_type_(self):
return cuda_types.CUDADispatcher(self)

def enable_caching(self):
self._cache = CUDACache(self.py_func)

def __get__(self, obj, objtype=None):
"""Allow a JIT function to be bound as a method to an object"""
if obj is None: # Unbound method
return self
else: # Bound method
return pytypes.MethodType(self, obj)

@functools.lru_cache(maxsize=128)
def configure(self, griddim, blockdim, stream=0, sharedmem=0):
griddim, blockdim = normalize_kernel_dimensions(griddim, blockdim)
Expand Down Expand Up @@ -1117,6 +1297,93 @@ def compile(self, sig):

return kernel

def get_compile_result(self, sig):
"""Compile (if needed) and return the compilation result with the
given signature.

Returns ``CompileResult``.
Raises ``NumbaError`` if the signature is incompatible.
"""
atypes = tuple(sig.args)
if atypes not in self.overloads:
if self._can_compile:
# Compiling may raise any NumbaError
self.compile(atypes)
else:
msg = f"{sig} not available and compilation disabled"
raise TypingError(msg)
return self.overloads[atypes]

def recompile(self):
"""
Recompile all signatures afresh.
"""
sigs = list(self.overloads)
old_can_compile = self._can_compile
# Ensure the old overloads are disposed of,
# including compiled functions.
self._make_finalizer()()
self._reset_overloads()
self._cache.flush()
self._can_compile = True
try:
for sig in sigs:
self.compile(sig)
finally:
self._can_compile = old_can_compile

@property
def stats(self):
return _CompileStats(
cache_path=self._cache.cache_path,
cache_hits=self._cache_hits,
cache_misses=self._cache_misses,
)

def parallel_diagnostics(self, signature=None, level=1):
"""
Print parallel diagnostic information for the given signature. If no
signature is present it is printed for all known signatures. level is
used to adjust the verbosity, level=1 (default) is minimal verbosity,
and 2, 3, and 4 provide increasing levels of verbosity.
"""

def dump(sig):
ol = self.overloads[sig]
pfdiag = ol.metadata.get("parfor_diagnostics", None)
if pfdiag is None:
msg = "No parfors diagnostic available, is 'parallel=True' set?"
raise ValueError(msg)
pfdiag.dump(level)

if signature is not None:
dump(signature)
else:
[dump(sig) for sig in self.signatures]

def get_metadata(self, signature=None):
"""
Obtain the compilation metadata for a given signature.
"""
if signature is not None:
return self.overloads[signature].metadata
else:
return dict(
(sig, self.overloads[sig].metadata) for sig in self.signatures
)

def get_function_type(self):
"""Return unique function type of dispatcher when possible, otherwise
return None.

A Dispatcher instance has unique function type when it
contains exactly one compilation result and its compilation
has been disabled (via its disable_compile method).
"""
if not self._can_compile and len(self.overloads) == 1:
cres = tuple(self.overloads.values())[0]
return types.FunctionType(cres.signature)

def inspect_llvm(self, signature=None):
"""
Return the LLVM IR for this kernel.
Expand Down