diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 1b6e8ce0c..33fd463b4 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -2,14 +2,19 @@ import os import sys import ctypes +import collections import functools +import types as pytypes +import weakref +import uuid -from numba.core import serialize, sigutils, types, typing, config +from numba.core import compiler, serialize, sigutils, types, typing, config from numba.cuda import utils -from numba.core.caching import Cache, CacheImpl +from numba.core.caching import Cache, CacheImpl, NullCache from numba.core.compiler_lock import global_compiler_lock -from numba.core.dispatcher import Dispatcher -from numba.core.errors import NumbaPerformanceWarning +from numba.core.dispatcher import _DispatcherBase +from numba.core.errors import NumbaPerformanceWarning, TypingError +from numba.core.typing.templates import fold_arguments from numba.core.typing.typeof import Purpose, typeof from numba.cuda.api import get_current_device from numba.cuda.args import wrap_arg @@ -728,7 +733,134 @@ def load_overload(self, sig, target_context): return super().load_overload(sig, target_context) -class CUDADispatcher(Dispatcher, serialize.ReduceMixin): +class _MemoMixin: + __uuid = None + # A {uuid -> instance} mapping, for deserialization + _memo = weakref.WeakValueDictionary() + # hold refs to last N functions deserialized, retaining them in _memo + # regardless of whether there is another reference + _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE) + + @property + def _uuid(self): + """ + An instance-specific UUID, to avoid multiple deserializations of + a given instance. + + Note: this is lazily-generated, for performance reasons. + """ + u = self.__uuid + if u is None: + u = str(uuid.uuid4()) + self._set_uuid(u) + return u + + def _set_uuid(self, u): + assert self.__uuid is None + self.__uuid = u + self._memo[u] = self + self._recent.append(self) + + +_CompileStats = collections.namedtuple( + "_CompileStats", ("cache_path", "cache_hits", "cache_misses") +) + + +class _FunctionCompiler(object): + def __init__(self, py_func, targetdescr, targetoptions, pipeline_class): + self.py_func = py_func + self.targetdescr = targetdescr + self.targetoptions = targetoptions + self.pysig = utils.pysignature(self.py_func) + self.pipeline_class = pipeline_class + # Remember key=(args, return_type) combinations that will fail + # compilation to avoid compilation attempt on them. The values are + # the exceptions. + self._failed_cache = {} + + def fold_argument_types(self, args, kws): + """ + Given positional and named argument types, fold keyword arguments + and resolve defaults by inserting types.Omitted() instances. + + A (pysig, argument types) tuple is returned. + """ + + def normal_handler(index, param, value): + return value + + def default_handler(index, param, default): + return types.Omitted(default) + + def stararg_handler(index, param, values): + return types.StarArgTuple(values) + + # For now, we take argument values from the @jit function + args = fold_arguments( + self.pysig, + args, + kws, + normal_handler, + default_handler, + stararg_handler, + ) + return self.pysig, args + + def compile(self, args, return_type): + status, retval = self._compile_cached(args, return_type) + if status: + return retval + else: + raise retval + + def _compile_cached(self, args, return_type): + key = tuple(args), return_type + try: + return False, self._failed_cache[key] + except KeyError: + pass + + try: + retval = self._compile_core(args, return_type) + except TypingError as e: + self._failed_cache[key] = e + return False, e + else: + return True, retval + + def _compile_core(self, args, return_type): + flags = compiler.Flags() + self.targetdescr.options.parse_as_flags(flags, self.targetoptions) + flags = self._customize_flags(flags) + + impl = self._get_implementation(args, {}) + cres = compiler.compile_extra( + self.targetdescr.typing_context, + self.targetdescr.target_context, + impl, + args=args, + return_type=return_type, + flags=flags, + locals={}, + pipeline_class=self.pipeline_class, + ) + # Check typing error if object mode is used + if cres.typing_error is not None and not flags.enable_pyobject: + raise cres.typing_error + return cres + + def get_globals_for_reduction(self): + return serialize._get_function_globals_for_reduction(self.py_func) + + def _get_implementation(self, args, kws): + return self.py_func + + def _customize_flags(self, flags): + return flags + + +class CUDADispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase): """ CUDA Dispatcher object. When configured and called, the dispatcher will specialize itself for the given arguments (if no suitable specialized @@ -747,10 +879,42 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin): targetdescr = cuda_target def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler): - super().__init__( - py_func, targetoptions=targetoptions, pipeline_class=pipeline_class + """ + Parameters + ---------- + py_func: function object to be compiled + targetoptions: dict, optional + Target-specific config options. + pipeline_class: type numba.compiler.CompilerBase + The compiler pipeline type. + """ + self.typingctx = self.targetdescr.typing_context + self.targetctx = self.targetdescr.target_context + + pysig = utils.pysignature(py_func) + arg_count = len(pysig.parameters) + can_fallback = not targetoptions.get("nopython", False) + + _DispatcherBase.__init__( + self, + arg_count, + py_func, + pysig, + can_fallback, + exact_match_required=False, ) + functools.update_wrapper(self, py_func) + + self.targetoptions = targetoptions + self._cache = NullCache() + compiler_class = _FunctionCompiler + self._compiler = compiler_class( + py_func, self.targetdescr, targetoptions, pipeline_class + ) + self._cache_hits = collections.Counter() + self._cache_misses = collections.Counter() + # The following properties are for specialization of CUDADispatchers. A # specialized CUDADispatcher is one that is compiled for exactly one # set of argument types, and bypasses some argument type checking for @@ -763,6 +927,15 @@ def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler): # argument types self.specializations = {} + def dump(self, tab=""): + print( + f"{tab}DUMP {type(self).__name__}[{self.py_func.__name__}" + f", type code={self._type._code}]" + ) + for cres in self.overloads.values(): + cres.dump(tab=tab + " ") + print(f"{tab}END DUMP {type(self).__name__}[{self.py_func.__name__}]") + @property def _numba_type_(self): return cuda_types.CUDADispatcher(self) @@ -770,6 +943,13 @@ def _numba_type_(self): def enable_caching(self): self._cache = CUDACache(self.py_func) + def __get__(self, obj, objtype=None): + """Allow a JIT function to be bound as a method to an object""" + if obj is None: # Unbound method + return self + else: # Bound method + return pytypes.MethodType(self, obj) + @functools.lru_cache(maxsize=128) def configure(self, griddim, blockdim, stream=0, sharedmem=0): griddim, blockdim = normalize_kernel_dimensions(griddim, blockdim) @@ -1117,6 +1297,93 @@ def compile(self, sig): return kernel + def get_compile_result(self, sig): + """Compile (if needed) and return the compilation result with the + given signature. + + Returns ``CompileResult``. + Raises ``NumbaError`` if the signature is incompatible. + """ + atypes = tuple(sig.args) + if atypes not in self.overloads: + if self._can_compile: + # Compiling may raise any NumbaError + self.compile(atypes) + else: + msg = f"{sig} not available and compilation disabled" + raise TypingError(msg) + return self.overloads[atypes] + + def recompile(self): + """ + Recompile all signatures afresh. + """ + sigs = list(self.overloads) + old_can_compile = self._can_compile + # Ensure the old overloads are disposed of, + # including compiled functions. + self._make_finalizer()() + self._reset_overloads() + self._cache.flush() + self._can_compile = True + try: + for sig in sigs: + self.compile(sig) + finally: + self._can_compile = old_can_compile + + @property + def stats(self): + return _CompileStats( + cache_path=self._cache.cache_path, + cache_hits=self._cache_hits, + cache_misses=self._cache_misses, + ) + + def parallel_diagnostics(self, signature=None, level=1): + """ + Print parallel diagnostic information for the given signature. If no + signature is present it is printed for all known signatures. level is + used to adjust the verbosity, level=1 (default) is minimal verbosity, + and 2, 3, and 4 provide increasing levels of verbosity. + """ + + def dump(sig): + ol = self.overloads[sig] + pfdiag = ol.metadata.get("parfor_diagnostics", None) + if pfdiag is None: + msg = "No parfors diagnostic available, is 'parallel=True' set?" + raise ValueError(msg) + pfdiag.dump(level) + + if signature is not None: + dump(signature) + else: + [dump(sig) for sig in self.signatures] + + def get_metadata(self, signature=None): + """ + Obtain the compilation metadata for a given signature. + """ + if signature is not None: + return self.overloads[signature].metadata + else: + return dict( + (sig, self.overloads[sig].metadata) for sig in self.signatures + ) + + def get_function_type(self): + """Return unique function type of dispatcher when possible, otherwise + return None. + + A Dispatcher instance has unique function type when it + contains exactly one compilation result and its compilation + has been disabled (via its disable_compile method). + """ + if not self._can_compile and len(self.overloads) == 1: + cres = tuple(self.overloads.values())[0] + return types.FunctionType(cres.signature) + def inspect_llvm(self, signature=None): """ Return the LLVM IR for this kernel.