Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion numba_cuda/numba/cuda/core/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from collections import namedtuple
from numba import types
from numba.core import consts, ir
from numba.core import ir
from numba.cuda.core import consts
from numba.core.analysis import compute_cfg_from_blocks


Expand Down
124 changes: 124 additions & 0 deletions numba_cuda/numba/cuda/core/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-2-Clause
from types import ModuleType

import weakref

from numba.core.errors import ConstantInferenceError, NumbaError
from numba.core import ir


class ConstantInference(object):
"""
A constant inference engine for a given interpreter.
Inference inspects the IR to try and compute a compile-time constant for
a variable.

This shouldn't be used directly, instead call Interpreter.infer_constant().
"""

def __init__(self, func_ir):
# Avoid cyclic references as some user-visible objects may be
# held alive in the cache
self._func_ir = weakref.proxy(func_ir)
self._cache = {}

def infer_constant(self, name, loc=None):
"""
Infer a constant value for the given variable *name*.
If no value can be inferred, numba.errors.ConstantInferenceError
is raised.
"""
if name not in self._cache:
try:
self._cache[name] = (True, self._do_infer(name))
except ConstantInferenceError as exc:
# Store the exception args only, to avoid keeping
# a whole traceback alive.
self._cache[name] = (False, (exc.__class__, exc.args))
success, val = self._cache[name]
if success:
return val
else:
exc, args = val
if issubclass(exc, NumbaError):
raise exc(*args, loc=loc)
else:
raise exc(*args)

def _fail(self, val):
# The location here is set to None because `val` is the ir.Var name
# and not the actual offending use of the var. When this is raised it is
# caught in the flow control of `infer_constant` and the class and args
# (the message) are captured and then raised again but with the location
# set to the expression that caused the constant inference error.
raise ConstantInferenceError(
"Constant inference not possible for: %s" % (val,), loc=None
)

def _do_infer(self, name):
if not isinstance(name, str):
raise TypeError("infer_constant() called with non-str %r" % (name,))
try:
defn = self._func_ir.get_definition(name)
except KeyError:
raise ConstantInferenceError(
"no single definition for %r" % (name,)
)
try:
const = defn.infer_constant()
except ConstantInferenceError:
if isinstance(defn, ir.Expr):
return self._infer_expr(defn)
self._fail(defn)
return const

def _infer_expr(self, expr):
# Infer an expression: handle supported cases
if expr.op == "call":
func = self.infer_constant(expr.func.name, loc=expr.loc)
return self._infer_call(func, expr)
elif expr.op == "getattr":
value = self.infer_constant(expr.value.name, loc=expr.loc)
return self._infer_getattr(value, expr)
elif expr.op == "build_list":
return [
self.infer_constant(i.name, loc=expr.loc) for i in expr.items
]
elif expr.op == "build_tuple":
return tuple(
self.infer_constant(i.name, loc=expr.loc) for i in expr.items
)
self._fail(expr)

def _infer_call(self, func, expr):
if expr.kws or expr.vararg:
self._fail(expr)
# Check supported callables
_slice = func in (slice,)
_exc = isinstance(func, type) and issubclass(func, BaseException)
if _slice or _exc:
args = [
self.infer_constant(a.name, loc=expr.loc) for a in expr.args
]
if _slice:
return func(*args)
elif _exc:
# If the exception class is user defined it may implement a ctor
# that does not pass the args to the super. Therefore return the
# raw class and the args so this can be instantiated at the call
# site in the way the user source expects it to be.
return func, args
else:
assert 0, "Unreachable"

self._fail(expr)

def _infer_getattr(self, value, expr):
if isinstance(value, (ModuleType, type)):
# Allow looking up a constant on a class or module
try:
return getattr(value, expr.attr)
except AttributeError:
pass
self._fail(expr)
11 changes: 2 additions & 9 deletions numba_cuda/numba/cuda/core/untyped_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,9 @@
rewrites,
config,
transforms,
consts,
)


from numba.core.utils import PYVERSION

if PYVERSION < (3, 10):
from numba.core.interpreter import Interpreter
else:
from numba.cuda.core.interpreter import Interpreter
from numba.cuda.core import consts
from numba.cuda.core.interpreter import Interpreter


from numba.misc.special import literal_unroll
Expand Down