diff --git a/numba_cuda/numba/cuda/core/ir.py b/numba_cuda/numba/cuda/core/ir.py new file mode 100644 index 000000000..98503ea38 --- /dev/null +++ b/numba_cuda/numba/cuda/core/ir.py @@ -0,0 +1,1800 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from collections import defaultdict +import copy +import itertools +import os +import linecache +import pprint +import re +import sys +import operator +from types import FunctionType, BuiltinFunctionType +from functools import total_ordering +from io import StringIO + +from numba.core import errors, config +from numba.cuda.utils import UNARY_BUILTINS_TO_OPERATORS, OPERATORS_TO_BUILTINS +from numba.core.errors import ( + NotDefinedError, + RedefinedError, + VerificationError, + ConstantInferenceError, +) +from numba.cuda.core import consts + +# terminal color markup +_termcolor = errors.termcolor() + + +class Loc(object): + """Source location""" + + _defmatcher = re.compile(r"def\s+(\w+)") + + def __init__(self, filename, line, col=None, maybe_decorator=False): + """Arguments: + filename - name of the file + line - line in file + col - column + maybe_decorator - Set to True if location is likely a jit decorator + """ + self.filename = filename + self.line = line + self.col = col + self.lines = None # the source lines from the linecache + self.maybe_decorator = maybe_decorator + + def __eq__(self, other): + # equivalence is solely based on filename, line and col + if type(self) is not type(other): + return False + if self.filename != other.filename: + return False + if self.line != other.line: + return False + if self.col != other.col: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def from_function_id(cls, func_id): + return cls(func_id.filename, func_id.firstlineno, maybe_decorator=True) + + def __repr__(self): + return "Loc(filename=%s, line=%s, col=%s)" % ( + self.filename, + self.line, + self.col, + ) + + def __str__(self): + if self.col is not None: + return "%s (%s:%s)" % (self.filename, self.line, self.col) + else: + return "%s (%s)" % (self.filename, self.line) + + def _find_definition(self): + # try and find a def, go backwards from error line + fn_name = None + lines = self.get_lines() + for x in reversed(lines[: self.line - 1]): + # the strip and startswith is to handle user code with commented out + # 'def' or use of 'def' in a docstring. + if x.strip().startswith("def "): + fn_name = x + break + + return fn_name + + def _raw_function_name(self): + defn = self._find_definition() + if defn: + m = self._defmatcher.match(defn.strip()) + if m: + return m.groups()[0] + # Probably exec() or REPL. + return None + + def get_lines(self): + if self.lines is None: + path = self._get_path() + # Avoid reading from dynamic string. They are most likely + # overridden. Problem started with Python 3.13. "" seems + # to be something from multiprocessing. + lns = [] if path == "" else linecache.getlines(path) + self.lines = lns + return self.lines + + def _get_path(self): + path = None + try: + # Try to get a relative path + # ipython/jupyter input just returns as self.filename + path = os.path.relpath(self.filename) + except ValueError: + # Fallback to absolute path if error occurred in getting the + # relative path. + # This may happen on windows if the drive is different + path = os.path.abspath(self.filename) + return path + + def strformat(self, nlines_up=2): + lines = self.get_lines() + + use_line = self.line + + if self.maybe_decorator: + # try and sort out a better `loc`, if it's suspected that this loc + # points at a jit decorator by virtue of + # `__code__.co_firstlineno` + + # get lines, add a dummy entry at the start as lines count from + # 1 but list index counts from 0 + tmplines = [""] + lines + + if lines and use_line and "def " not in tmplines[use_line]: + # look forward 10 lines, unlikely anyone managed to stretch + # a jit call declaration over >10 lines?! + min_line = max(0, use_line) + max_line = use_line + 10 + selected = tmplines[min_line:max_line] + index = 0 + for idx, x in enumerate(selected): + if "def " in x: + index = idx + break + use_line = use_line + index + + ret = [] # accumulates output + if lines and use_line > 0: + + def count_spaces(string): + spaces = 0 + for x in itertools.takewhile(str.isspace, str(string)): + spaces += 1 + return spaces + + # A few places in the code still use no `loc` or default to line 1 + # this is often in places where exceptions are used for the purposes + # of flow control. As a result max is in use to prevent slice from + # `[negative: positive]` + selected = lines[max(0, use_line - nlines_up) : use_line] + + # see if selected contains a definition + def_found = False + for x in selected: + if "def " in x: + def_found = True + + # no definition found, try and find one + if not def_found: + # try and find a def, go backwards from error line + fn_name = None + for x in reversed(lines[: use_line - 1]): + if "def " in x: + fn_name = x + break + if fn_name: + ret.append(fn_name) + spaces = count_spaces(x) + ret.append(" " * (4 + spaces) + "\n") + + if selected: + ret.extend(selected[:-1]) + ret.append(_termcolor.highlight(selected[-1])) + + # point at the problem with a caret + spaces = count_spaces(selected[-1]) + ret.append(" " * (spaces) + _termcolor.indicate("^")) + + # if in the REPL source may not be available + if not ret: + if not lines: + ret = "" + elif use_line <= 0: + ret = "" + + err = _termcolor.filename('\nFile "%s", line %d:') + "\n%s" + tmp = err % (self._get_path(), use_line, _termcolor.code("".join(ret))) + return tmp + + def with_lineno(self, line, col=None): + """ + Return a new Loc with this line number. + """ + return type(self)(self.filename, line, col) + + def short(self): + """ + Returns a short string + """ + shortfilename = os.path.basename(self.filename) + return "%s:%s" % (shortfilename, self.line) + + +# Used for annotating errors when source location is unknown. +unknown_loc = Loc("unknown location", 0, 0) + + +@total_ordering +class SlotEqualityCheckMixin(object): + # some ir nodes are __dict__ free using __slots__ instead, this mixin + # should not trigger the unintended creation of __dict__. + __slots__ = tuple() + + def __eq__(self, other): + if type(self) is type(other): + for name in self.__slots__: + if getattr(self, name) != getattr(other, name): + return False + else: + return True + return False + + def __le__(self, other): + return str(self) <= str(other) + + def __hash__(self): + return id(self) + + +@total_ordering +class EqualityCheckMixin(object): + """Mixin for basic equality checking""" + + def __eq__(self, other): + if type(self) is type(other): + + def fixup(adict): + bad = ("loc", "scope") + d = dict(adict) + for x in bad: + d.pop(x, None) + return d + + d1 = fixup(self.__dict__) + d2 = fixup(other.__dict__) + if d1 == d2: + return True + return False + + def __le__(self, other): + return str(self) < str(other) + + def __hash__(self): + return id(self) + + +class VarMap(object): + def __init__(self): + self._con = {} + + def define(self, name, var): + if name in self._con: + raise RedefinedError(name) + else: + self._con[name] = var + + def get(self, name): + try: + return self._con[name] + except KeyError: + raise NotDefinedError(name) + + def __contains__(self, name): + return name in self._con + + def __len__(self): + return len(self._con) + + def __repr__(self): + return pprint.pformat(self._con) + + def __hash__(self): + return hash(self.name) + + def __iter__(self): + return self._con.iterkeys() + + def __eq__(self, other): + if type(self) is type(other): + # check keys only, else __eq__ ref cycles, scope -> varmap -> var + return self._con.keys() == other._con.keys() + return False + + def __ne__(self, other): + return not self.__eq__(other) + + +class AbstractRHS(object): + """Abstract base class for anything that can be the RHS of an assignment. + This class **does not** define any methods. + """ + + +class Inst(EqualityCheckMixin, AbstractRHS): + """ + Base class for all IR instructions. + """ + + def list_vars(self): + """ + List the variables used (read or written) by the instruction. + """ + raise NotImplementedError + + def _rec_list_vars(self, val): + """ + A recursive helper used to implement list_vars() in subclasses. + """ + if isinstance(val, Var): + return [val] + elif isinstance(val, Inst): + return val.list_vars() + elif isinstance(val, (list, tuple)): + lst = [] + for v in val: + lst.extend(self._rec_list_vars(v)) + return lst + elif isinstance(val, dict): + lst = [] + for v in val.values(): + lst.extend(self._rec_list_vars(v)) + return lst + else: + return [] + + +class Stmt(Inst): + """ + Base class for IR statements (instructions which can appear on their + own in a Block). + """ + + # Whether this statement ends its basic block (i.e. it will either jump + # to another block or exit the function). + is_terminator = False + # Whether this statement exits the function. + is_exit = False + + def list_vars(self): + return self._rec_list_vars(self.__dict__) + + +class Terminator(Stmt): + """ + IR statements that are terminators: the last statement in a block. + A terminator must either: + - exit the function + - jump to a block + + All subclass of Terminator must override `.get_targets()` to return a list + of jump targets. + """ + + is_terminator = True + + def get_targets(self): + raise NotImplementedError(type(self)) + + +class Expr(Inst): + """ + An IR expression (an instruction which can only be part of a larger + statement). + """ + + def __init__(self, op, loc, **kws): + assert isinstance(op, str) + assert isinstance(loc, Loc) + self.op = op + self.loc = loc + self._kws = kws + + def __getattr__(self, name): + if name.startswith("_"): + return Inst.__getattr__(self, name) + return self._kws[name] + + def __setattr__(self, name, value): + if name in ("op", "loc", "_kws"): + self.__dict__[name] = value + else: + self._kws[name] = value + + @classmethod + def binop(cls, fn, lhs, rhs, loc): + assert isinstance(fn, BuiltinFunctionType) + assert isinstance(lhs, Var) + assert isinstance(rhs, Var) + assert isinstance(loc, Loc) + op = "binop" + return cls( + op=op, + loc=loc, + fn=fn, + lhs=lhs, + rhs=rhs, + static_lhs=UNDEFINED, + static_rhs=UNDEFINED, + ) + + @classmethod + def inplace_binop(cls, fn, immutable_fn, lhs, rhs, loc): + assert isinstance(fn, BuiltinFunctionType) + assert isinstance(immutable_fn, BuiltinFunctionType) + assert isinstance(lhs, Var) + assert isinstance(rhs, Var) + assert isinstance(loc, Loc) + op = "inplace_binop" + return cls( + op=op, + loc=loc, + fn=fn, + immutable_fn=immutable_fn, + lhs=lhs, + rhs=rhs, + static_lhs=UNDEFINED, + static_rhs=UNDEFINED, + ) + + @classmethod + def unary(cls, fn, value, loc): + assert isinstance(value, (str, Var, FunctionType)) + assert isinstance(loc, Loc) + op = "unary" + fn = UNARY_BUILTINS_TO_OPERATORS.get(fn, fn) + return cls(op=op, loc=loc, fn=fn, value=value) + + @classmethod + def call( + cls, func, args, kws, loc, vararg=None, varkwarg=None, target=None + ): + assert isinstance(func, Var) + assert isinstance(loc, Loc) + op = "call" + return cls( + op=op, + loc=loc, + func=func, + args=args, + kws=kws, + vararg=vararg, + varkwarg=varkwarg, + target=target, + ) + + @classmethod + def build_tuple(cls, items, loc): + assert isinstance(loc, Loc) + op = "build_tuple" + return cls(op=op, loc=loc, items=items) + + @classmethod + def build_list(cls, items, loc): + assert isinstance(loc, Loc) + op = "build_list" + return cls(op=op, loc=loc, items=items) + + @classmethod + def build_set(cls, items, loc): + assert isinstance(loc, Loc) + op = "build_set" + return cls(op=op, loc=loc, items=items) + + @classmethod + def build_map(cls, items, size, literal_value, value_indexes, loc): + assert isinstance(loc, Loc) + op = "build_map" + return cls( + op=op, + loc=loc, + items=items, + size=size, + literal_value=literal_value, + value_indexes=value_indexes, + ) + + @classmethod + def pair_first(cls, value, loc): + assert isinstance(value, Var) + op = "pair_first" + return cls(op=op, loc=loc, value=value) + + @classmethod + def pair_second(cls, value, loc): + assert isinstance(value, Var) + assert isinstance(loc, Loc) + op = "pair_second" + return cls(op=op, loc=loc, value=value) + + @classmethod + def getiter(cls, value, loc): + assert isinstance(value, Var) + assert isinstance(loc, Loc) + op = "getiter" + return cls(op=op, loc=loc, value=value) + + @classmethod + def iternext(cls, value, loc): + assert isinstance(value, Var) + assert isinstance(loc, Loc) + op = "iternext" + return cls(op=op, loc=loc, value=value) + + @classmethod + def exhaust_iter(cls, value, count, loc): + assert isinstance(value, Var) + assert isinstance(count, int) + assert isinstance(loc, Loc) + op = "exhaust_iter" + return cls(op=op, loc=loc, value=value, count=count) + + @classmethod + def getattr(cls, value, attr, loc): + assert isinstance(value, Var) + assert isinstance(attr, str) + assert isinstance(loc, Loc) + op = "getattr" + return cls(op=op, loc=loc, value=value, attr=attr) + + @classmethod + def getitem(cls, value, index, loc): + assert isinstance(value, Var) + assert isinstance(index, Var) + assert isinstance(loc, Loc) + op = "getitem" + fn = operator.getitem + return cls(op=op, loc=loc, value=value, index=index, fn=fn) + + @classmethod + def typed_getitem(cls, value, dtype, index, loc): + assert isinstance(value, Var) + assert isinstance(loc, Loc) + op = "typed_getitem" + return cls(op=op, loc=loc, value=value, dtype=dtype, index=index) + + @classmethod + def static_getitem(cls, value, index, index_var, loc): + assert isinstance(value, Var) + assert index_var is None or isinstance(index_var, Var) + assert isinstance(loc, Loc) + op = "static_getitem" + fn = operator.getitem + return cls( + op=op, loc=loc, value=value, index=index, index_var=index_var, fn=fn + ) + + @classmethod + def cast(cls, value, loc): + """ + A node for implicit casting at the return statement + """ + assert isinstance(value, Var) + assert isinstance(loc, Loc) + op = "cast" + return cls(op=op, value=value, loc=loc) + + @classmethod + def phi(cls, loc): + """Phi node""" + assert isinstance(loc, Loc) + return cls(op="phi", incoming_values=[], incoming_blocks=[], loc=loc) + + @classmethod + def make_function(cls, name, code, closure, defaults, loc): + """ + A node for making a function object. + """ + assert isinstance(loc, Loc) + op = "make_function" + return cls( + op=op, + name=name, + code=code, + closure=closure, + defaults=defaults, + loc=loc, + ) + + @classmethod + def null(cls, loc): + """ + A node for null value. + + This node is not handled by type inference. It is only added by + post-typing passes. + """ + assert isinstance(loc, Loc) + op = "null" + return cls(op=op, loc=loc) + + @classmethod + def undef(cls, loc): + """ + A node for undefined value specifically from LOAD_FAST_AND_CLEAR opcode. + """ + assert isinstance(loc, Loc) + op = "undef" + return cls(op=op, loc=loc) + + @classmethod + def dummy(cls, op, info, loc): + """ + A node for a dummy value. + + This node is a place holder for carrying information through to a point + where it is rewritten into something valid. This node is not handled + by type inference or lowering. It's presence outside of the interpreter + renders IR as illegal. + """ + assert isinstance(loc, Loc) + assert isinstance(op, str) + return cls(op=op, info=info, loc=loc) + + def __repr__(self): + if self.op == "call": + args = ", ".join(str(a) for a in self.args) + pres_order = ( + self._kws.items() + if config.DIFF_IR == 0 + else sorted(self._kws.items()) + ) + kws = ", ".join("%s=%s" % (k, v) for k, v in pres_order) + vararg = "*%s" % (self.vararg,) if self.vararg is not None else "" + arglist = ", ".join(filter(None, [args, vararg, kws])) + return "call %s(%s)" % (self.func, arglist) + elif self.op == "binop": + lhs, rhs = self.lhs, self.rhs + if self.fn == operator.contains: + lhs, rhs = rhs, lhs + fn = OPERATORS_TO_BUILTINS.get(self.fn, self.fn) + return "%s %s %s" % (lhs, fn, rhs) + else: + pres_order = ( + self._kws.items() + if config.DIFF_IR == 0 + else sorted(self._kws.items()) + ) + args = ("%s=%s" % (k, v) for k, v in pres_order) + return "%s(%s)" % (self.op, ", ".join(args)) + + def list_vars(self): + return self._rec_list_vars(self._kws) + + def infer_constant(self): + raise ConstantInferenceError("%s" % self, loc=self.loc) + + +class SetItem(Stmt): + """ + target[index] = value + """ + + def __init__(self, target, index, value, loc): + assert isinstance(target, Var) + assert isinstance(index, Var) + assert isinstance(value, Var) + assert isinstance(loc, Loc) + self.target = target + self.index = index + self.value = value + self.loc = loc + + def __repr__(self): + return "%s[%s] = %s" % (self.target, self.index, self.value) + + +class StaticSetItem(Stmt): + """ + target[constant index] = value + """ + + def __init__(self, target, index, index_var, value, loc): + assert isinstance(target, Var) + assert not isinstance(index, Var) + assert isinstance(index_var, Var) + assert isinstance(value, Var) + assert isinstance(loc, Loc) + self.target = target + self.index = index + self.index_var = index_var + self.value = value + self.loc = loc + + def __repr__(self): + return "%s[%r] = %s" % (self.target, self.index, self.value) + + +class DelItem(Stmt): + """ + del target[index] + """ + + def __init__(self, target, index, loc): + assert isinstance(target, Var) + assert isinstance(index, Var) + assert isinstance(loc, Loc) + self.target = target + self.index = index + self.loc = loc + + def __repr__(self): + return "del %s[%s]" % (self.target, self.index) + + +class SetAttr(Stmt): + def __init__(self, target, attr, value, loc): + assert isinstance(target, Var) + assert isinstance(attr, str) + assert isinstance(value, Var) + assert isinstance(loc, Loc) + self.target = target + self.attr = attr + self.value = value + self.loc = loc + + def __repr__(self): + return "(%s).%s = %s" % (self.target, self.attr, self.value) + + +class DelAttr(Stmt): + def __init__(self, target, attr, loc): + assert isinstance(target, Var) + assert isinstance(attr, str) + assert isinstance(loc, Loc) + self.target = target + self.attr = attr + self.loc = loc + + def __repr__(self): + return "del (%s).%s" % (self.target, self.attr) + + +class StoreMap(Stmt): + def __init__(self, dct, key, value, loc): + assert isinstance(dct, Var) + assert isinstance(key, Var) + assert isinstance(value, Var) + assert isinstance(loc, Loc) + self.dct = dct + self.key = key + self.value = value + self.loc = loc + + def __repr__(self): + return "%s[%s] = %s" % (self.dct, self.key, self.value) + + +class Del(Stmt): + def __init__(self, value, loc): + assert isinstance(value, str) + assert isinstance(loc, Loc) + self.value = value + self.loc = loc + + def __str__(self): + return "del %s" % self.value + + +class Raise(Terminator): + is_exit = True + + def __init__(self, exception, loc): + assert exception is None or isinstance(exception, Var) + assert isinstance(loc, Loc) + self.exception = exception + self.loc = loc + + def __str__(self): + return "raise %s" % self.exception + + def get_targets(self): + return [] + + +class StaticRaise(Terminator): + """ + Raise an exception class and arguments known at compile-time. + Note that if *exc_class* is None, a bare "raise" statement is implied + (i.e. re-raise the current exception). + """ + + is_exit = True + + def __init__(self, exc_class, exc_args, loc): + assert exc_class is None or isinstance(exc_class, type) + assert isinstance(loc, Loc) + assert exc_args is None or isinstance(exc_args, tuple) + self.exc_class = exc_class + self.exc_args = exc_args + self.loc = loc + + def __str__(self): + if self.exc_class is None: + return " raise" + elif self.exc_args is None: + return " raise %s" % (self.exc_class,) + else: + return " raise %s(%s)" % ( + self.exc_class, + ", ".join(map(repr, self.exc_args)), + ) + + def get_targets(self): + return [] + + +class DynamicRaise(Terminator): + """ + Raise an exception class and some argument *values* unknown at compile-time. + Note that if *exc_class* is None, a bare "raise" statement is implied + (i.e. re-raise the current exception). + """ + + is_exit = True + + def __init__(self, exc_class, exc_args, loc): + assert exc_class is None or isinstance(exc_class, type) + assert isinstance(loc, Loc) + assert exc_args is None or isinstance(exc_args, tuple) + self.exc_class = exc_class + self.exc_args = exc_args + self.loc = loc + + def __str__(self): + if self.exc_class is None: + return " raise" + elif self.exc_args is None: + return " raise %s" % (self.exc_class,) + else: + return " raise %s(%s)" % ( + self.exc_class, + ", ".join(map(repr, self.exc_args)), + ) + + def get_targets(self): + return [] + + +class TryRaise(Stmt): + """A raise statement inside a try-block + Similar to ``Raise`` but does not terminate. + """ + + def __init__(self, exception, loc): + assert exception is None or isinstance(exception, Var) + assert isinstance(loc, Loc) + self.exception = exception + self.loc = loc + + def __str__(self): + return "try_raise %s" % self.exception + + +class StaticTryRaise(Stmt): + """A raise statement inside a try-block. + Similar to ``StaticRaise`` but does not terminate. + """ + + def __init__(self, exc_class, exc_args, loc): + assert exc_class is None or isinstance(exc_class, type) + assert isinstance(loc, Loc) + assert exc_args is None or isinstance(exc_args, tuple) + self.exc_class = exc_class + self.exc_args = exc_args + self.loc = loc + + def __str__(self): + if self.exc_class is None: + return "static_try_raise" + elif self.exc_args is None: + return f"static_try_raise {self.exc_class}" + else: + args = ", ".join(map(repr, self.exc_args)) + return f"static_try_raise {self.exc_class}({args})" + + +class DynamicTryRaise(Stmt): + """A raise statement inside a try-block. + Similar to ``DynamicRaise`` but does not terminate. + """ + + def __init__(self, exc_class, exc_args, loc): + assert exc_class is None or isinstance(exc_class, type) + assert isinstance(loc, Loc) + assert exc_args is None or isinstance(exc_args, tuple) + self.exc_class = exc_class + self.exc_args = exc_args + self.loc = loc + + def __str__(self): + if self.exc_class is None: + return "dynamic_try_raise" + elif self.exc_args is None: + return f"dynamic_try_raise {self.exc_class}" + else: + args = ", ".join(map(repr, self.exc_args)) + return f"dynamic_try_raise {self.exc_class}({args})" + + +class Return(Terminator): + """ + Return to caller. + """ + + is_exit = True + + def __init__(self, value, loc): + assert isinstance(value, Var), type(value) + assert isinstance(loc, Loc) + self.value = value + self.loc = loc + + def __str__(self): + return "return %s" % self.value + + def get_targets(self): + return [] + + +class Jump(Terminator): + """ + Unconditional branch. + """ + + def __init__(self, target, loc): + assert isinstance(loc, Loc) + self.target = target + self.loc = loc + + def __str__(self): + return "jump %s" % self.target + + def get_targets(self): + return [self.target] + + +class Branch(Terminator): + """ + Conditional branch. + """ + + def __init__(self, cond, truebr, falsebr, loc): + assert isinstance(cond, Var) + assert isinstance(loc, Loc) + self.cond = cond + self.truebr = truebr + self.falsebr = falsebr + self.loc = loc + + def __str__(self): + return "branch %s, %s, %s" % (self.cond, self.truebr, self.falsebr) + + def get_targets(self): + return [self.truebr, self.falsebr] + + +class Assign(Stmt): + """ + Assign to a variable. + """ + + def __init__(self, value, target, loc): + assert isinstance(value, AbstractRHS) + assert isinstance(target, Var) + assert isinstance(loc, Loc) + self.value = value + self.target = target + self.loc = loc + + def __str__(self): + return "%s = %s" % (self.target, self.value) + + +class Print(Stmt): + """ + Print some values. + """ + + def __init__(self, args, vararg, loc): + assert all(isinstance(x, Var) for x in args) + assert vararg is None or isinstance(vararg, Var) + assert isinstance(loc, Loc) + self.args = tuple(args) + self.vararg = vararg + # Constant-inferred arguments + self.consts = {} + self.loc = loc + + def __str__(self): + return "print(%s)" % ", ".join(str(v) for v in self.args) + + +class Yield(Inst): + def __init__(self, value, loc, index): + assert isinstance(value, Var) + assert isinstance(loc, Loc) + self.value = value + self.loc = loc + self.index = index + + def __str__(self): + return "yield %s" % (self.value,) + + def list_vars(self): + return [self.value] + + +class EnterWith(Stmt): + """Enter a "with" context""" + + def __init__(self, contextmanager, begin, end, loc): + """ + Parameters + ---------- + contextmanager : IR value + begin, end : int + The beginning and the ending offset of the with-body. + loc : ir.Loc instance + Source location + """ + assert isinstance(contextmanager, Var) + assert isinstance(loc, Loc) + self.contextmanager = contextmanager + self.begin = begin + self.end = end + self.loc = loc + + def __str__(self): + return "enter_with {}".format(self.contextmanager) + + def list_vars(self): + return [self.contextmanager] + + +class PopBlock(Stmt): + """Marker statement for a pop block op code""" + + def __init__(self, loc): + assert isinstance(loc, Loc) + self.loc = loc + + def __str__(self): + return "pop_block" + + +class Arg(EqualityCheckMixin, AbstractRHS): + def __init__(self, name, index, loc): + assert isinstance(name, str) + assert isinstance(index, int) + assert isinstance(loc, Loc) + self.name = name + self.index = index + self.loc = loc + + def __repr__(self): + return "arg(%d, name=%s)" % (self.index, self.name) + + def infer_constant(self): + raise ConstantInferenceError("%s" % self, loc=self.loc) + + +class Const(EqualityCheckMixin, AbstractRHS): + def __init__(self, value, loc, use_literal_type=True): + assert isinstance(loc, Loc) + self.value = value + self.loc = loc + # Note: need better way to tell if this is a literal or not. + self.use_literal_type = use_literal_type + + def __repr__(self): + return "const(%s, %s)" % (type(self.value).__name__, self.value) + + def infer_constant(self): + return self.value + + def __deepcopy__(self, memo): + # Override to not copy constant values in code + return Const( + value=self.value, + loc=self.loc, + use_literal_type=self.use_literal_type, + ) + + +class Global(EqualityCheckMixin, AbstractRHS): + def __init__(self, name, value, loc): + assert isinstance(loc, Loc) + self.name = name + self.value = value + self.loc = loc + + def __str__(self): + return "global(%s: %s)" % (self.name, self.value) + + def infer_constant(self): + return self.value + + def __deepcopy__(self, memo): + # don't copy value since it can fail (e.g. modules) + # value is readonly and doesn't need copying + return Global(self.name, self.value, copy.deepcopy(self.loc)) + + +class FreeVar(EqualityCheckMixin, AbstractRHS): + """ + A freevar, as loaded by LOAD_DECREF. + (i.e. a variable defined in an enclosing non-global scope) + """ + + def __init__(self, index, name, value, loc): + assert isinstance(index, int) + assert isinstance(name, str) + assert isinstance(loc, Loc) + # index inside __code__.co_freevars + self.index = index + # variable name + self.name = name + # frozen value + self.value = value + self.loc = loc + + def __str__(self): + return "freevar(%s: %s)" % (self.name, self.value) + + def infer_constant(self): + return self.value + + def __deepcopy__(self, memo): + # Override to not copy constant values in code + return FreeVar( + index=self.index, name=self.name, value=self.value, loc=self.loc + ) + + +class Var(EqualityCheckMixin, AbstractRHS): + """ + Attributes + ----------- + - scope: Scope + + - name: str + + - loc: Loc + Definition location + """ + + def __init__(self, scope, name, loc): + # NOTE: Use of scope=None should be removed. + assert scope is None or isinstance(scope, Scope) + assert isinstance(name, str) + assert isinstance(loc, Loc) + self.scope = scope + self.name = name + self.loc = loc + + def __repr__(self): + return "Var(%s, %s)" % (self.name, self.loc.short()) + + def __str__(self): + return self.name + + @property + def is_temp(self): + return self.name.startswith("$") + + @property + def unversioned_name(self): + """The unversioned name of this variable, i.e. SSA renaming removed""" + for k, redef_set in self.scope.var_redefinitions.items(): + if self.name in redef_set: + return k + return self.name + + @property + def versioned_names(self): + """Known versioned names for this variable, i.e. known variable names in + the scope that have been formed from applying SSA to this variable + """ + return self.scope.get_versions_of(self.unversioned_name) + + @property + def all_names(self): + """All known versioned and unversioned names for this variable""" + return self.versioned_names | { + self.unversioned_name, + } + + def __deepcopy__(self, memo): + out = Var(copy.deepcopy(self.scope, memo), self.name, self.loc) + memo[id(self)] = out + return out + + +class Scope(EqualityCheckMixin): + """ + Attributes + ----------- + - parent: Scope + Parent scope + + - localvars: VarMap + Scope-local variable map + + - loc: Loc + Start of scope location + + """ + + def __init__(self, parent, loc): + assert parent is None or isinstance(parent, Scope) + assert isinstance(loc, Loc) + self.parent = parent + self.localvars = VarMap() + self.loc = loc + self.redefined = defaultdict(int) + self.var_redefinitions = defaultdict(set) + + def define(self, name, loc): + """ + Define a variable + """ + v = Var(scope=self, name=name, loc=loc) + self.localvars.define(v.name, v) + return v + + def get(self, name): + """ + Refer to a variable. Returns the latest version. + """ + if name in self.redefined: + name = "%s.%d" % (name, self.redefined[name]) + return self.get_exact(name) + + def get_exact(self, name): + """ + Refer to a variable. The returned variable has the exact + name (exact variable version). + """ + try: + return self.localvars.get(name) + except NotDefinedError: + if self.has_parent: + return self.parent.get(name) + else: + raise + + def get_or_define(self, name, loc): + if name in self.redefined: + name = "%s.%d" % (name, self.redefined[name]) + + if name not in self.localvars: + return self.define(name, loc) + else: + return self.localvars.get(name) + + def redefine(self, name, loc, rename=True): + """ + Redefine if the name is already defined + """ + if name not in self.localvars: + return self.define(name, loc) + elif not rename: + # Must use the same name if the variable is a cellvar, which + # means it could be captured in a closure. + return self.localvars.get(name) + else: + while True: + ct = self.redefined[name] + self.redefined[name] = ct + 1 + newname = "%s.%d" % (name, ct + 1) + try: + res = self.define(newname, loc) + except RedefinedError: + continue + else: + self.var_redefinitions[name].add(newname) + return res + + def get_versions_of(self, name): + """ + Gets all known versions of a given name + """ + vers = set() + + def walk(thename): + redefs = self.var_redefinitions.get(thename, None) + if redefs: + for v in redefs: + vers.add(v) + walk(v) + + walk(name) + return vers + + def make_temp(self, loc): + n = len(self.localvars) + v = Var(scope=self, name="$%d" % n, loc=loc) + self.localvars.define(v.name, v) + return v + + @property + def has_parent(self): + return self.parent is not None + + def __repr__(self): + return "Scope(has_parent=%r, num_vars=%d, %s)" % ( + self.has_parent, + len(self.localvars), + self.loc, + ) + + +class Block(EqualityCheckMixin): + """A code block""" + + def __init__(self, scope, loc): + assert isinstance(scope, Scope) + assert isinstance(loc, Loc) + self.scope = scope + self.body = [] + self.loc = loc + + def copy(self): + block = Block(self.scope, self.loc) + block.body = self.body[:] + return block + + def find_exprs(self, op=None): + """ + Iterate over exprs of the given *op* in this block. + """ + for inst in self.body: + if isinstance(inst, Assign): + expr = inst.value + if isinstance(expr, Expr): + if op is None or expr.op == op: + yield expr + + def find_insts(self, cls=None): + """ + Iterate over insts of the given class in this block. + """ + for inst in self.body: + if isinstance(inst, cls): + yield inst + + def find_variable_assignment(self, name): + """ + Returns the assignment inst associated with variable "name", None if + it cannot be found. + """ + for x in self.find_insts(cls=Assign): + if x.target.name == name: + return x + return None + + def prepend(self, inst): + assert isinstance(inst, Stmt) + self.body.insert(0, inst) + + def append(self, inst): + assert isinstance(inst, Stmt) + self.body.append(inst) + + def remove(self, inst): + assert isinstance(inst, Stmt) + del self.body[self.body.index(inst)] + + def clear(self): + del self.body[:] + + def dump(self, file=None): + # Avoid early bind of sys.stdout as default value + file = file or sys.stdout + for inst in self.body: + if hasattr(inst, "dump"): + inst.dump(file) + else: + inst_vars = sorted(str(v) for v in inst.list_vars()) + print(" %-40s %s" % (inst, inst_vars), file=file) + + @property + def terminator(self): + return self.body[-1] + + @property + def is_terminated(self): + return self.body and self.body[-1].is_terminator + + def verify(self): + if not self.is_terminated: + raise VerificationError("Missing block terminator") + # Only the last instruction can be a terminator + for inst in self.body[:-1]: + if inst.is_terminator: + raise VerificationError( + "Terminator before the last instruction" + ) + + def insert_after(self, stmt, other): + """ + Insert *stmt* after *other*. + """ + index = self.body.index(other) + self.body.insert(index + 1, stmt) + + def insert_before_terminator(self, stmt): + assert isinstance(stmt, Stmt) + assert self.is_terminated + self.body.insert(-1, stmt) + + def __repr__(self): + return "" % (self.loc,) + + +class Loop(SlotEqualityCheckMixin): + """Describes a loop-block""" + + __slots__ = "entry", "exit" + + def __init__(self, entry, exit): + self.entry = entry + self.exit = exit + + def __repr__(self): + args = self.entry, self.exit + return "Loop(entry=%s, exit=%s)" % args + + +class With(SlotEqualityCheckMixin): + """Describes a with-block""" + + __slots__ = "entry", "exit" + + def __init__(self, entry, exit): + self.entry = entry + self.exit = exit + + def __repr__(self): + args = self.entry, self.exit + return "With(entry=%s, exit=%s)" % args + + +class FunctionIR(object): + def __init__( + self, + blocks, + is_generator, + func_id, + loc, + definitions, + arg_count, + arg_names, + ): + self.blocks = blocks + self.is_generator = is_generator + self.func_id = func_id + self.loc = loc + self.arg_count = arg_count + self.arg_names = arg_names + + self._definitions = definitions + + self._reset_analysis_variables() + + def equal_ir(self, other): + """Checks that the IR contained within is equal to the IR in other. + Equality is defined by being equal in fundamental structure (blocks, + labels, IR node type and the order in which they are defined) and the + IR nodes being equal. IR node equality essentially comes down to + ensuring a node's `.__dict__` or `.__slots__` is equal, with the + exception of ignoring 'loc' and 'scope' entries. The upshot is that the + comparison is essentially location and scope invariant, but otherwise + behaves as unsurprisingly as possible. + """ + if type(self) is type(other): + return self.blocks == other.blocks + return False + + def diff_str(self, other): + """ + Compute a human readable difference in the IR, returns a formatted + string ready for printing. + """ + msg = [] + for label, block in self.blocks.items(): + other_blk = other.blocks.get(label, None) + if other_blk is not None: + if block != other_blk: + msg.append(("Block %s differs" % label).center(80, "-")) + # see if the instructions are just a permutation + block_del = [x for x in block.body if isinstance(x, Del)] + oth_del = [x for x in other_blk.body if isinstance(x, Del)] + if block_del != oth_del: + # this is a common issue, dels are all present, but + # order shuffled. + if sorted(block_del) == sorted(oth_del): + msg.append( + ( + "Block %s contains the same dels but " + "their order is different" + ) + % label + ) + if len(block.body) > len(other_blk.body): + msg.append("This block contains more statements") + elif len(block.body) < len(other_blk.body): + msg.append("Other block contains more statements") + + # find the indexes where they don't match + tmp = [] + for idx, stmts in enumerate( + zip(block.body, other_blk.body) + ): + b_s, o_s = stmts + if b_s != o_s: + tmp.append(idx) + + def get_pad(ablock, l): + pointer = "-> " + sp = len(pointer) * " " + pad = [] + nstmt = len(ablock) + for i in range(nstmt): + if i in tmp: + item = pointer + elif i >= l: + item = pointer + else: + item = sp + pad.append(item) + return pad + + min_stmt_len = min(len(block.body), len(other_blk.body)) + + with StringIO() as buf: + it = [("self", block), ("other", other_blk)] + for name, _block in it: + buf.truncate(0) + _block.dump(file=buf) + stmts = buf.getvalue().splitlines() + pad = get_pad(_block.body, min_stmt_len) + title = "%s: block %s" % (name, label) + msg.append(title.center(80, "-")) + msg.extend( + [ + "{0}{1}".format(a, b) + for a, b in zip(pad, stmts) + ] + ) + if msg == []: + msg.append("IR is considered equivalent.") + return "\n".join(msg) + + def _reset_analysis_variables(self): + self._consts = consts.ConstantInference(self) + + # Will be computed by PostProcessor + self.generator_info = None + self.variable_lifetime = None + # { ir.Block: { variable names (potentially) alive at start of block } } + self.block_entry_vars = {} + + def derive( + self, + blocks, + arg_count=None, + arg_names=None, + force_non_generator=False, + loc=None, + ): + """ + Derive a new function IR from this one, using the given blocks, + and possibly modifying the argument count and generator flag. + + Post-processing will have to be run again on the new IR. + """ + firstblock = blocks[min(blocks)] + + new_ir = copy.copy(self) + new_ir.blocks = blocks + new_ir.loc = firstblock.loc if loc is None else loc + if force_non_generator: + new_ir.is_generator = False + if arg_count is not None: + new_ir.arg_count = arg_count + if arg_names is not None: + new_ir.arg_names = arg_names + new_ir._reset_analysis_variables() + # Make fresh func_id + new_ir.func_id = new_ir.func_id.derive() + return new_ir + + def copy(self): + new_ir = copy.copy(self) + blocks = {} + block_entry_vars = {} + for label, block in self.blocks.items(): + new_block = block.copy() + blocks[label] = new_block + if block in self.block_entry_vars: + block_entry_vars[new_block] = self.block_entry_vars[block] + new_ir.blocks = blocks + new_ir.block_entry_vars = block_entry_vars + return new_ir + + def get_block_entry_vars(self, block): + """ + Return a set of variable names possibly alive at the beginning of + the block. + """ + return self.block_entry_vars[block] + + def infer_constant(self, name): + """ + Try to infer the constant value of a given variable. + """ + if isinstance(name, Var): + name = name.name + return self._consts.infer_constant(name) + + def get_definition(self, value, lhs_only=False): + """ + Get the definition site for the given variable name or instance. + A Expr instance is returned by default, but if lhs_only is set + to True, the left-hand-side variable is returned instead. + """ + lhs = value + while True: + if isinstance(value, Var): + lhs = value + name = value.name + elif isinstance(value, str): + lhs = value + name = value + else: + return lhs if lhs_only else value + defs = self._definitions[name] + if len(defs) == 0: + raise KeyError("no definition for %r" % (name,)) + if len(defs) > 1: + raise KeyError("more than one definition for %r" % (name,)) + value = defs[0] + + def get_assignee(self, rhs_value, in_blocks=None): + """ + Finds the assignee for a given RHS value. If in_blocks is given the + search will be limited to the specified blocks. + """ + if in_blocks is None: + blocks = self.blocks.values() + elif isinstance(in_blocks, int): + blocks = [self.blocks[in_blocks]] + else: + blocks = [self.blocks[blk] for blk in list(in_blocks)] + + assert isinstance(rhs_value, AbstractRHS) + + for blk in blocks: + for assign in blk.find_insts(Assign): + if assign.value == rhs_value: + return assign.target + + raise ValueError("Could not find an assignee for %s" % rhs_value) + + def dump(self, file=None): + nofile = file is None + # Avoid early bind of sys.stdout as default value + file = file or StringIO() + for offset, block in sorted(self.blocks.items()): + print("label %s:" % (offset,), file=file) + block.dump(file=file) + if nofile: + text = file.getvalue() + if config.HIGHLIGHT_DUMPS: + try: + import pygments # noqa + except ImportError: + msg = "Please install pygments to see highlighted dumps" + raise ValueError(msg) + else: + from pygments import highlight + from numba.misc.dump_style import NumbaIRLexer as lexer + from numba.misc.dump_style import by_colorscheme + from pygments.formatters import Terminal256Formatter + + print( + highlight( + text, + lexer(), + Terminal256Formatter(style=by_colorscheme()), + ) + ) + else: + print(text) + + def dump_to_string(self): + with StringIO() as sb: + self.dump(file=sb) + return sb.getvalue() + + def dump_generator_info(self, file=None): + file = file or sys.stdout + gi = self.generator_info + print("generator state variables:", sorted(gi.state_vars), file=file) + for index, yp in sorted(gi.yield_points.items()): + print( + "yield point #%d: live variables = %s, weak live variables = %s" + % (index, sorted(yp.live_vars), sorted(yp.weak_live_vars)), + file=file, + ) + + def render_dot(self, filename_prefix="numba_ir", include_ir=True): + """Render the CFG of the IR with GraphViz DOT via the + ``graphviz`` python binding. + + Returns + ------- + g : graphviz.Digraph + Use `g.view()` to open the graph in the default PDF application. + """ + + try: + import graphviz as gv + except ImportError: + raise ImportError( + "The feature requires `graphviz` but it is not available. " + "Please install with `pip install graphviz`" + ) + g = gv.Digraph( + filename="{}{}.dot".format( + filename_prefix, + self.func_id.unique_name, + ) + ) + # Populate the nodes + for k, blk in self.blocks.items(): + with StringIO() as sb: + blk.dump(sb) + label = sb.getvalue() + if include_ir: + label = "".join( + [r" {}\l".format(x) for x in label.splitlines()], + ) + label = r"block {}\l".format(k) + label + g.node(str(k), label=label, shape="rect") + else: + label = r"{}\l".format(k) + g.node(str(k), label=label, shape="circle") + # Populate the edges + for src, blk in self.blocks.items(): + for dst in blk.terminator.get_targets(): + g.edge(str(src), str(dst)) + return g + + +# A stub for undefined global reference +class UndefinedType(EqualityCheckMixin): + _singleton = None + + def __new__(cls): + obj = cls._singleton + if obj is not None: + return obj + else: + obj = object.__new__(cls) + cls._singleton = obj + return obj + + def __repr__(self): + return "Undefined" + + +UNDEFINED = UndefinedType() diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_ir.py b/numba_cuda/numba/cuda/tests/cudapy/test_ir.py new file mode 100644 index 000000000..7fac682cb --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_ir.py @@ -0,0 +1,592 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import unittest +from numba.cuda.testing import CUDATestCase +import warnings +import numpy as np + +from numba import objmode +from numba.cuda.core import ir +from numba.cuda import compiler +from numba.core import errors +from numba.cuda.core.compiler import ( + CompilerBase, +) +from numba.cuda.core.compiler_machinery import ( + FunctionPass, + PassManager, + register_pass, +) +from numba.cuda.core.untyped_passes import ( + TranslateByteCode, + IRProcessing, + ReconstructSSA, +) +from numba import njit + + +def requires_fully_vendored_ir(fn): + return unittest.skip( + "Cannot run until the dependency on the un-vendored IR module is removed" + )(fn) + + +class TestIR(CUDATestCase): + def test_IRScope(self): + filename = "" + top = ir.Scope(parent=None, loc=ir.Loc(filename=filename, line=1)) + local = ir.Scope(parent=top, loc=ir.Loc(filename=filename, line=2)) + + apple = local.define("apple", loc=ir.Loc(filename=filename, line=3)) + self.assertIs(local.get("apple"), apple) + self.assertEqual(len(local.localvars), 1) + + orange = top.define("orange", loc=ir.Loc(filename=filename, line=4)) + self.assertEqual(len(local.localvars), 1) + self.assertEqual(len(top.localvars), 1) + self.assertIs(top.get("orange"), orange) + self.assertIs(local.get("orange"), orange) + + more_orange = local.define( + "orange", loc=ir.Loc(filename=filename, line=5) + ) + self.assertIs(top.get("orange"), orange) + self.assertIsNot(local.get("orange"), not orange) + self.assertIs(local.get("orange"), more_orange) + + try: + local.define("orange", loc=ir.Loc(filename=filename, line=5)) + except ir.RedefinedError: + pass + else: + self.fail("Expecting an %s" % ir.RedefinedError) + + +class CheckEquality(CUDATestCase): + var_a = ir.Var(None, "a", ir.unknown_loc) + var_b = ir.Var(None, "b", ir.unknown_loc) + var_c = ir.Var(None, "c", ir.unknown_loc) + var_d = ir.Var(None, "d", ir.unknown_loc) + var_e = ir.Var(None, "e", ir.unknown_loc) + loc1 = ir.Loc("mock", 1, 0) + loc2 = ir.Loc("mock", 2, 0) + loc3 = ir.Loc("mock", 3, 0) + + def check(self, base, same=[], different=[]): + for s in same: + self.assertTrue(base == s) + for d in different: + self.assertTrue(base != d) + + +class TestIRMeta(CheckEquality): + """ + Tests IR node meta, like Loc and Scope + """ + + def test_loc(self): + a = ir.Loc("file", 1, 0) + b = ir.Loc("file", 1, 0) + c = ir.Loc("pile", 1, 0) + d = ir.Loc("file", 2, 0) + e = ir.Loc("file", 1, 1) + self.check( + a, + same=[ + b, + ], + different=[c, d, e], + ) + + f = ir.Loc("file", 1, 0, maybe_decorator=False) + g = ir.Loc("file", 1, 0, maybe_decorator=True) + self.check(a, same=[f, g]) + + def test_scope(self): + parent1 = ir.Scope(None, self.loc1) + parent2 = ir.Scope(None, self.loc1) + parent3 = ir.Scope(None, self.loc2) + self.check( + parent1, + same=[ + parent2, + parent3, + ], + ) + + a = ir.Scope(parent1, self.loc1) + b = ir.Scope(parent1, self.loc1) + c = ir.Scope(parent1, self.loc2) + d = ir.Scope(parent3, self.loc1) + self.check(a, same=[b, c, d]) + + # parent1 and parent2 are equal, so children referring to either parent + # should be equal + e = ir.Scope(parent2, self.loc1) + self.check( + a, + same=[ + e, + ], + ) + + +class TestIRNodes(CheckEquality): + """ + Tests IR nodes + """ + + def test_terminator(self): + # terminator base class inst should always be equal + t1 = ir.Terminator() + t2 = ir.Terminator() + self.check(t1, same=[t2]) + + def test_jump(self): + a = ir.Jump(1, self.loc1) + b = ir.Jump(1, self.loc1) + c = ir.Jump(1, self.loc2) + d = ir.Jump(2, self.loc1) + self.check(a, same=[b, c], different=[d]) + + def test_return(self): + a = ir.Return(self.var_a, self.loc1) + b = ir.Return(self.var_a, self.loc1) + c = ir.Return(self.var_a, self.loc2) + d = ir.Return(self.var_b, self.loc1) + self.check(a, same=[b, c], different=[d]) + + def test_raise(self): + a = ir.Raise(self.var_a, self.loc1) + b = ir.Raise(self.var_a, self.loc1) + c = ir.Raise(self.var_a, self.loc2) + d = ir.Raise(self.var_b, self.loc1) + self.check(a, same=[b, c], different=[d]) + + def test_staticraise(self): + a = ir.StaticRaise(AssertionError, None, self.loc1) + b = ir.StaticRaise(AssertionError, None, self.loc1) + c = ir.StaticRaise(AssertionError, None, self.loc2) + e = ir.StaticRaise(AssertionError, ("str",), self.loc1) + d = ir.StaticRaise(RuntimeError, None, self.loc1) + self.check(a, same=[b, c], different=[d, e]) + + def test_branch(self): + a = ir.Branch(self.var_a, 1, 2, self.loc1) + b = ir.Branch(self.var_a, 1, 2, self.loc1) + c = ir.Branch(self.var_a, 1, 2, self.loc2) + d = ir.Branch(self.var_b, 1, 2, self.loc1) + e = ir.Branch(self.var_a, 2, 2, self.loc1) + f = ir.Branch(self.var_a, 1, 3, self.loc1) + self.check(a, same=[b, c], different=[d, e, f]) + + def test_expr(self): + a = ir.Expr("some_op", self.loc1) + b = ir.Expr("some_op", self.loc1) + c = ir.Expr("some_op", self.loc2) + d = ir.Expr("some_other_op", self.loc1) + self.check(a, same=[b, c], different=[d]) + + def test_setitem(self): + a = ir.SetItem(self.var_a, self.var_b, self.var_c, self.loc1) + b = ir.SetItem(self.var_a, self.var_b, self.var_c, self.loc1) + c = ir.SetItem(self.var_a, self.var_b, self.var_c, self.loc2) + d = ir.SetItem(self.var_d, self.var_b, self.var_c, self.loc1) + e = ir.SetItem(self.var_a, self.var_d, self.var_c, self.loc1) + f = ir.SetItem(self.var_a, self.var_b, self.var_d, self.loc1) + self.check(a, same=[b, c], different=[d, e, f]) + + def test_staticsetitem(self): + a = ir.StaticSetItem(self.var_a, 1, self.var_b, self.var_c, self.loc1) + b = ir.StaticSetItem(self.var_a, 1, self.var_b, self.var_c, self.loc1) + c = ir.StaticSetItem(self.var_a, 1, self.var_b, self.var_c, self.loc2) + d = ir.StaticSetItem(self.var_d, 1, self.var_b, self.var_c, self.loc1) + e = ir.StaticSetItem(self.var_a, 2, self.var_b, self.var_c, self.loc1) + f = ir.StaticSetItem(self.var_a, 1, self.var_d, self.var_c, self.loc1) + g = ir.StaticSetItem(self.var_a, 1, self.var_b, self.var_d, self.loc1) + self.check(a, same=[b, c], different=[d, e, f, g]) + + def test_delitem(self): + a = ir.DelItem(self.var_a, self.var_b, self.loc1) + b = ir.DelItem(self.var_a, self.var_b, self.loc1) + c = ir.DelItem(self.var_a, self.var_b, self.loc2) + d = ir.DelItem(self.var_c, self.var_b, self.loc1) + e = ir.DelItem(self.var_a, self.var_c, self.loc1) + self.check(a, same=[b, c], different=[d, e]) + + def test_del(self): + a = ir.Del(self.var_a.name, self.loc1) + b = ir.Del(self.var_a.name, self.loc1) + c = ir.Del(self.var_a.name, self.loc2) + d = ir.Del(self.var_b.name, self.loc1) + self.check(a, same=[b, c], different=[d]) + + def test_setattr(self): + a = ir.SetAttr(self.var_a, "foo", self.var_b, self.loc1) + b = ir.SetAttr(self.var_a, "foo", self.var_b, self.loc1) + c = ir.SetAttr(self.var_a, "foo", self.var_b, self.loc2) + d = ir.SetAttr(self.var_c, "foo", self.var_b, self.loc1) + e = ir.SetAttr(self.var_a, "bar", self.var_b, self.loc1) + f = ir.SetAttr(self.var_a, "foo", self.var_c, self.loc1) + self.check(a, same=[b, c], different=[d, e, f]) + + def test_delattr(self): + a = ir.DelAttr(self.var_a, "foo", self.loc1) + b = ir.DelAttr(self.var_a, "foo", self.loc1) + c = ir.DelAttr(self.var_a, "foo", self.loc2) + d = ir.DelAttr(self.var_c, "foo", self.loc1) + e = ir.DelAttr(self.var_a, "bar", self.loc1) + self.check(a, same=[b, c], different=[d, e]) + + def test_assign(self): + a = ir.Assign(self.var_a, self.var_b, self.loc1) + b = ir.Assign(self.var_a, self.var_b, self.loc1) + c = ir.Assign(self.var_a, self.var_b, self.loc2) + d = ir.Assign(self.var_c, self.var_b, self.loc1) + e = ir.Assign(self.var_a, self.var_c, self.loc1) + self.check(a, same=[b, c], different=[d, e]) + + def test_print(self): + a = ir.Print((self.var_a,), self.var_b, self.loc1) + b = ir.Print((self.var_a,), self.var_b, self.loc1) + c = ir.Print((self.var_a,), self.var_b, self.loc2) + d = ir.Print((self.var_c,), self.var_b, self.loc1) + e = ir.Print((self.var_a,), self.var_c, self.loc1) + self.check(a, same=[b, c], different=[d, e]) + + def test_storemap(self): + a = ir.StoreMap(self.var_a, self.var_b, self.var_c, self.loc1) + b = ir.StoreMap(self.var_a, self.var_b, self.var_c, self.loc1) + c = ir.StoreMap(self.var_a, self.var_b, self.var_c, self.loc2) + d = ir.StoreMap(self.var_d, self.var_b, self.var_c, self.loc1) + e = ir.StoreMap(self.var_a, self.var_d, self.var_c, self.loc1) + f = ir.StoreMap(self.var_a, self.var_b, self.var_d, self.loc1) + self.check(a, same=[b, c], different=[d, e, f]) + + def test_yield(self): + a = ir.Yield(self.var_a, self.loc1, 0) + b = ir.Yield(self.var_a, self.loc1, 0) + c = ir.Yield(self.var_a, self.loc2, 0) + d = ir.Yield(self.var_b, self.loc1, 0) + e = ir.Yield(self.var_a, self.loc1, 1) + self.check(a, same=[b, c], different=[d, e]) + + def test_enterwith(self): + a = ir.EnterWith(self.var_a, 0, 1, self.loc1) + b = ir.EnterWith(self.var_a, 0, 1, self.loc1) + c = ir.EnterWith(self.var_a, 0, 1, self.loc2) + d = ir.EnterWith(self.var_b, 0, 1, self.loc1) + e = ir.EnterWith(self.var_a, 1, 1, self.loc1) + f = ir.EnterWith(self.var_a, 0, 2, self.loc1) + self.check(a, same=[b, c], different=[d, e, f]) + + def test_arg(self): + a = ir.Arg("foo", 0, self.loc1) + b = ir.Arg("foo", 0, self.loc1) + c = ir.Arg("foo", 0, self.loc2) + d = ir.Arg("bar", 0, self.loc1) + e = ir.Arg("foo", 1, self.loc1) + self.check(a, same=[b, c], different=[d, e]) + + def test_const(self): + a = ir.Const(1, self.loc1) + b = ir.Const(1, self.loc1) + c = ir.Const(1, self.loc2) + d = ir.Const(2, self.loc1) + self.check(a, same=[b, c], different=[d]) + + def test_global(self): + a = ir.Global("foo", 0, self.loc1) + b = ir.Global("foo", 0, self.loc1) + c = ir.Global("foo", 0, self.loc2) + d = ir.Global("bar", 0, self.loc1) + e = ir.Global("foo", 1, self.loc1) + self.check(a, same=[b, c], different=[d, e]) + + def test_var(self): + a = ir.Var(None, "foo", self.loc1) + b = ir.Var(None, "foo", self.loc1) + c = ir.Var(None, "foo", self.loc2) + d = ir.Var(ir.Scope(None, ir.unknown_loc), "foo", self.loc1) + e = ir.Var(None, "bar", self.loc1) + self.check(a, same=[b, c, d], different=[e]) + + def test_undefinedtype(self): + a = ir.UndefinedType() + b = ir.UndefinedType() + self.check(a, same=[b]) + + def test_loop(self): + a = ir.Loop(1, 3) + b = ir.Loop(1, 3) + c = ir.Loop(2, 3) + d = ir.Loop(1, 4) + self.check(a, same=[b], different=[c, d]) + + def test_with(self): + a = ir.With(1, 3) + b = ir.With(1, 3) + c = ir.With(2, 3) + d = ir.With(1, 4) + self.check(a, same=[b], different=[c, d]) + + +# used later +_GLOBAL = 1234 + + +class TestIRCompounds(CheckEquality): + """ + Tests IR concepts that have state + """ + + def test_varmap(self): + a = ir.VarMap() + a.define(self.var_a, "foo") + a.define(self.var_b, "bar") + + b = ir.VarMap() + b.define(self.var_a, "foo") + b.define(self.var_b, "bar") + + c = ir.VarMap() + c.define(self.var_a, "foo") + c.define(self.var_c, "bar") + + self.check(a, same=[b], different=[c]) + + def test_block(self): + def gen_block(): + parent = ir.Scope(None, self.loc1) + tmp = ir.Block(parent, self.loc2) + assign1 = ir.Assign(self.var_a, self.var_b, self.loc3) + assign2 = ir.Assign(self.var_a, self.var_c, self.loc3) + assign3 = ir.Assign(self.var_c, self.var_b, self.loc3) + tmp.append(assign1) + tmp.append(assign2) + tmp.append(assign3) + return tmp + + a = gen_block() + b = gen_block() + c = gen_block().append(ir.Assign(self.var_a, self.var_b, self.loc3)) + + self.check(a, same=[b], different=[c]) + + @requires_fully_vendored_ir + def test_functionir(self): + def run_frontend(x): + return compiler.run_frontend(x, emit_dels=True) + + # this creates a function full of all sorts of things to ensure the IR + # is pretty involved, it then compares two instances of the compiled + # function IR to check the IR is the same invariant of objects, and then + # a tiny mutation is made to the IR in the second function and detection + # of this change is checked. + def gen(): + _FREEVAR = 0xCAFE + + def foo(a, b, c=12, d=1j, e=None): + f = a + b + a += _FREEVAR + g = np.zeros(c, dtype=np.complex64) + h = f + g + i = 1j / d + if np.abs(i) > 0: + k = h / i + l = np.arange(1, c + 1) + with objmode(): + print(e, k) + m = np.sqrt(l - g) + if np.abs(m[0]) < 1: + n = 0 + for o in range(a): + n += 0 + if np.abs(n) < 3: + break + n += m[2] + p = g / l + q = [] + for r in range(len(p)): + q.append(p[r]) + if r > 4 + 1: + with objmode(s="intp", t="complex128"): + s = 123 + t = 5 + if s > 122: + t += s + t += q[0] + _GLOBAL + + return f + o + r + t + r + a + n + + return foo + + x = gen() + y = gen() + x_ir = run_frontend(x) + y_ir = run_frontend(y) + + self.assertTrue(x_ir.equal_ir(y_ir)) + + def check_diffstr(string, pointing_at=[]): + lines = string.splitlines() + for item in pointing_at: + for l in lines: + if l.startswith("->"): + if item in l: + break + else: + raise AssertionError("Could not find %s " % item) + + self.assertIn("IR is considered equivalent", x_ir.diff_str(y_ir)) + + # minor mutation, simply switch branch targets on last branch + for label in reversed(list(y_ir.blocks.keys())): + blk = y_ir.blocks[label] + if isinstance(blk.body[-1], ir.Branch): + ref = blk.body[-1] + ref.truebr, ref.falsebr = ref.falsebr, ref.truebr + break + + check_diffstr(x_ir.diff_str(y_ir), ["branch"]) + + z = gen() + self.assertFalse(x_ir.equal_ir(y_ir)) + + z_ir = run_frontend(z) + + change_set = set() + for label in reversed(list(z_ir.blocks.keys())): + blk = z_ir.blocks[label] + ref = blk.body[:-1] + idx = None + for i in range(len(ref) - 1): + # look for two adjacent Del + if isinstance(ref[i], ir.Del) and isinstance( + ref[i + 1], ir.Del + ): + idx = i + break + if idx is not None: + b = blk.body + change_set.add(str(b[idx + 1])) + change_set.add(str(b[idx])) + b[idx], b[idx + 1] = b[idx + 1], b[idx] + break + + # ensure that a mutation occurred. + self.assertTrue(change_set) + + self.assertFalse(x_ir.equal_ir(z_ir)) + self.assertEqual(len(change_set), 2) + for item in change_set: + self.assertTrue(item.startswith("del ")) + check_diffstr(x_ir.diff_str(z_ir), change_set) + + def foo(a, b): + c = a * 2 + d = c + b + e = np.sqrt(d) + return e + + def bar(a, b): # same as foo + c = a * 2 + d = c + b + e = np.sqrt(d) + return e + + def baz(a, b): + c = a * 2 + d = b + c + e = np.sqrt(d + 1) + return e + + foo_ir = run_frontend(foo) + bar_ir = run_frontend(bar) + self.assertTrue(foo_ir.equal_ir(bar_ir)) + self.assertIn("IR is considered equivalent", foo_ir.diff_str(bar_ir)) + + baz_ir = run_frontend(baz) + self.assertFalse(foo_ir.equal_ir(baz_ir)) + tmp = foo_ir.diff_str(baz_ir) + self.assertIn("Other block contains more statements", tmp) + check_diffstr(tmp, ["c + b", "b + c"]) + + +@requires_fully_vendored_ir +class TestIRPedanticChecks(CUDATestCase): + def test_var_in_scope_assumption(self): + # Create a pass that clears ir.Scope in ir.Block + @register_pass(mutates_CFG=False, analysis_only=False) + class RemoveVarInScope(FunctionPass): + _name = "_remove_var_in_scope" + + def __init__(self): + FunctionPass.__init__(self) + + # implement method to do the work, "state" is the internal compiler + # state from the CompilerBase instance. + def run_pass(self, state): + func_ir = state.func_ir + # walk the blocks + for blk in func_ir.blocks.values(): + oldscope = blk.scope + # put in an empty Scope + blk.scope = ir.Scope( + parent=oldscope.parent, loc=oldscope.loc + ) + return True + + # Create a pass that always fails, to stop the compiler + @register_pass(mutates_CFG=False, analysis_only=False) + class FailPass(FunctionPass): + _name = "_fail" + + def __init__(self, *args, **kwargs): + FunctionPass.__init__(self) + + def run_pass(self, state): + # This is unreachable. SSA pass should have raised before this + # pass when run with `error.NumbaPedanticWarning`s raised as + # errors. + raise AssertionError("unreachable") + + class MyCompiler(CompilerBase): + def define_pipelines(self): + pm = PassManager("testing pm") + pm.add_pass(TranslateByteCode, "analyzing bytecode") + pm.add_pass(IRProcessing, "processing IR") + pm.add_pass(RemoveVarInScope, "_remove_var_in_scope") + pm.add_pass(ReconstructSSA, "ssa") + pm.add_pass(FailPass, "_fail") + pm.finalize() + return [pm] + + @njit(pipeline_class=MyCompiler) + def dummy(x): + # To trigger SSA and the pedantic check, this function must have + # multiple assignments to the same variable in different blocks. + a = 1 + b = 2 + if a < b: + a = 2 + else: + b = 3 + return a, b + + with warnings.catch_warnings(): + # Make NumbaPedanticWarning an error + warnings.simplefilter("error", errors.NumbaPedanticWarning) + # Catch NumbaIRAssumptionWarning + with self.assertRaises(errors.NumbaIRAssumptionWarning) as raises: + dummy(1) + # Verify the error message + self.assertRegex( + str(raises.exception), + r"variable '[a-z]' is not in scope", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/numba_cuda/numba/cuda/utils.py b/numba_cuda/numba/cuda/utils.py index 76cdd1af1..bc012168d 100644 --- a/numba_cuda/numba/cuda/utils.py +++ b/numba_cuda/numba/cuda/utils.py @@ -115,7 +115,7 @@ def safe_relpath(path, start=os.curdir): ALL_BINOPS_TO_OPERATORS = {**BINOPS_TO_OPERATORS, **INPLACE_BINOPS_TO_OPERATORS} -UNARY_BUITINS_TO_OPERATORS = { +UNARY_BUILTINS_TO_OPERATORS = { "+": operator.pos, "-": operator.neg, "~": operator.invert,