diff --git a/numba_cuda/numba/cuda/core/annotations/pretty_annotate.py b/numba_cuda/numba/cuda/core/annotations/pretty_annotate.py new file mode 100644 index 000000000..b3217a3cf --- /dev/null +++ b/numba_cuda/numba/cuda/core/annotations/pretty_annotate.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +This module implements code highlighting of numba-cuda function annotations. +""" + +from warnings import warn + +warn( + "The pretty_annotate functionality is experimental and might change API", + FutureWarning, +) + + +def hllines(code, style): + try: + from pygments import highlight + from pygments.lexers import PythonLexer + from pygments.formatters import HtmlFormatter + except ImportError: + raise ImportError("please install the 'pygments' package") + pylex = PythonLexer() + "Given a code string, return a list of html-highlighted lines" + hf = HtmlFormatter(noclasses=True, style=style, nowrap=True) + res = highlight(code, pylex, hf) + return res.splitlines() + + +def htlines(code, style): + try: + from pygments import highlight + from pygments.lexers import PythonLexer + + # TerminalFormatter does not support themes, Terminal256 should, + # but seem to not work. + from pygments.formatters import TerminalFormatter + except ImportError: + raise ImportError("please install the 'pygments' package") + pylex = PythonLexer() + "Given a code string, return a list of ANSI-highlighted lines" + hf = TerminalFormatter(style=style) + res = highlight(code, pylex, hf) + return res.splitlines() + + +def get_ansi_template(): + try: + from jinja2 import Template + except ImportError: + raise ImportError("please install the 'jinja2' package") + return Template(""" + {%- for func_key in func_data.keys() -%} + Function name: \x1b[34m{{func_data[func_key]['funcname']}}\x1b[39;49;00m + {%- if func_data[func_key]['filename'] -%} + {{'\n'}}In file: \x1b[34m{{func_data[func_key]['filename'] -}}\x1b[39;49;00m + {%- endif -%} + {{'\n'}}With signature: \x1b[34m{{func_key[1]}}\x1b[39;49;00m + {{- "\n" -}} + {%- for num, line, hl, hc in func_data[func_key]['pygments_lines'] -%} + {{-'\n'}}{{ num}}: {{hc-}} + {%- if func_data[func_key]['ir_lines'][num] -%} + {%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %} + {{-'\n'}}--{{- ' '*func_data[func_key]['python_indent'][num]}} + {{- ' '*(func_data[func_key]['ir_indent'][num][loop.index0]+4) + }}{{ir_line }}\x1b[41m{{ir_line_type-}}\x1b[39;49;00m + {%- endfor -%} + {%- endif -%} + {%- endfor -%} + {%- endfor -%} + """) + + +def get_html_template(): + try: + from jinja2 import Template + except ImportError: + raise ImportError("please install the 'jinja2' package") + return Template(""" + + + + + + + {% for func_key in func_data.keys() %} +
+ Function name: {{func_data[func_key]['funcname']}}
+ {% if func_data[func_key]['filename'] %} + in file: {{func_data[func_key]['filename']|escape}}
+ {% endif %} + with signature: {{func_key[1]|e}} +
+
+ + {%- for num, line, hl, hc in func_data[func_key]['pygments_lines'] -%} + {%- if func_data[func_key]['ir_lines'][num] %} + + {% else -%} + + {%- endif -%} + {%- endfor -%} +
+
+ + + {{num}}: + {{' '*func_data[func_key]['python_indent'][num]}}{{hl}} + + + + + {%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %} + + + + {%- endfor -%} + +
+   + {{- ' '*func_data[func_key]['python_indent'][num]}} + {{ ' '*func_data[func_key]['ir_indent'][num][loop.index0]}}{{ir_line|e -}} + {{ir_line_type}} + +
+
+
+ + {{num}}: + {{' '*func_data[func_key]['python_indent'][num]}}{{hl}} + +
+
+ {% endfor %} + + + """) + + +def reform_code(annotation): + """ + Extract the code from the Numba-cuda annotation datastructure. + + Pygments can only highlight full multi-line strings, the Numba-cuda + annotation is list of single lines, with indentation removed. + """ + ident_dict = annotation["python_indent"] + s = "" + for n, l in annotation["python_lines"]: + s = s + " " * ident_dict[n] + l + "\n" + return s + + +class Annotate: + """ + Construct syntax highlighted annotation for a given jitted function: + + Example: + + >>> from numba import cuda + >>> import numpy as np + >>> from numba.cuda.core.annotations.pretty_annotate import Annotate + >>> @cuda.jit + ... def test(a): + ... tid = cuda.grid(1) + ... size = len(a) + ... if tid < size: + ... a[tid] = 1 + >>> test[(4), (16)](np.ones(100)) + >>> Annotate(test) + + The last line will return an HTML and/or ANSI representation that will be + displayed accordingly in Jupyter/IPython. + + Function annotations persist across compilation for newly encountered + type signatures and as a result annotations are shown for all signatures + by default. + + Annotations for a specific signature can be shown by using the + ``signature`` parameter. For the above jitted function: + + >>> test.signatures + [(Array(float64, 1, 'C', False, aligned=True),)] + >>> Annotate(f, signature=f.signatures[0]) + # annotation for Array(float64, 1, 'C', False, aligned=True) + """ + + def __init__(self, function, signature=None, **kwargs): + style = kwargs.get("style", "default") + if not function.signatures: + raise ValueError( + "function need to be jitted for at least one signature" + ) + ann = function.get_annotation_info(signature=signature) + self.ann = ann + + for k, v in ann.items(): + res = hllines(reform_code(v), style) + rest = htlines(reform_code(v), style) + v["pygments_lines"] = [ + (a, b, c, d) + for (a, b), c, d in zip(v["python_lines"], res, rest) + ] + + def _repr_html_(self): + return get_html_template().render(func_data=self.ann) + + def __repr__(self): + return get_ansi_template().render(func_data=self.ann) diff --git a/numba_cuda/numba/cuda/core/annotations/type_annotations.py b/numba_cuda/numba/cuda/core/annotations/type_annotations.py index 72b985fdf..52f5b4f93 100644 --- a/numba_cuda/numba/cuda/core/annotations/type_annotations.py +++ b/numba_cuda/numba/cuda/core/annotations/type_annotations.py @@ -11,7 +11,6 @@ import textwrap from io import StringIO -import numba.core.dispatcher from numba.core import ir @@ -83,6 +82,8 @@ def __init__( self.lifted_from = lifted_from def prepare_annotations(self): + from numba.cuda.dispatcher import LiftedLoop + # Prepare annotations groupedinst = defaultdict(list) found_lifted_loop = False @@ -103,7 +104,7 @@ def prepare_annotations(self): ): atype = self.calltypes[inst.value] elif isinstance(inst.value, ir.Const) and isinstance( - inst.value.value, numba.core.dispatcher.LiftedLoop + inst.value.value, LiftedLoop ): atype = "XXX Lifted Loop XXX" found_lifted_loop = True diff --git a/numba_cuda/numba/cuda/core/entrypoints.py b/numba_cuda/numba/cuda/core/entrypoints.py new file mode 100644 index 000000000..fb8a6e06f --- /dev/null +++ b/numba_cuda/numba/cuda/core/entrypoints.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import logging +import warnings + +from importlib import metadata as importlib_metadata + + +_already_initialized = False +logger = logging.getLogger(__name__) + + +def init_all(): + """Execute all `numba_cuda_extensions` entry points with the name `init` + + If extensions have already been initialized, this function does nothing. + """ + try: + from numba.core import entrypoints + + entrypoints.init_all() + except ImportError: + pass + + global _already_initialized + if _already_initialized: + return + + # Must put this here to avoid extensions re-triggering initialization + _already_initialized = True + + def load_ep(entry_point): + """Loads a given entry point. Warns and logs on failure.""" + logger.debug("Loading extension: %s", entry_point) + try: + func = entry_point.load() + func() + except Exception as e: + msg = ( + f"Numba extension module '{entry_point.module}' " + f"failed to load due to '{type(e).__name__}({str(e)})'." + ) + warnings.warn(msg, stacklevel=3) + logger.debug("Extension loading failed for: %s", entry_point) + + eps = importlib_metadata.entry_points() + # Split, Python 3.10+ and importlib_metadata 3.6+ have the "selectable" + # interface, versions prior to that do not. See "compatibility note" in: + # https://docs.python.org/3.10/library/importlib.metadata.html#entry-points + if hasattr(eps, "select"): + for entry_point in eps.select( + group="numba_cuda_extensions", name="init" + ): + load_ep(entry_point) + else: + for entry_point in eps.get("numba_cuda_extensions", ()): + if entry_point.name == "init": + load_ep(entry_point) diff --git a/numba_cuda/numba/cuda/core/transforms.py b/numba_cuda/numba/cuda/core/transforms.py index 201bb62f1..be10a30cd 100644 --- a/numba_cuda/numba/cuda/core/transforms.py +++ b/numba_cuda/numba/cuda/core/transforms.py @@ -193,7 +193,7 @@ def _loop_lift_modify_blocks( Modify the block inplace to call to the lifted-loop. Returns a dictionary of blocks of the lifted-loop. """ - from numba.core.dispatcher import LiftedLoop + from numba.cuda.dispatcher import LiftedLoop # Copy loop blocks loop = loopinfo.loop @@ -402,7 +402,7 @@ def with_lifting(func_ir, typingctx, targetctx, flags, locals): from numba.cuda.core import postproc def dispatcher_factory(func_ir, objectmode=False, **kwargs): - from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith + from numba.cuda.dispatcher import LiftedWith, ObjModeLiftedWith myflags = flags.copy() if objectmode: diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 2f0bb529d..ca26c03a1 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -9,11 +9,13 @@ import functools import types as pytypes import weakref +from contextlib import ExitStack +from abc import abstractmethod import uuid import re from warnings import warn -from numba.core import types, errors, entrypoints +from numba.core import types, errors from numba.cuda import serialize, utils from numba import cuda @@ -32,8 +34,9 @@ CUDACompiler, kernel_fixup, compile_extra, + compile_ir, ) -from numba.cuda.core import sigutils, config +from numba.cuda.core import sigutils, config, entrypoints from numba.cuda.flags import Flags from numba.cuda.cudadrv import driver, nvvm from numba.cuda.locks import module_init_lock @@ -46,7 +49,7 @@ from numba.cuda.cudadrv.linkable_code import LinkableCode from numba.cuda.cudadrv.devices import get_context from numba.cuda.memory_management.nrt import rtsys, NRT_LIBRARY - +import numba.cuda.core.event as ev from numba.cuda.cext import _dispatcher @@ -1148,7 +1151,7 @@ def inspect_types( else: if file is not None: raise ValueError("`file` must be None if `pretty=True`") - from numba.core.annotations.pretty_annotate import Annotate + from numba.cuda.core.annotations.pretty_annotate import Annotate return Annotate(self, signature=signature, style=style) @@ -2123,6 +2126,344 @@ def _reduce_states(self): return dict(py_func=self.py_func, targetoptions=self.targetoptions) +class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase): + """ + Implementation of the hidden dispatcher objects used for lifted code + (a lifted loop is really compiled as a separate function). + """ + + _fold_args = False + can_cache = False + + def __init__(self, func_ir, typingctx, targetctx, flags, locals): + self.func_ir = func_ir + self.lifted_from = None + + self.typingctx = typingctx + self.targetctx = targetctx + self.flags = flags + self.locals = locals + + _DispatcherBase.__init__( + self, + self.func_ir.arg_count, + self.func_ir.func_id.func, + self.func_ir.func_id.pysig, + can_fallback=True, + exact_match_required=False, + ) + + def _reduce_states(self): + """ + Reduce the instance for pickling. This will serialize + the original function as well the compilation options and + compiled signatures, but not the compiled code itself. + + NOTE: part of ReduceMixin protocol + """ + return dict( + uuid=self._uuid, + func_ir=self.func_ir, + flags=self.flags, + locals=self.locals, + extras=self._reduce_extras(), + ) + + def _reduce_extras(self): + """ + NOTE: sub-class can override to add extra states + """ + return {} + + @classmethod + def _rebuild(cls, uuid, func_ir, flags, locals, extras): + """ + Rebuild an Dispatcher instance after it was __reduce__'d. + + NOTE: part of ReduceMixin protocol + """ + try: + return cls._memo[uuid] + except KeyError: + pass + + from numba.cuda.descriptor import cuda_target + + typingctx = cuda_target.typing_context + targetctx = cuda_target.target_context + + self = cls(func_ir, typingctx, targetctx, flags, locals, **extras) + self._set_uuid(uuid) + return self + + def get_source_location(self): + """Return the starting line number of the loop.""" + return self.func_ir.loc.line + + def _pre_compile(self, args, return_type, flags): + """Pre-compile actions""" + pass + + @abstractmethod + def compile(self, sig): + """Lifted code should implement a compilation method that will return + a CompileResult.entry_point for the given signature.""" + pass + + def _get_dispatcher_for_current_target(self): + # Lifted code does not honor the target switch currently. + # No work has been done to check if this can be allowed. + return self + + +class LiftedLoop(LiftedCode): + def _pre_compile(self, args, return_type, flags): + assert not flags.enable_looplift, "Enable looplift flags is on" + + def compile(self, sig): + with ExitStack() as scope: + cres = None + + def cb_compiler(dur): + if cres is not None: + self._callback_add_compiler_timer(dur, cres) + + def cb_llvm(dur): + if cres is not None: + self._callback_add_llvm_timer(dur, cres) + + scope.enter_context( + ev.install_timer("numba:compiler_lock", cb_compiler) + ) + scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm)) + scope.enter_context(global_compiler_lock) + + # Use counter to track recursion compilation depth + with self._compiling_counter: + # XXX this is mostly duplicated from Dispatcher. + flags = self.flags + args, return_type = sigutils.normalize_signature(sig) + + # Don't recompile if signature already exists + # (e.g. if another thread compiled it before we got the lock) + existing = self.overloads.get(tuple(args)) + if existing is not None: + return existing.entry_point + + self._pre_compile(args, return_type, flags) + + # copy the flags, use nopython first + npm_loop_flags = flags.copy() + npm_loop_flags.force_pyobject = False + + pyobject_loop_flags = flags.copy() + pyobject_loop_flags.force_pyobject = True + + # Clone IR to avoid (some of the) mutation in the rewrite pass + cloned_func_ir_npm = self.func_ir.copy() + cloned_func_ir_fbk = self.func_ir.copy() + + ev_details = dict( + dispatcher=self, + args=args, + return_type=return_type, + ) + with ev.trigger_event("numba:compile", data=ev_details): + # this emulates "object mode fall-back", try nopython, if it + # fails, then try again in object mode. + try: + cres = compile_ir( + typingctx=self.typingctx, + targetctx=self.targetctx, + func_ir=cloned_func_ir_npm, + args=args, + return_type=return_type, + flags=npm_loop_flags, + locals=self.locals, + lifted=(), + lifted_from=self.lifted_from, + is_lifted_loop=True, + ) + except errors.TypingError: + cres = compile_ir( + typingctx=self.typingctx, + targetctx=self.targetctx, + func_ir=cloned_func_ir_fbk, + args=args, + return_type=return_type, + flags=pyobject_loop_flags, + locals=self.locals, + lifted=(), + lifted_from=self.lifted_from, + is_lifted_loop=True, + ) + # Check typing error if object mode is used + if cres.typing_error is not None: + raise cres.typing_error + self.add_overload(cres) + return cres.entry_point + + +class LiftedWith(LiftedCode): + can_cache = True + + def _reduce_extras(self): + return dict(output_types=self.output_types) + + @property + def _numba_type_(self): + return types.Dispatcher(self) + + def get_call_template(self, args, kws): + """ + Get a typing.ConcreteTemplate for this dispatcher and the given + *args* and *kws* types. This enables the resolving of the return type. + + A (template, pysig, args, kws) tuple is returned. + """ + # Ensure an overload is available + if self._can_compile: + self.compile(tuple(args)) + + pysig = None + # Create function type for typing + func_name = self.py_func.__name__ + name = "CallTemplate({0})".format(func_name) + # The `key` isn't really used except for diagnosis here, + # so avoid keeping a reference to `cfunc`. + call_template = typing.make_concrete_template( + name, key=func_name, signatures=self.nopython_signatures + ) + return call_template, pysig, args, kws + + def compile(self, sig): + # this is similar to LiftedLoop's compile but does not have the + # "fallback" to object mode part. + with ExitStack() as scope: + cres = None + + def cb_compiler(dur): + if cres is not None: + self._callback_add_compiler_timer(dur, cres) + + def cb_llvm(dur): + if cres is not None: + self._callback_add_llvm_timer(dur, cres) + + scope.enter_context( + ev.install_timer("numba:compiler_lock", cb_compiler) + ) + scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm)) + scope.enter_context(global_compiler_lock) + + # Use counter to track recursion compilation depth + with self._compiling_counter: + # XXX this is mostly duplicated from Dispatcher. + flags = self.flags + args, return_type = sigutils.normalize_signature(sig) + + # Don't recompile if signature already exists + # (e.g. if another thread compiled it before we got the lock) + existing = self.overloads.get(tuple(args)) + if existing is not None: + return existing.entry_point + + self._pre_compile(args, return_type, flags) + + # Clone IR to avoid (some of the) mutation in the rewrite pass + cloned_func_ir = self.func_ir.copy() + + ev_details = dict( + dispatcher=self, + args=args, + return_type=return_type, + ) + with ev.trigger_event("numba:compile", data=ev_details): + cres = compile_ir( + typingctx=self.typingctx, + targetctx=self.targetctx, + func_ir=cloned_func_ir, + args=args, + return_type=return_type, + flags=flags, + locals=self.locals, + lifted=(), + lifted_from=self.lifted_from, + is_lifted_loop=True, + ) + + # Check typing error if object mode is used + if ( + cres.typing_error is not None + and not flags.enable_pyobject + ): + raise cres.typing_error + self.add_overload(cres) + return cres.entry_point + + +class ObjModeLiftedWith(LiftedWith): + def __init__(self, *args, **kwargs): + self.output_types = kwargs.pop("output_types", None) + super(LiftedWith, self).__init__(*args, **kwargs) + if not self.flags.force_pyobject: + raise ValueError("expecting `flags.force_pyobject`") + if self.output_types is None: + raise TypeError("`output_types` must be provided") + # switch off rewrites, they have no effect + self.flags.no_rewrites = True + + @property + def _numba_type_(self): + return types.ObjModeDispatcher(self) + + def get_call_template(self, args, kws): + """ + Get a typing.ConcreteTemplate for this dispatcher and the given + *args* and *kws* types. This enables the resolving of the return type. + + A (template, pysig, args, kws) tuple is returned. + """ + assert not kws + self._legalize_arg_types(args) + # Coerce to object mode + args = [types.ffi_forced_object] * len(args) + + if self._can_compile: + self.compile(tuple(args)) + + signatures = [typing.signature(self.output_types, *args)] + pysig = None + func_name = self.py_func.__name__ + name = "CallTemplate({0})".format(func_name) + call_template = typing.make_concrete_template( + name, key=func_name, signatures=signatures + ) + + return call_template, pysig, args, kws + + def _legalize_arg_types(self, args): + for i, a in enumerate(args, start=1): + if isinstance(a, types.List): + msg = ( + "Does not support list type inputs into " + "with-context for arg {}" + ) + raise errors.TypingError(msg.format(i)) + elif isinstance(a, types.Dispatcher): + msg = ( + "Does not support function type inputs into " + "with-context for arg {}" + ) + raise errors.TypingError(msg.format(i)) + + @global_compiler_lock + def compile(self, sig): + args, _ = sigutils.normalize_signature(sig) + sig = (types.ffi_forced_object,) * len(args) + return super().compile(sig) + + # Initialize typeof machinery _dispatcher.typeof_init( OmittedArg, dict((str(t), t._code) for t in types.number_domain) diff --git a/numba_cuda/numba/cuda/extending.py b/numba_cuda/numba/cuda/extending.py index 393b3eab6..4df2f068e 100644 --- a/numba_cuda/numba/cuda/extending.py +++ b/numba_cuda/numba/cuda/extending.py @@ -619,14 +619,13 @@ def bind(self, *args, **kwargs): def is_jitted(function): - """Returns True if a function is wrapped by one of the Numba @jit - decorators, for example: numba.jit, numba.njit + """Returns True if a function is wrapped by cuda.jit The purpose of this function is to provide a means to check if a function is already JIT decorated. """ # don't want to export this so import locally - from numba.core.dispatcher import Dispatcher + from numba.cuda.dispatcher import CUDADispatcher - return isinstance(function, Dispatcher) + return isinstance(function, CUDADispatcher) diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index f310a770c..22b181822 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -11,7 +11,6 @@ from numba.core import types from numba.core.compiler_lock import global_compiler_lock -from numba.core.dispatcher import Dispatcher from numba.core.errors import NumbaWarning from numba.cuda.core.base import BaseContext from numba.core.typing import cmathdecl @@ -63,25 +62,32 @@ def resolve_value_type(self, val): # treat other dispatcher object as another device function from numba.cuda.dispatcher import CUDADispatcher - if isinstance(val, Dispatcher) and not isinstance(val, CUDADispatcher): - try: - # use cached device function - val = val.__dispatcher - except AttributeError: - if not val._can_compile: - raise ValueError( - "using cpu function on device " - "but its compilation is disabled" - ) - targetoptions = val.targetoptions.copy() - targetoptions["device"] = True - targetoptions["debug"] = targetoptions.get("debug", False) - targetoptions["opt"] = targetoptions.get("opt", True) - disp = CUDADispatcher(val.py_func, targetoptions) - # cache the device function for future use and to avoid - # duplicated copy of the same function. - val.__dispatcher = disp - val = disp + try: + from numba.core.dispatcher import Dispatcher + + if isinstance(val, Dispatcher) and not isinstance( + val, CUDADispatcher + ): + try: + # use cached device function + val = val.__dispatcher + except AttributeError: + if not val._can_compile: + raise ValueError( + "using cpu function on device " + "but its compilation is disabled" + ) + targetoptions = val.targetoptions.copy() + targetoptions["device"] = True + targetoptions["debug"] = targetoptions.get("debug", False) + targetoptions["opt"] = targetoptions.get("opt", True) + disp = CUDADispatcher(val.py_func, targetoptions) + # cache the device function for future use and to avoid + # duplicated copy of the same function. + val.__dispatcher = disp + val = disp + except ImportError: + pass # continue with parent logic return super(CUDATypingContext, self).resolve_value_type(val)