diff --git a/numba_cuda/numba/cuda/core/annotations/pretty_annotate.py b/numba_cuda/numba/cuda/core/annotations/pretty_annotate.py
new file mode 100644
index 000000000..b3217a3cf
--- /dev/null
+++ b/numba_cuda/numba/cuda/core/annotations/pretty_annotate.py
@@ -0,0 +1,288 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-2-Clause
+
+"""
+This module implements code highlighting of numba-cuda function annotations.
+"""
+
+from warnings import warn
+
+warn(
+ "The pretty_annotate functionality is experimental and might change API",
+ FutureWarning,
+)
+
+
+def hllines(code, style):
+ try:
+ from pygments import highlight
+ from pygments.lexers import PythonLexer
+ from pygments.formatters import HtmlFormatter
+ except ImportError:
+ raise ImportError("please install the 'pygments' package")
+ pylex = PythonLexer()
+ "Given a code string, return a list of html-highlighted lines"
+ hf = HtmlFormatter(noclasses=True, style=style, nowrap=True)
+ res = highlight(code, pylex, hf)
+ return res.splitlines()
+
+
+def htlines(code, style):
+ try:
+ from pygments import highlight
+ from pygments.lexers import PythonLexer
+
+ # TerminalFormatter does not support themes, Terminal256 should,
+ # but seem to not work.
+ from pygments.formatters import TerminalFormatter
+ except ImportError:
+ raise ImportError("please install the 'pygments' package")
+ pylex = PythonLexer()
+ "Given a code string, return a list of ANSI-highlighted lines"
+ hf = TerminalFormatter(style=style)
+ res = highlight(code, pylex, hf)
+ return res.splitlines()
+
+
+def get_ansi_template():
+ try:
+ from jinja2 import Template
+ except ImportError:
+ raise ImportError("please install the 'jinja2' package")
+ return Template("""
+ {%- for func_key in func_data.keys() -%}
+ Function name: \x1b[34m{{func_data[func_key]['funcname']}}\x1b[39;49;00m
+ {%- if func_data[func_key]['filename'] -%}
+ {{'\n'}}In file: \x1b[34m{{func_data[func_key]['filename'] -}}\x1b[39;49;00m
+ {%- endif -%}
+ {{'\n'}}With signature: \x1b[34m{{func_key[1]}}\x1b[39;49;00m
+ {{- "\n" -}}
+ {%- for num, line, hl, hc in func_data[func_key]['pygments_lines'] -%}
+ {{-'\n'}}{{ num}}: {{hc-}}
+ {%- if func_data[func_key]['ir_lines'][num] -%}
+ {%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %}
+ {{-'\n'}}--{{- ' '*func_data[func_key]['python_indent'][num]}}
+ {{- ' '*(func_data[func_key]['ir_indent'][num][loop.index0]+4)
+ }}{{ir_line }}\x1b[41m{{ir_line_type-}}\x1b[39;49;00m
+ {%- endfor -%}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- endfor -%}
+ """)
+
+
+def get_html_template():
+ try:
+ from jinja2 import Template
+ except ImportError:
+ raise ImportError("please install the 'jinja2' package")
+ return Template("""
+
+
+
+
+
+
+ {% for func_key in func_data.keys() %}
+
+ Function name: {{func_data[func_key]['funcname']}}
+ {% if func_data[func_key]['filename'] %}
+ in file: {{func_data[func_key]['filename']|escape}}
+ {% endif %}
+ with signature: {{func_key[1]|e}}
+
+
+
+ {%- for num, line, hl, hc in func_data[func_key]['pygments_lines'] -%}
+ {%- if func_data[func_key]['ir_lines'][num] %}
+
+
+
+
+ {{num}}:
+ {{' '*func_data[func_key]['python_indent'][num]}}{{hl}}
+
+
+
+
+ {%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %}
+
+
+
+ {{- ' '*func_data[func_key]['python_indent'][num]}}
+ {{ ' '*func_data[func_key]['ir_indent'][num][loop.index0]}}{{ir_line|e -}}
+ {{ir_line_type}}
+
+ |
+
+ {%- endfor -%}
+
+
+
+ |
+ {% else -%}
+
+
+ {{num}}:
+ {{' '*func_data[func_key]['python_indent'][num]}}{{hl}}
+
+ |
+ {%- endif -%}
+ {%- endfor -%}
+
+
+ {% endfor %}
+
+
+ """)
+
+
+def reform_code(annotation):
+ """
+ Extract the code from the Numba-cuda annotation datastructure.
+
+ Pygments can only highlight full multi-line strings, the Numba-cuda
+ annotation is list of single lines, with indentation removed.
+ """
+ ident_dict = annotation["python_indent"]
+ s = ""
+ for n, l in annotation["python_lines"]:
+ s = s + " " * ident_dict[n] + l + "\n"
+ return s
+
+
+class Annotate:
+ """
+ Construct syntax highlighted annotation for a given jitted function:
+
+ Example:
+
+ >>> from numba import cuda
+ >>> import numpy as np
+ >>> from numba.cuda.core.annotations.pretty_annotate import Annotate
+ >>> @cuda.jit
+ ... def test(a):
+ ... tid = cuda.grid(1)
+ ... size = len(a)
+ ... if tid < size:
+ ... a[tid] = 1
+ >>> test[(4), (16)](np.ones(100))
+ >>> Annotate(test)
+
+ The last line will return an HTML and/or ANSI representation that will be
+ displayed accordingly in Jupyter/IPython.
+
+ Function annotations persist across compilation for newly encountered
+ type signatures and as a result annotations are shown for all signatures
+ by default.
+
+ Annotations for a specific signature can be shown by using the
+ ``signature`` parameter. For the above jitted function:
+
+ >>> test.signatures
+ [(Array(float64, 1, 'C', False, aligned=True),)]
+ >>> Annotate(f, signature=f.signatures[0])
+ # annotation for Array(float64, 1, 'C', False, aligned=True)
+ """
+
+ def __init__(self, function, signature=None, **kwargs):
+ style = kwargs.get("style", "default")
+ if not function.signatures:
+ raise ValueError(
+ "function need to be jitted for at least one signature"
+ )
+ ann = function.get_annotation_info(signature=signature)
+ self.ann = ann
+
+ for k, v in ann.items():
+ res = hllines(reform_code(v), style)
+ rest = htlines(reform_code(v), style)
+ v["pygments_lines"] = [
+ (a, b, c, d)
+ for (a, b), c, d in zip(v["python_lines"], res, rest)
+ ]
+
+ def _repr_html_(self):
+ return get_html_template().render(func_data=self.ann)
+
+ def __repr__(self):
+ return get_ansi_template().render(func_data=self.ann)
diff --git a/numba_cuda/numba/cuda/core/annotations/type_annotations.py b/numba_cuda/numba/cuda/core/annotations/type_annotations.py
index 72b985fdf..52f5b4f93 100644
--- a/numba_cuda/numba/cuda/core/annotations/type_annotations.py
+++ b/numba_cuda/numba/cuda/core/annotations/type_annotations.py
@@ -11,7 +11,6 @@
import textwrap
from io import StringIO
-import numba.core.dispatcher
from numba.core import ir
@@ -83,6 +82,8 @@ def __init__(
self.lifted_from = lifted_from
def prepare_annotations(self):
+ from numba.cuda.dispatcher import LiftedLoop
+
# Prepare annotations
groupedinst = defaultdict(list)
found_lifted_loop = False
@@ -103,7 +104,7 @@ def prepare_annotations(self):
):
atype = self.calltypes[inst.value]
elif isinstance(inst.value, ir.Const) and isinstance(
- inst.value.value, numba.core.dispatcher.LiftedLoop
+ inst.value.value, LiftedLoop
):
atype = "XXX Lifted Loop XXX"
found_lifted_loop = True
diff --git a/numba_cuda/numba/cuda/core/entrypoints.py b/numba_cuda/numba/cuda/core/entrypoints.py
new file mode 100644
index 000000000..fb8a6e06f
--- /dev/null
+++ b/numba_cuda/numba/cuda/core/entrypoints.py
@@ -0,0 +1,59 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-2-Clause
+
+import logging
+import warnings
+
+from importlib import metadata as importlib_metadata
+
+
+_already_initialized = False
+logger = logging.getLogger(__name__)
+
+
+def init_all():
+ """Execute all `numba_cuda_extensions` entry points with the name `init`
+
+ If extensions have already been initialized, this function does nothing.
+ """
+ try:
+ from numba.core import entrypoints
+
+ entrypoints.init_all()
+ except ImportError:
+ pass
+
+ global _already_initialized
+ if _already_initialized:
+ return
+
+ # Must put this here to avoid extensions re-triggering initialization
+ _already_initialized = True
+
+ def load_ep(entry_point):
+ """Loads a given entry point. Warns and logs on failure."""
+ logger.debug("Loading extension: %s", entry_point)
+ try:
+ func = entry_point.load()
+ func()
+ except Exception as e:
+ msg = (
+ f"Numba extension module '{entry_point.module}' "
+ f"failed to load due to '{type(e).__name__}({str(e)})'."
+ )
+ warnings.warn(msg, stacklevel=3)
+ logger.debug("Extension loading failed for: %s", entry_point)
+
+ eps = importlib_metadata.entry_points()
+ # Split, Python 3.10+ and importlib_metadata 3.6+ have the "selectable"
+ # interface, versions prior to that do not. See "compatibility note" in:
+ # https://docs.python.org/3.10/library/importlib.metadata.html#entry-points
+ if hasattr(eps, "select"):
+ for entry_point in eps.select(
+ group="numba_cuda_extensions", name="init"
+ ):
+ load_ep(entry_point)
+ else:
+ for entry_point in eps.get("numba_cuda_extensions", ()):
+ if entry_point.name == "init":
+ load_ep(entry_point)
diff --git a/numba_cuda/numba/cuda/core/transforms.py b/numba_cuda/numba/cuda/core/transforms.py
index 201bb62f1..be10a30cd 100644
--- a/numba_cuda/numba/cuda/core/transforms.py
+++ b/numba_cuda/numba/cuda/core/transforms.py
@@ -193,7 +193,7 @@ def _loop_lift_modify_blocks(
Modify the block inplace to call to the lifted-loop.
Returns a dictionary of blocks of the lifted-loop.
"""
- from numba.core.dispatcher import LiftedLoop
+ from numba.cuda.dispatcher import LiftedLoop
# Copy loop blocks
loop = loopinfo.loop
@@ -402,7 +402,7 @@ def with_lifting(func_ir, typingctx, targetctx, flags, locals):
from numba.cuda.core import postproc
def dispatcher_factory(func_ir, objectmode=False, **kwargs):
- from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith
+ from numba.cuda.dispatcher import LiftedWith, ObjModeLiftedWith
myflags = flags.copy()
if objectmode:
diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py
index 2f0bb529d..ca26c03a1 100644
--- a/numba_cuda/numba/cuda/dispatcher.py
+++ b/numba_cuda/numba/cuda/dispatcher.py
@@ -9,11 +9,13 @@
import functools
import types as pytypes
import weakref
+from contextlib import ExitStack
+from abc import abstractmethod
import uuid
import re
from warnings import warn
-from numba.core import types, errors, entrypoints
+from numba.core import types, errors
from numba.cuda import serialize, utils
from numba import cuda
@@ -32,8 +34,9 @@
CUDACompiler,
kernel_fixup,
compile_extra,
+ compile_ir,
)
-from numba.cuda.core import sigutils, config
+from numba.cuda.core import sigutils, config, entrypoints
from numba.cuda.flags import Flags
from numba.cuda.cudadrv import driver, nvvm
from numba.cuda.locks import module_init_lock
@@ -46,7 +49,7 @@
from numba.cuda.cudadrv.linkable_code import LinkableCode
from numba.cuda.cudadrv.devices import get_context
from numba.cuda.memory_management.nrt import rtsys, NRT_LIBRARY
-
+import numba.cuda.core.event as ev
from numba.cuda.cext import _dispatcher
@@ -1148,7 +1151,7 @@ def inspect_types(
else:
if file is not None:
raise ValueError("`file` must be None if `pretty=True`")
- from numba.core.annotations.pretty_annotate import Annotate
+ from numba.cuda.core.annotations.pretty_annotate import Annotate
return Annotate(self, signature=signature, style=style)
@@ -2123,6 +2126,344 @@ def _reduce_states(self):
return dict(py_func=self.py_func, targetoptions=self.targetoptions)
+class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
+ """
+ Implementation of the hidden dispatcher objects used for lifted code
+ (a lifted loop is really compiled as a separate function).
+ """
+
+ _fold_args = False
+ can_cache = False
+
+ def __init__(self, func_ir, typingctx, targetctx, flags, locals):
+ self.func_ir = func_ir
+ self.lifted_from = None
+
+ self.typingctx = typingctx
+ self.targetctx = targetctx
+ self.flags = flags
+ self.locals = locals
+
+ _DispatcherBase.__init__(
+ self,
+ self.func_ir.arg_count,
+ self.func_ir.func_id.func,
+ self.func_ir.func_id.pysig,
+ can_fallback=True,
+ exact_match_required=False,
+ )
+
+ def _reduce_states(self):
+ """
+ Reduce the instance for pickling. This will serialize
+ the original function as well the compilation options and
+ compiled signatures, but not the compiled code itself.
+
+ NOTE: part of ReduceMixin protocol
+ """
+ return dict(
+ uuid=self._uuid,
+ func_ir=self.func_ir,
+ flags=self.flags,
+ locals=self.locals,
+ extras=self._reduce_extras(),
+ )
+
+ def _reduce_extras(self):
+ """
+ NOTE: sub-class can override to add extra states
+ """
+ return {}
+
+ @classmethod
+ def _rebuild(cls, uuid, func_ir, flags, locals, extras):
+ """
+ Rebuild an Dispatcher instance after it was __reduce__'d.
+
+ NOTE: part of ReduceMixin protocol
+ """
+ try:
+ return cls._memo[uuid]
+ except KeyError:
+ pass
+
+ from numba.cuda.descriptor import cuda_target
+
+ typingctx = cuda_target.typing_context
+ targetctx = cuda_target.target_context
+
+ self = cls(func_ir, typingctx, targetctx, flags, locals, **extras)
+ self._set_uuid(uuid)
+ return self
+
+ def get_source_location(self):
+ """Return the starting line number of the loop."""
+ return self.func_ir.loc.line
+
+ def _pre_compile(self, args, return_type, flags):
+ """Pre-compile actions"""
+ pass
+
+ @abstractmethod
+ def compile(self, sig):
+ """Lifted code should implement a compilation method that will return
+ a CompileResult.entry_point for the given signature."""
+ pass
+
+ def _get_dispatcher_for_current_target(self):
+ # Lifted code does not honor the target switch currently.
+ # No work has been done to check if this can be allowed.
+ return self
+
+
+class LiftedLoop(LiftedCode):
+ def _pre_compile(self, args, return_type, flags):
+ assert not flags.enable_looplift, "Enable looplift flags is on"
+
+ def compile(self, sig):
+ with ExitStack() as scope:
+ cres = None
+
+ def cb_compiler(dur):
+ if cres is not None:
+ self._callback_add_compiler_timer(dur, cres)
+
+ def cb_llvm(dur):
+ if cres is not None:
+ self._callback_add_llvm_timer(dur, cres)
+
+ scope.enter_context(
+ ev.install_timer("numba:compiler_lock", cb_compiler)
+ )
+ scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm))
+ scope.enter_context(global_compiler_lock)
+
+ # Use counter to track recursion compilation depth
+ with self._compiling_counter:
+ # XXX this is mostly duplicated from Dispatcher.
+ flags = self.flags
+ args, return_type = sigutils.normalize_signature(sig)
+
+ # Don't recompile if signature already exists
+ # (e.g. if another thread compiled it before we got the lock)
+ existing = self.overloads.get(tuple(args))
+ if existing is not None:
+ return existing.entry_point
+
+ self._pre_compile(args, return_type, flags)
+
+ # copy the flags, use nopython first
+ npm_loop_flags = flags.copy()
+ npm_loop_flags.force_pyobject = False
+
+ pyobject_loop_flags = flags.copy()
+ pyobject_loop_flags.force_pyobject = True
+
+ # Clone IR to avoid (some of the) mutation in the rewrite pass
+ cloned_func_ir_npm = self.func_ir.copy()
+ cloned_func_ir_fbk = self.func_ir.copy()
+
+ ev_details = dict(
+ dispatcher=self,
+ args=args,
+ return_type=return_type,
+ )
+ with ev.trigger_event("numba:compile", data=ev_details):
+ # this emulates "object mode fall-back", try nopython, if it
+ # fails, then try again in object mode.
+ try:
+ cres = compile_ir(
+ typingctx=self.typingctx,
+ targetctx=self.targetctx,
+ func_ir=cloned_func_ir_npm,
+ args=args,
+ return_type=return_type,
+ flags=npm_loop_flags,
+ locals=self.locals,
+ lifted=(),
+ lifted_from=self.lifted_from,
+ is_lifted_loop=True,
+ )
+ except errors.TypingError:
+ cres = compile_ir(
+ typingctx=self.typingctx,
+ targetctx=self.targetctx,
+ func_ir=cloned_func_ir_fbk,
+ args=args,
+ return_type=return_type,
+ flags=pyobject_loop_flags,
+ locals=self.locals,
+ lifted=(),
+ lifted_from=self.lifted_from,
+ is_lifted_loop=True,
+ )
+ # Check typing error if object mode is used
+ if cres.typing_error is not None:
+ raise cres.typing_error
+ self.add_overload(cres)
+ return cres.entry_point
+
+
+class LiftedWith(LiftedCode):
+ can_cache = True
+
+ def _reduce_extras(self):
+ return dict(output_types=self.output_types)
+
+ @property
+ def _numba_type_(self):
+ return types.Dispatcher(self)
+
+ def get_call_template(self, args, kws):
+ """
+ Get a typing.ConcreteTemplate for this dispatcher and the given
+ *args* and *kws* types. This enables the resolving of the return type.
+
+ A (template, pysig, args, kws) tuple is returned.
+ """
+ # Ensure an overload is available
+ if self._can_compile:
+ self.compile(tuple(args))
+
+ pysig = None
+ # Create function type for typing
+ func_name = self.py_func.__name__
+ name = "CallTemplate({0})".format(func_name)
+ # The `key` isn't really used except for diagnosis here,
+ # so avoid keeping a reference to `cfunc`.
+ call_template = typing.make_concrete_template(
+ name, key=func_name, signatures=self.nopython_signatures
+ )
+ return call_template, pysig, args, kws
+
+ def compile(self, sig):
+ # this is similar to LiftedLoop's compile but does not have the
+ # "fallback" to object mode part.
+ with ExitStack() as scope:
+ cres = None
+
+ def cb_compiler(dur):
+ if cres is not None:
+ self._callback_add_compiler_timer(dur, cres)
+
+ def cb_llvm(dur):
+ if cres is not None:
+ self._callback_add_llvm_timer(dur, cres)
+
+ scope.enter_context(
+ ev.install_timer("numba:compiler_lock", cb_compiler)
+ )
+ scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm))
+ scope.enter_context(global_compiler_lock)
+
+ # Use counter to track recursion compilation depth
+ with self._compiling_counter:
+ # XXX this is mostly duplicated from Dispatcher.
+ flags = self.flags
+ args, return_type = sigutils.normalize_signature(sig)
+
+ # Don't recompile if signature already exists
+ # (e.g. if another thread compiled it before we got the lock)
+ existing = self.overloads.get(tuple(args))
+ if existing is not None:
+ return existing.entry_point
+
+ self._pre_compile(args, return_type, flags)
+
+ # Clone IR to avoid (some of the) mutation in the rewrite pass
+ cloned_func_ir = self.func_ir.copy()
+
+ ev_details = dict(
+ dispatcher=self,
+ args=args,
+ return_type=return_type,
+ )
+ with ev.trigger_event("numba:compile", data=ev_details):
+ cres = compile_ir(
+ typingctx=self.typingctx,
+ targetctx=self.targetctx,
+ func_ir=cloned_func_ir,
+ args=args,
+ return_type=return_type,
+ flags=flags,
+ locals=self.locals,
+ lifted=(),
+ lifted_from=self.lifted_from,
+ is_lifted_loop=True,
+ )
+
+ # Check typing error if object mode is used
+ if (
+ cres.typing_error is not None
+ and not flags.enable_pyobject
+ ):
+ raise cres.typing_error
+ self.add_overload(cres)
+ return cres.entry_point
+
+
+class ObjModeLiftedWith(LiftedWith):
+ def __init__(self, *args, **kwargs):
+ self.output_types = kwargs.pop("output_types", None)
+ super(LiftedWith, self).__init__(*args, **kwargs)
+ if not self.flags.force_pyobject:
+ raise ValueError("expecting `flags.force_pyobject`")
+ if self.output_types is None:
+ raise TypeError("`output_types` must be provided")
+ # switch off rewrites, they have no effect
+ self.flags.no_rewrites = True
+
+ @property
+ def _numba_type_(self):
+ return types.ObjModeDispatcher(self)
+
+ def get_call_template(self, args, kws):
+ """
+ Get a typing.ConcreteTemplate for this dispatcher and the given
+ *args* and *kws* types. This enables the resolving of the return type.
+
+ A (template, pysig, args, kws) tuple is returned.
+ """
+ assert not kws
+ self._legalize_arg_types(args)
+ # Coerce to object mode
+ args = [types.ffi_forced_object] * len(args)
+
+ if self._can_compile:
+ self.compile(tuple(args))
+
+ signatures = [typing.signature(self.output_types, *args)]
+ pysig = None
+ func_name = self.py_func.__name__
+ name = "CallTemplate({0})".format(func_name)
+ call_template = typing.make_concrete_template(
+ name, key=func_name, signatures=signatures
+ )
+
+ return call_template, pysig, args, kws
+
+ def _legalize_arg_types(self, args):
+ for i, a in enumerate(args, start=1):
+ if isinstance(a, types.List):
+ msg = (
+ "Does not support list type inputs into "
+ "with-context for arg {}"
+ )
+ raise errors.TypingError(msg.format(i))
+ elif isinstance(a, types.Dispatcher):
+ msg = (
+ "Does not support function type inputs into "
+ "with-context for arg {}"
+ )
+ raise errors.TypingError(msg.format(i))
+
+ @global_compiler_lock
+ def compile(self, sig):
+ args, _ = sigutils.normalize_signature(sig)
+ sig = (types.ffi_forced_object,) * len(args)
+ return super().compile(sig)
+
+
# Initialize typeof machinery
_dispatcher.typeof_init(
OmittedArg, dict((str(t), t._code) for t in types.number_domain)
diff --git a/numba_cuda/numba/cuda/extending.py b/numba_cuda/numba/cuda/extending.py
index 393b3eab6..4df2f068e 100644
--- a/numba_cuda/numba/cuda/extending.py
+++ b/numba_cuda/numba/cuda/extending.py
@@ -619,14 +619,13 @@ def bind(self, *args, **kwargs):
def is_jitted(function):
- """Returns True if a function is wrapped by one of the Numba @jit
- decorators, for example: numba.jit, numba.njit
+ """Returns True if a function is wrapped by cuda.jit
The purpose of this function is to provide a means to check if a function is
already JIT decorated.
"""
# don't want to export this so import locally
- from numba.core.dispatcher import Dispatcher
+ from numba.cuda.dispatcher import CUDADispatcher
- return isinstance(function, Dispatcher)
+ return isinstance(function, CUDADispatcher)
diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py
index f310a770c..22b181822 100644
--- a/numba_cuda/numba/cuda/target.py
+++ b/numba_cuda/numba/cuda/target.py
@@ -11,7 +11,6 @@
from numba.core import types
from numba.core.compiler_lock import global_compiler_lock
-from numba.core.dispatcher import Dispatcher
from numba.core.errors import NumbaWarning
from numba.cuda.core.base import BaseContext
from numba.core.typing import cmathdecl
@@ -63,25 +62,32 @@ def resolve_value_type(self, val):
# treat other dispatcher object as another device function
from numba.cuda.dispatcher import CUDADispatcher
- if isinstance(val, Dispatcher) and not isinstance(val, CUDADispatcher):
- try:
- # use cached device function
- val = val.__dispatcher
- except AttributeError:
- if not val._can_compile:
- raise ValueError(
- "using cpu function on device "
- "but its compilation is disabled"
- )
- targetoptions = val.targetoptions.copy()
- targetoptions["device"] = True
- targetoptions["debug"] = targetoptions.get("debug", False)
- targetoptions["opt"] = targetoptions.get("opt", True)
- disp = CUDADispatcher(val.py_func, targetoptions)
- # cache the device function for future use and to avoid
- # duplicated copy of the same function.
- val.__dispatcher = disp
- val = disp
+ try:
+ from numba.core.dispatcher import Dispatcher
+
+ if isinstance(val, Dispatcher) and not isinstance(
+ val, CUDADispatcher
+ ):
+ try:
+ # use cached device function
+ val = val.__dispatcher
+ except AttributeError:
+ if not val._can_compile:
+ raise ValueError(
+ "using cpu function on device "
+ "but its compilation is disabled"
+ )
+ targetoptions = val.targetoptions.copy()
+ targetoptions["device"] = True
+ targetoptions["debug"] = targetoptions.get("debug", False)
+ targetoptions["opt"] = targetoptions.get("opt", True)
+ disp = CUDADispatcher(val.py_func, targetoptions)
+ # cache the device function for future use and to avoid
+ # duplicated copy of the same function.
+ val.__dispatcher = disp
+ val = disp
+ except ImportError:
+ pass
# continue with parent logic
return super(CUDATypingContext, self).resolve_value_type(val)