diff --git a/numba_cuda/numba/cuda/core/analysis.py b/numba_cuda/numba/cuda/core/analysis.py index 6326ca260..aa70a800d 100644 --- a/numba_cuda/numba/cuda/core/analysis.py +++ b/numba_cuda/numba/cuda/core/analysis.py @@ -1,11 +1,449 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from collections import namedtuple +from collections import namedtuple, defaultdict from numba import types -from numba.core import ir +from numba.core import ir, errors from numba.cuda.core import consts -from numba.core.analysis import compute_cfg_from_blocks +import operator +from functools import reduce + +from .controlflow import CFGraph +from numba.misc import special + +# +# Analysis related to variable lifetime +# + +_use_defs_result = namedtuple("use_defs_result", "usemap,defmap") + +# other packages that define new nodes add calls for finding defs +# format: {type:function} +ir_extension_usedefs = {} + + +def compute_use_defs(blocks): + """ + Find variable use/def per block. + """ + + var_use_map = {} # { block offset -> set of vars } + var_def_map = {} # { block offset -> set of vars } + for offset, ir_block in blocks.items(): + var_use_map[offset] = use_set = set() + var_def_map[offset] = def_set = set() + for stmt in ir_block.body: + if type(stmt) in ir_extension_usedefs: + func = ir_extension_usedefs[type(stmt)] + func(stmt, use_set, def_set) + continue + if isinstance(stmt, ir.Assign): + if isinstance(stmt.value, ir.Inst): + rhs_set = set(var.name for var in stmt.value.list_vars()) + elif isinstance(stmt.value, ir.Var): + rhs_set = set([stmt.value.name]) + elif isinstance( + stmt.value, (ir.Arg, ir.Const, ir.Global, ir.FreeVar) + ): + rhs_set = () + else: + raise AssertionError("unreachable", type(stmt.value)) + # If lhs not in rhs of the assignment + if stmt.target.name not in rhs_set: + def_set.add(stmt.target.name) + + for var in stmt.list_vars(): + # do not include locally defined vars to use-map + if var.name not in def_set: + use_set.add(var.name) + + return _use_defs_result(usemap=var_use_map, defmap=var_def_map) + + +def compute_live_map(cfg, blocks, var_use_map, var_def_map): + """ + Find variables that must be alive at the ENTRY of each block. + We use a simple fix-point algorithm that iterates until the set of + live variables is unchanged for each block. + """ + + def fix_point_progress(dct): + """Helper function to determine if a fix-point has been reached.""" + return tuple(len(v) for v in dct.values()) + + def fix_point(fn, dct): + """Helper function to run fix-point algorithm.""" + old_point = None + new_point = fix_point_progress(dct) + while old_point != new_point: + fn(dct) + old_point = new_point + new_point = fix_point_progress(dct) + + def def_reach(dct): + """Find all variable definition reachable at the entry of a block""" + for offset in var_def_map: + used_or_defined = var_def_map[offset] | var_use_map[offset] + dct[offset] |= used_or_defined + # Propagate to outgoing nodes + for out_blk, _ in cfg.successors(offset): + dct[out_blk] |= dct[offset] + + def liveness(dct): + """Find live variables. + + Push var usage backward. + """ + for offset in dct: + # Live vars here + live_vars = dct[offset] + for inc_blk, _data in cfg.predecessors(offset): + # Reachable at the predecessor + reachable = live_vars & def_reach_map[inc_blk] + # But not defined in the predecessor + dct[inc_blk] |= reachable - var_def_map[inc_blk] + + live_map = {} + for offset in blocks.keys(): + live_map[offset] = set(var_use_map[offset]) + + def_reach_map = defaultdict(set) + fix_point(def_reach, def_reach_map) + fix_point(liveness, live_map) + return live_map + + +_dead_maps_result = namedtuple("dead_maps_result", "internal,escaping,combined") + + +def compute_dead_maps(cfg, blocks, live_map, var_def_map): + """ + Compute the end-of-live information for variables. + `live_map` contains a mapping of block offset to all the living + variables at the ENTRY of the block. + """ + # The following three dictionaries will be + # { block offset -> set of variables to delete } + # all vars that should be deleted at the start of the successors + escaping_dead_map = defaultdict(set) + # all vars that should be deleted within this block + internal_dead_map = defaultdict(set) + # all vars that should be deleted after the function exit + exit_dead_map = defaultdict(set) + + for offset, ir_block in blocks.items(): + # live vars WITHIN the block will include all the locally + # defined variables + cur_live_set = live_map[offset] | var_def_map[offset] + # vars alive in the outgoing blocks + outgoing_live_map = dict( + (out_blk, live_map[out_blk]) + for out_blk, _data in cfg.successors(offset) + ) + # vars to keep alive for the terminator + terminator_liveset = set( + v.name for v in ir_block.terminator.list_vars() + ) + # vars to keep alive in the successors + combined_liveset = reduce( + operator.or_, outgoing_live_map.values(), set() + ) + # include variables used in terminator + combined_liveset |= terminator_liveset + # vars that are dead within the block because they are not + # propagated to any outgoing blocks + internal_set = cur_live_set - combined_liveset + internal_dead_map[offset] = internal_set + # vars that escape this block + escaping_live_set = cur_live_set - internal_set + for out_blk, new_live_set in outgoing_live_map.items(): + # successor should delete the unused escaped vars + new_live_set = new_live_set | var_def_map[out_blk] + escaping_dead_map[out_blk] |= escaping_live_set - new_live_set + + # if no outgoing blocks + if not outgoing_live_map: + # insert var used by terminator + exit_dead_map[offset] = terminator_liveset + + # Verify that the dead maps cover all live variables + all_vars = reduce(operator.or_, live_map.values(), set()) + internal_dead_vars = reduce(operator.or_, internal_dead_map.values(), set()) + escaping_dead_vars = reduce(operator.or_, escaping_dead_map.values(), set()) + exit_dead_vars = reduce(operator.or_, exit_dead_map.values(), set()) + dead_vars = internal_dead_vars | escaping_dead_vars | exit_dead_vars + missing_vars = all_vars - dead_vars + if missing_vars: + # There are no exit points + if not cfg.exit_points(): + # We won't be able to verify this + pass + else: + msg = "liveness info missing for vars: {0}".format(missing_vars) + raise RuntimeError(msg) + + combined = dict( + (k, internal_dead_map[k] | escaping_dead_map[k]) for k in blocks + ) + + return _dead_maps_result( + internal=internal_dead_map, + escaping=escaping_dead_map, + combined=combined, + ) + + +def compute_live_variables(cfg, blocks, var_def_map, var_dead_map): + """ + Compute the live variables at the beginning of each block + and at each yield point. + The ``var_def_map`` and ``var_dead_map`` indicates the variable defined + and deleted at each block, respectively. + """ + # live var at the entry per block + block_entry_vars = defaultdict(set) + + def fix_point_progress(): + return tuple(map(len, block_entry_vars.values())) + + old_point = None + new_point = fix_point_progress() + + # Propagate defined variables and still live the successors. + # (note the entry block automatically gets an empty set) + + # Note: This is finding the actual available variables at the entry + # of each block. The algorithm in compute_live_map() is finding + # the variable that must be available at the entry of each block. + # This is top-down in the dataflow. The other one is bottom-up. + while old_point != new_point: + # We iterate until the result stabilizes. This is necessary + # because of loops in the graphself. + for offset in blocks: + # vars available + variable defined + avail = block_entry_vars[offset] | var_def_map[offset] + # subtract variables deleted + avail -= var_dead_map[offset] + # add ``avail`` to each successors + for succ, _data in cfg.successors(offset): + block_entry_vars[succ] |= avail + + old_point = new_point + new_point = fix_point_progress() + + return block_entry_vars + + +# +# Analysis related to controlflow +# + + +def compute_cfg_from_blocks(blocks): + cfg = CFGraph() + for k in blocks: + cfg.add_node(k) + + for k, b in blocks.items(): + term = b.terminator + for target in term.get_targets(): + cfg.add_edge(k, target) + + cfg.set_entry_point(min(blocks)) + cfg.process() + return cfg + + +def find_top_level_loops(cfg): + """ + A generator that yields toplevel loops given a control-flow-graph + """ + blocks_in_loop = set() + # get loop bodies + for loop in cfg.loops().values(): + insiders = set(loop.body) | set(loop.entries) | set(loop.exits) + insiders.discard(loop.header) + blocks_in_loop |= insiders + # find loop that is not part of other loops + for loop in cfg.loops().values(): + if loop.header not in blocks_in_loop: + yield _fix_loop_exit(cfg, loop) + + +def _fix_loop_exit(cfg, loop): + """ + Fixes loop.exits for Py3.8+ bytecode CFG changes. + This is to handle `break` inside loops. + """ + # Computes the common postdoms of exit nodes + postdoms = cfg.post_dominators() + exits = reduce( + operator.and_, + [postdoms[b] for b in loop.exits], + loop.exits, + ) + if exits: + # Put the non-common-exits as body nodes + body = loop.body | loop.exits - exits + return loop._replace(exits=exits, body=body) + else: + return loop + + +def rewrite_semantic_constants(func_ir, called_args): + """ + This rewrites values known to be constant by their semantics as ir.Const + nodes, this is to give branch pruning the best chance possible of killing + branches. An example might be rewriting len(tuple) as the literal length. + + func_ir is the IR + called_args are the actual arguments with which the function is called + """ + DEBUG = 0 + + if DEBUG > 1: + print( + ("rewrite_semantic_constants: " + func_ir.func_id.func_name).center( + 80, "-" + ) + ) + print("before".center(80, "*")) + func_ir.dump() + + def rewrite_statement(func_ir, stmt, new_val): + """ + Rewrites the stmt as a ir.Const new_val and fixes up the entries in + func_ir._definitions + """ + stmt.value = ir.Const(new_val, stmt.loc) + defns = func_ir._definitions[stmt.target.name] + repl_idx = defns.index(val) + defns[repl_idx] = stmt.value + + def rewrite_array_ndim(val, func_ir, called_args): + # rewrite Array.ndim as const(ndim) + if getattr(val, "op", None) == "getattr": + if val.attr == "ndim": + arg_def = guard(get_definition, func_ir, val.value) + if isinstance(arg_def, ir.Arg): + argty = called_args[arg_def.index] + if isinstance(argty, types.Array): + rewrite_statement(func_ir, stmt, argty.ndim) + + def rewrite_tuple_len(val, func_ir, called_args): + # rewrite len(tuple) as const(len(tuple)) + if getattr(val, "op", None) == "call": + func = guard(get_definition, func_ir, val.func) + if ( + func is not None + and isinstance(func, ir.Global) + and getattr(func, "value", None) is len + ): + (arg,) = val.args + arg_def = guard(get_definition, func_ir, arg) + if isinstance(arg_def, ir.Arg): + argty = called_args[arg_def.index] + if isinstance(argty, types.BaseTuple): + rewrite_statement(func_ir, stmt, argty.count) + elif ( + isinstance(arg_def, ir.Expr) + and arg_def.op == "typed_getitem" + ): + argty = arg_def.dtype + if isinstance(argty, types.BaseTuple): + rewrite_statement(func_ir, stmt, argty.count) + + from numba.core.ir_utils import get_definition, guard + + for blk in func_ir.blocks.values(): + for stmt in blk.body: + if isinstance(stmt, ir.Assign): + val = stmt.value + if isinstance(val, ir.Expr): + rewrite_array_ndim(val, func_ir, called_args) + rewrite_tuple_len(val, func_ir, called_args) + + if DEBUG > 1: + print("after".center(80, "*")) + func_ir.dump() + print("-" * 80) + + +def find_literally_calls(func_ir, argtypes): + """An analysis to find `numba.literally` call inside the given IR. + When an unsatisfied literal typing request is found, a `ForceLiteralArg` + exception is raised. + + Parameters + ---------- + + func_ir : numba.ir.FunctionIR + + argtypes : Sequence[numba.types.Type] + The argument types. + """ + from numba.core import ir_utils + + marked_args = set() + first_loc = {} + # Scan for literally calls + for blk in func_ir.blocks.values(): + for assign in blk.find_exprs(op="call"): + var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func) + if isinstance(var, (ir.Global, ir.FreeVar)): + fnobj = var.value + else: + fnobj = ir_utils.guard( + ir_utils.resolve_func_from_module, func_ir, var + ) + if fnobj is special.literally: + # Found + [arg] = assign.args + defarg = func_ir.get_definition(arg) + if isinstance(defarg, ir.Arg): + argindex = defarg.index + marked_args.add(argindex) + first_loc.setdefault(argindex, assign.loc) + # Signal the dispatcher to force literal typing + for pos in marked_args: + query_arg = argtypes[pos] + do_raise = ( + isinstance(query_arg, types.InitialValue) + and query_arg.initial_value is None + ) + if do_raise: + loc = first_loc[pos] + raise errors.ForceLiteralArg(marked_args, loc=loc) + + if not isinstance(query_arg, (types.Literal, types.InitialValue)): + loc = first_loc[pos] + raise errors.ForceLiteralArg(marked_args, loc=loc) + + +ir_extension_use_alloca = {} + + +def must_use_alloca(blocks): + """ + Analyzes a dictionary of blocks to find variables that must be + stack allocated with alloca. For each statement in the blocks, + determine if that statement requires certain variables to be + stack allocated. This function uses the extension point + ir_extension_use_alloca to allow other IR node types like parfors + to register to be processed by this analysis function. At the + moment, parfors are the only IR node types that may require + something to be stack allocated. + """ + use_alloca_vars = set() + + for ir_block in blocks.values(): + for stmt in ir_block.body: + if type(stmt) in ir_extension_use_alloca: + func = ir_extension_use_alloca[type(stmt)] + func(stmt, use_alloca_vars) + continue + + return use_alloca_vars # Used to describe a nullified condition in dead branch pruning diff --git a/numba_cuda/numba/cuda/core/byteflow.py b/numba_cuda/numba/cuda/core/byteflow.py new file mode 100644 index 000000000..4be6d4c10 --- /dev/null +++ b/numba_cuda/numba/cuda/core/byteflow.py @@ -0,0 +1,2346 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implement python 3.8+ bytecode analysis +""" + +import dis +import logging +from collections import namedtuple, defaultdict, deque +from functools import total_ordering + +from numba.cuda.utils import ( + UniqueDict, + PYVERSION, + ALL_BINOPS_TO_OPERATORS, + _lazy_pformat, +) +from numba.cuda.core.controlflow import NEW_BLOCKERS, CFGraph +from numba.core.ir import Loc +from numba.cuda.errors import UnsupportedBytecodeError + + +_logger = logging.getLogger(__name__) + +_EXCEPT_STACK_OFFSET = 6 +_FINALLY_POP = _EXCEPT_STACK_OFFSET +_NO_RAISE_OPS = frozenset( + { + "LOAD_CONST", + "NOP", + "LOAD_DEREF", + "PRECALL", + } +) + +if PYVERSION in ((3, 12), (3, 13)): + from enum import Enum + + # Operands for CALL_INTRINSIC_1 + class CALL_INTRINSIC_1_Operand(Enum): + INTRINSIC_STOPITERATION_ERROR = 3 + UNARY_POSITIVE = 5 + INTRINSIC_LIST_TO_TUPLE = 6 + + ci1op = CALL_INTRINSIC_1_Operand +elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass +else: + raise NotImplementedError(PYVERSION) + + +@total_ordering +class BlockKind(object): + """Kinds of block to make related code safer than just `str`.""" + + _members = frozenset( + { + "LOOP", + "TRY", + "EXCEPT", + "FINALLY", + "WITH", + "WITH_FINALLY", + } + ) + + def __init__(self, value): + assert value in self._members + self._value = value + + def __hash__(self): + return hash((type(self), self._value)) + + def __lt__(self, other): + if isinstance(other, BlockKind): + return self._value < other._value + else: + raise TypeError("cannot compare to {!r}".format(type(other))) + + def __eq__(self, other): + if isinstance(other, BlockKind): + return self._value == other._value + else: + raise TypeError("cannot compare to {!r}".format(type(other))) + + def __repr__(self): + return "BlockKind({})".format(self._value) + + +class Flow(object): + """Data+Control Flow analysis. + + Simulate execution to recover dataflow and controlflow information. + """ + + def __init__(self, bytecode): + _logger.debug( + "bytecode dump:\n%s", + _lazy_pformat(bytecode, lazy_func=lambda x: x.dump()), + ) + self._bytecode = bytecode + self.block_infos = UniqueDict() + + def run(self): + """Run a trace over the bytecode over all reachable path. + + The trace starts at bytecode offset 0 and gathers stack and control- + flow information by partially interpreting each bytecode. + Each ``State`` instance in the trace corresponds to a basic-block. + The State instances forks when a jump instruction is encountered. + A newly forked state is then added to the list of pending states. + The trace ends when there are no more pending states. + """ + firststate = State( + bytecode=self._bytecode, pc=0, nstack=0, blockstack=() + ) + runner = TraceRunner(debug_filename=self._bytecode.func_id.filename) + runner.pending.append(firststate) + + # Enforce unique-ness on initial PC to avoid re-entering the PC with + # a different stack-depth. We don't know if such a case is ever + # possible, but no such case has been encountered in our tests. + first_encounter = UniqueDict() + # Loop over each pending state at a initial PC. + # Each state is tracing a basic block + while runner.pending: + _logger.debug("pending: %s", runner.pending) + state = runner.pending.popleft() + if state not in runner.finished: + _logger.debug("stack: %s", state._stack) + _logger.debug("state.pc_initial: %s", state) + first_encounter[state.pc_initial] = state + # Loop over the state until it is terminated. + while True: + runner.dispatch(state) + # Terminated? + if state.has_terminated(): + break + else: + if self._run_handle_exception(runner, state): + break + + if self._is_implicit_new_block(state): + # check if this is a with...as, abort if so + self._guard_with_as(state) + # else split + state.split_new_block() + break + _logger.debug("end state. edges=%s", state.outgoing_edges) + runner.finished.add(state) + out_states = state.get_outgoing_states() + runner.pending.extend(out_states) + + # Complete controlflow + self._build_cfg(runner.finished) + # Prune redundant PHI-nodes + self._prune_phis(runner) + # Post process + for state in sorted(runner.finished, key=lambda x: x.pc_initial): + self.block_infos[state.pc_initial] = si = adapt_state_infos(state) + _logger.debug("block_infos %s:\n%s", state, si) + + if PYVERSION in ((3, 11), (3, 12), (3, 13)): + + def _run_handle_exception(self, runner, state): + if not state.in_with() and ( + state.has_active_try() + and state.get_inst().opname not in _NO_RAISE_OPS + ): + # Is in a *try* block + state.fork(pc=state.get_inst().next) + runner._adjust_except_stack(state) + return True + else: + state.advance_pc() + + # Must the new PC be a new block? + if not state.in_with() and state.is_in_exception(): + _logger.debug( + "3.11 exception %s PC=%s", + state.get_exception(), + state._pc, + ) + eh = state.get_exception() + eh_top = state.get_top_block("TRY") + if eh_top and eh_top["end"] == eh.target: + # Same exception + eh_block = None + else: + eh_block = state.make_block("TRY", end=eh.target) + eh_block["end_offset"] = eh.end + eh_block["stack_depth"] = eh.depth + eh_block["push_lasti"] = eh.lasti + state.fork(pc=state._pc, extra_block=eh_block) + return True + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + + def _run_handle_exception(self, runner, state): + if ( + state.has_active_try() + and state.get_inst().opname not in _NO_RAISE_OPS + ): + # Is in a *try* block + state.fork(pc=state.get_inst().next) + tryblk = state.get_top_block("TRY") + state.pop_block_and_above(tryblk) + nstack = state.stack_depth + kwargs = {} + if nstack > tryblk["entry_stack"]: + kwargs["npop"] = nstack - tryblk["entry_stack"] + handler = tryblk["handler"] + kwargs["npush"] = { + BlockKind("EXCEPT"): _EXCEPT_STACK_OFFSET, + BlockKind("FINALLY"): _FINALLY_POP, + }[handler["kind"]] + kwargs["extra_block"] = handler + state.fork(pc=tryblk["end"], **kwargs) + return True + else: + state.advance_pc() + else: + raise NotImplementedError(PYVERSION) + + def _build_cfg(self, all_states): + graph = CFGraph() + for state in all_states: + b = state.pc_initial + graph.add_node(b) + for state in all_states: + for edge in state.outgoing_edges: + graph.add_edge(state.pc_initial, edge.pc, 0) + graph.set_entry_point(0) + graph.process() + self.cfgraph = graph + + def _prune_phis(self, runner): + # Find phis that are unused in the local block + _logger.debug("Prune PHIs".center(60, "-")) + + # Compute dataflow for used phis and propagate + + # 1. Get used-phis for each block + # Map block to used_phis + def get_used_phis_per_state(): + used_phis = defaultdict(set) + phi_set = set() + for state in runner.finished: + used = set(state._used_regs) + phis = set(state._phis) + used_phis[state] |= phis & used + phi_set |= phis + return used_phis, phi_set + + # Find use-defs + def find_use_defs(): + defmap = {} + phismap = defaultdict(set) + for state in runner.finished: + for phi, rhs in state._outgoing_phis.items(): + if rhs not in phi_set: + # Is a definition + defmap[phi] = state + phismap[phi].add((rhs, state)) + _logger.debug("defmap: %s", _lazy_pformat(defmap)) + _logger.debug("phismap: %s", _lazy_pformat(phismap)) + return defmap, phismap + + def propagate_phi_map(phismap): + """An iterative dataflow algorithm to find the definition + (the source) of each PHI node. + """ + blacklist = defaultdict(set) + + while True: + changing = False + for phi, defsites in sorted(list(phismap.items())): + for rhs, state in sorted(list(defsites)): + if rhs in phi_set: + defsites |= phismap[rhs] + blacklist[phi].add((rhs, state)) + to_remove = blacklist[phi] + if to_remove & defsites: + defsites -= to_remove + changing = True + + _logger.debug("changing phismap: %s", _lazy_pformat(phismap)) + if not changing: + break + + def apply_changes(used_phis, phismap): + keep = {} + for state, used_set in used_phis.items(): + for phi in used_set: + keep[phi] = phismap[phi] + _logger.debug("keep phismap: %s", _lazy_pformat(keep)) + new_out = defaultdict(dict) + for phi in keep: + for rhs, state in keep[phi]: + new_out[state][phi] = rhs + + _logger.debug("new_out: %s", _lazy_pformat(new_out)) + for state in runner.finished: + state._outgoing_phis.clear() + state._outgoing_phis.update(new_out[state]) + + used_phis, phi_set = get_used_phis_per_state() + _logger.debug("Used_phis: %s", _lazy_pformat(used_phis)) + defmap, phismap = find_use_defs() + propagate_phi_map(phismap) + apply_changes(used_phis, phismap) + _logger.debug("DONE Prune PHIs".center(60, "-")) + + def _is_implicit_new_block(self, state): + inst = state.get_inst() + + if inst.offset in self._bytecode.labels: + return True + elif inst.opname in NEW_BLOCKERS: + return True + else: + return False + + def _guard_with_as(self, state): + """Checks if the next instruction after a SETUP_WITH is something other + than a POP_TOP, if it is something else it'll be some sort of store + which is not supported (this corresponds to `with CTXMGR as VAR(S)`).""" + current_inst = state.get_inst() + if current_inst.opname in {"SETUP_WITH", "BEFORE_WITH"}: + next_op = self._bytecode[current_inst.next].opname + if next_op != "POP_TOP": + msg = ( + "The 'with (context manager) as " + "(variable):' construct is not " + "supported." + ) + raise UnsupportedBytecodeError(msg) + + +def _is_null_temp_reg(reg): + return reg.startswith("$null$") + + +class TraceRunner(object): + """Trace runner contains the states for the trace and the opcode dispatch.""" + + def __init__(self, debug_filename): + self.debug_filename = debug_filename + self.pending = deque() + self.finished = set() + + def get_debug_loc(self, lineno): + return Loc(self.debug_filename, lineno) + + def dispatch(self, state): + if PYVERSION in ((3, 11), (3, 12), (3, 13)): + if state._blockstack: + state: State + while state._blockstack: + topblk = state._blockstack[-1] + blk_end = topblk["end"] + if blk_end is not None and blk_end <= state.pc_initial: + state._blockstack.pop() + else: + break + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + pass + else: + raise NotImplementedError(PYVERSION) + inst = state.get_inst() + if inst.opname != "CACHE": + _logger.debug("dispatch pc=%s, inst=%s", state._pc, inst) + _logger.debug("stack %s", state._stack) + fn = getattr(self, "op_{}".format(inst.opname), None) + if fn is not None: + fn(state, inst) + else: + msg = "Use of unsupported opcode (%s) found" % inst.opname + raise UnsupportedBytecodeError( + msg, loc=self.get_debug_loc(inst.lineno) + ) + + def _adjust_except_stack(self, state): + """ + Adjust stack when entering an exception handler to match expectation + by the bytecode. + """ + tryblk = state.get_top_block("TRY") + state.pop_block_and_above(tryblk) + nstack = state.stack_depth + kwargs = {} + expected_depth = tryblk["stack_depth"] + if nstack > expected_depth: + # Pop extra item in the stack + kwargs["npop"] = nstack - expected_depth + # Set extra stack itemcount due to the exception values. + extra_stack = 1 + if tryblk["push_lasti"]: + extra_stack += 1 + kwargs["npush"] = extra_stack + state.fork(pc=tryblk["end"], **kwargs) + + def op_NOP(self, state, inst): + state.append(inst) + + def op_RESUME(self, state, inst): + state.append(inst) + + def op_CACHE(self, state, inst): + state.append(inst) + + def op_PRECALL(self, state, inst): + state.append(inst) + + def op_PUSH_NULL(self, state, inst): + state.push(state.make_null()) + state.append(inst) + + def op_RETURN_GENERATOR(self, state, inst): + # This impl doesn't follow what CPython does. CPython is hacking + # the frame stack in the interpreter. From usage, it always + # has a POP_TOP after it so we push a dummy value to the stack. + # + # Example bytecode: + # > 0 NOP(arg=None, lineno=80) + # 2 RETURN_GENERATOR(arg=None, lineno=80) + # 4 POP_TOP(arg=None, lineno=80) + # 6 RESUME(arg=0, lineno=80) + state.push(state.make_temp()) + state.append(inst) + + if PYVERSION in ((3, 13),): + + def op_FORMAT_SIMPLE(self, state, inst): + assert PYVERSION == (3, 13) + value = state.pop() + strvar = state.make_temp() + res = state.make_temp() + state.append(inst, value=value, res=res, strvar=strvar) + state.push(res) + + def op_FORMAT_VALUE(self, state, inst): + """ + FORMAT_VALUE(flags): flags argument specifies format spec which is + not supported yet. Currently, we just call str() on the value. + Pops a value from stack and pushes results back. + Required for supporting f-strings. + https://docs.python.org/3/library/dis.html#opcode-FORMAT_VALUE + """ + if inst.arg != 0: + msg = "format spec in f-strings not supported yet" + raise UnsupportedBytecodeError( + msg, loc=self.get_debug_loc(inst.lineno) + ) + value = state.pop() + strvar = state.make_temp() + res = state.make_temp() + state.append(inst, value=value, res=res, strvar=strvar) + state.push(res) + + def op_BUILD_STRING(self, state, inst): + """ + BUILD_STRING(count): Concatenates count strings from the stack and + pushes the resulting string onto the stack. + Required for supporting f-strings. + https://docs.python.org/3/library/dis.html#opcode-BUILD_STRING + """ + count = inst.arg + strings = list(reversed([state.pop() for _ in range(count)])) + # corner case: f"" + if count == 0: + tmps = [state.make_temp()] + else: + tmps = [state.make_temp() for _ in range(count - 1)] + state.append(inst, strings=strings, tmps=tmps) + state.push(tmps[-1]) + + def op_POP_TOP(self, state, inst): + state.pop() + + if PYVERSION in ((3, 13),): + + def op_TO_BOOL(self, state, inst): + res = state.make_temp() + tos = state.pop() + state.append(inst, val=tos, res=res) + state.push(res) + + elif PYVERSION < (3, 13): + pass + + if PYVERSION in ((3, 13),): + + def op_LOAD_GLOBAL(self, state, inst): + # Ordering of the global value and NULL is swapped in Py3.13 + res = state.make_temp() + idx = inst.arg >> 1 + state.append(inst, idx=idx, res=res) + state.push(res) + # ignoring the NULL + if inst.arg & 1: + state.push(state.make_null()) + elif PYVERSION in ((3, 11), (3, 12)): + + def op_LOAD_GLOBAL(self, state, inst): + res = state.make_temp() + idx = inst.arg >> 1 + state.append(inst, idx=idx, res=res) + # ignoring the NULL + if inst.arg & 1: + state.push(state.make_null()) + state.push(res) + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + + def op_LOAD_GLOBAL(self, state, inst): + res = state.make_temp() + state.append(inst, res=res) + state.push(res) + else: + raise NotImplementedError(PYVERSION) + + def op_COPY_FREE_VARS(self, state, inst): + state.append(inst) + + def op_MAKE_CELL(self, state, inst): + state.append(inst) + + def op_LOAD_DEREF(self, state, inst): + res = state.make_temp() + state.append(inst, res=res) + state.push(res) + + def op_LOAD_CONST(self, state, inst): + # append const index for interpreter to read the const value + res = state.make_temp("const") + f".{inst.arg}" + state.push(res) + state.append(inst, res=res) + + def op_LOAD_ATTR(self, state, inst): + item = state.pop() + res = state.make_temp() + if PYVERSION in ((3, 13),): + state.push(res) # the attr + if inst.arg & 1: + state.push(state.make_null()) + elif PYVERSION in ((3, 12),): + if inst.arg & 1: + state.push(state.make_null()) + state.push(res) + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + state.push(res) + else: + raise NotImplementedError(PYVERSION) + state.append(inst, item=item, res=res) + + def op_LOAD_FAST(self, state, inst): + assert PYVERSION <= (3, 13) + if PYVERSION in ((3, 13),): + try: + name = state.get_varname(inst) + except IndexError: # oparg is out of range + # Handle this like a LOAD_DEREF + # Assume MAKE_CELL and COPY_FREE_VARS has correctly setup the + # states. + # According to https://github.com/python/cpython/blob/9ac606080a0074cdf7589d9b7c9413a73e0ddf37/Objects/codeobject.c#L730C9-L759 # noqa E501 + # localsplus is locals + cells + freevars + bc = state._bytecode + num_varnames = len(bc.co_varnames) + num_freevars = len(bc.co_freevars) + num_cellvars = len(bc.co_cellvars) + max_fast_local = num_cellvars + num_freevars + assert 0 <= inst.arg - num_varnames < max_fast_local + res = state.make_temp() + state.append(inst, res=res, as_load_deref=True) + state.push(res) + return + else: + name = state.get_varname(inst) + res = state.make_temp(name) + state.append(inst, res=res) + state.push(res) + + if PYVERSION in ((3, 13),): + + def op_LOAD_FAST_LOAD_FAST(self, state, inst): + oparg = inst.arg + oparg1 = oparg >> 4 + oparg2 = oparg & 15 + name1 = state.get_varname_by_arg(oparg1) + name2 = state.get_varname_by_arg(oparg2) + res1 = state.make_temp(name1) + res2 = state.make_temp(name2) + state.append(inst, res1=res1, res2=res2) + state.push(res1) + state.push(res2) + + def op_STORE_FAST_LOAD_FAST(self, state, inst): + oparg = inst.arg + # oparg1 = oparg >> 4 # not needed + oparg2 = oparg & 15 + store_value = state.pop() + load_name = state.get_varname_by_arg(oparg2) + load_res = state.make_temp(load_name) + state.append(inst, store_value=store_value, load_res=load_res) + state.push(load_res) + + def op_STORE_FAST_STORE_FAST(self, state, inst): + value1 = state.pop() + value2 = state.pop() + state.append(inst, value1=value1, value2=value2) + + elif PYVERSION in ((3, 9), (3, 10), (3, 11), (3, 12)): + pass + else: + raise NotImplementedError(PYVERSION) + + if PYVERSION in ((3, 12), (3, 13)): + op_LOAD_FAST_CHECK = op_LOAD_FAST + op_LOAD_FAST_AND_CLEAR = op_LOAD_FAST + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass + else: + raise NotImplementedError(PYVERSION) + + def op_DELETE_FAST(self, state, inst): + state.append(inst) + + def op_DELETE_ATTR(self, state, inst): + target = state.pop() + state.append(inst, target=target) + + def op_STORE_ATTR(self, state, inst): + target = state.pop() + value = state.pop() + state.append(inst, target=target, value=value) + + def op_STORE_DEREF(self, state, inst): + value = state.pop() + state.append(inst, value=value) + + def op_STORE_FAST(self, state, inst): + value = state.pop() + state.append(inst, value=value) + + def op_SLICE_1(self, state, inst): + """ + TOS = TOS1[TOS:] + """ + tos = state.pop() + tos1 = state.pop() + res = state.make_temp() + slicevar = state.make_temp() + indexvar = state.make_temp() + nonevar = state.make_temp() + state.append( + inst, + base=tos1, + start=tos, + res=res, + slicevar=slicevar, + indexvar=indexvar, + nonevar=nonevar, + ) + state.push(res) + + def op_SLICE_2(self, state, inst): + """ + TOS = TOS1[:TOS] + """ + tos = state.pop() + tos1 = state.pop() + res = state.make_temp() + slicevar = state.make_temp() + indexvar = state.make_temp() + nonevar = state.make_temp() + state.append( + inst, + base=tos1, + stop=tos, + res=res, + slicevar=slicevar, + indexvar=indexvar, + nonevar=nonevar, + ) + state.push(res) + + def op_SLICE_3(self, state, inst): + """ + TOS = TOS2[TOS1:TOS] + """ + tos = state.pop() + tos1 = state.pop() + tos2 = state.pop() + res = state.make_temp() + slicevar = state.make_temp() + indexvar = state.make_temp() + state.append( + inst, + base=tos2, + start=tos1, + stop=tos, + res=res, + slicevar=slicevar, + indexvar=indexvar, + ) + state.push(res) + + def op_STORE_SLICE_0(self, state, inst): + """ + TOS[:] = TOS1 + """ + tos = state.pop() + value = state.pop() + slicevar = state.make_temp() + indexvar = state.make_temp() + nonevar = state.make_temp() + state.append( + inst, + base=tos, + value=value, + slicevar=slicevar, + indexvar=indexvar, + nonevar=nonevar, + ) + + def op_STORE_SLICE_1(self, state, inst): + """ + TOS1[TOS:] = TOS2 + """ + tos = state.pop() + tos1 = state.pop() + value = state.pop() + slicevar = state.make_temp() + indexvar = state.make_temp() + nonevar = state.make_temp() + state.append( + inst, + base=tos1, + start=tos, + slicevar=slicevar, + value=value, + indexvar=indexvar, + nonevar=nonevar, + ) + + def op_STORE_SLICE_2(self, state, inst): + """ + TOS1[:TOS] = TOS2 + """ + tos = state.pop() + tos1 = state.pop() + value = state.pop() + slicevar = state.make_temp() + indexvar = state.make_temp() + nonevar = state.make_temp() + state.append( + inst, + base=tos1, + stop=tos, + value=value, + slicevar=slicevar, + indexvar=indexvar, + nonevar=nonevar, + ) + + def op_STORE_SLICE_3(self, state, inst): + """ + TOS2[TOS1:TOS] = TOS3 + """ + tos = state.pop() + tos1 = state.pop() + tos2 = state.pop() + value = state.pop() + slicevar = state.make_temp() + indexvar = state.make_temp() + state.append( + inst, + base=tos2, + start=tos1, + stop=tos, + value=value, + slicevar=slicevar, + indexvar=indexvar, + ) + + def op_DELETE_SLICE_0(self, state, inst): + """ + del TOS[:] + """ + tos = state.pop() + slicevar = state.make_temp() + indexvar = state.make_temp() + nonevar = state.make_temp() + state.append( + inst, + base=tos, + slicevar=slicevar, + indexvar=indexvar, + nonevar=nonevar, + ) + + def op_DELETE_SLICE_1(self, state, inst): + """ + del TOS1[TOS:] + """ + tos = state.pop() + tos1 = state.pop() + slicevar = state.make_temp() + indexvar = state.make_temp() + nonevar = state.make_temp() + state.append( + inst, + base=tos1, + start=tos, + slicevar=slicevar, + indexvar=indexvar, + nonevar=nonevar, + ) + + def op_DELETE_SLICE_2(self, state, inst): + """ + del TOS1[:TOS] + """ + tos = state.pop() + tos1 = state.pop() + slicevar = state.make_temp() + indexvar = state.make_temp() + nonevar = state.make_temp() + state.append( + inst, + base=tos1, + stop=tos, + slicevar=slicevar, + indexvar=indexvar, + nonevar=nonevar, + ) + + def op_DELETE_SLICE_3(self, state, inst): + """ + del TOS2[TOS1:TOS] + """ + tos = state.pop() + tos1 = state.pop() + tos2 = state.pop() + slicevar = state.make_temp() + indexvar = state.make_temp() + state.append( + inst, + base=tos2, + start=tos1, + stop=tos, + slicevar=slicevar, + indexvar=indexvar, + ) + + def op_BUILD_SLICE(self, state, inst): + """ + slice(TOS1, TOS) or slice(TOS2, TOS1, TOS) + """ + argc = inst.arg + if argc == 2: + tos = state.pop() + tos1 = state.pop() + start = tos1 + stop = tos + step = None + elif argc == 3: + tos = state.pop() + tos1 = state.pop() + tos2 = state.pop() + start = tos2 + stop = tos1 + step = tos + else: + raise Exception("unreachable") + slicevar = state.make_temp() + res = state.make_temp() + state.append( + inst, start=start, stop=stop, step=step, res=res, slicevar=slicevar + ) + state.push(res) + + if PYVERSION in ((3, 12), (3, 13)): + + def op_BINARY_SLICE(self, state, inst): + end = state.pop() + start = state.pop() + container = state.pop() + temp_res = state.make_temp() + res = state.make_temp() + slicevar = state.make_temp() + state.append( + inst, + start=start, + end=end, + container=container, + res=res, + slicevar=slicevar, + temp_res=temp_res, + ) + state.push(res) + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass + else: + raise NotImplementedError(PYVERSION) + + if PYVERSION in ((3, 12), (3, 13)): + + def op_STORE_SLICE(self, state, inst): + end = state.pop() + start = state.pop() + container = state.pop() + value = state.pop() + + slicevar = state.make_temp() + res = state.make_temp() + state.append( + inst, + start=start, + end=end, + container=container, + value=value, + res=res, + slicevar=slicevar, + ) + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass + else: + raise NotImplementedError(PYVERSION) + + def _op_POP_JUMP_IF(self, state, inst): + pred = state.pop() + state.append(inst, pred=pred) + + target_inst = inst.get_jump_target() + next_inst = inst.next + # if the next inst and the jump target are the same location, issue one + # fork else issue a fork for the next and the target. + state.fork(pc=next_inst) + if target_inst != next_inst: + state.fork(pc=target_inst) + + op_POP_JUMP_IF_TRUE = _op_POP_JUMP_IF + op_POP_JUMP_IF_FALSE = _op_POP_JUMP_IF + + if PYVERSION in ((3, 12), (3, 13)): + op_POP_JUMP_IF_NONE = _op_POP_JUMP_IF + op_POP_JUMP_IF_NOT_NONE = _op_POP_JUMP_IF + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass + else: + raise NotImplementedError(PYVERSION) + + def _op_JUMP_IF_OR_POP(self, state, inst): + pred = state.get_tos() + state.append(inst, pred=pred) + state.fork(pc=inst.next, npop=1) + state.fork(pc=inst.get_jump_target()) + + op_JUMP_IF_FALSE_OR_POP = _op_JUMP_IF_OR_POP + op_JUMP_IF_TRUE_OR_POP = _op_JUMP_IF_OR_POP + + def op_POP_JUMP_FORWARD_IF_NONE(self, state, inst): + self._op_POP_JUMP_IF(state, inst) + + def op_POP_JUMP_FORWARD_IF_NOT_NONE(self, state, inst): + self._op_POP_JUMP_IF(state, inst) + + def op_POP_JUMP_BACKWARD_IF_NONE(self, state, inst): + self._op_POP_JUMP_IF(state, inst) + + def op_POP_JUMP_BACKWARD_IF_NOT_NONE(self, state, inst): + self._op_POP_JUMP_IF(state, inst) + + def op_POP_JUMP_FORWARD_IF_FALSE(self, state, inst): + self._op_POP_JUMP_IF(state, inst) + + def op_POP_JUMP_FORWARD_IF_TRUE(self, state, inst): + self._op_POP_JUMP_IF(state, inst) + + def op_POP_JUMP_BACKWARD_IF_FALSE(self, state, inst): + self._op_POP_JUMP_IF(state, inst) + + def op_POP_JUMP_BACKWARD_IF_TRUE(self, state, inst): + self._op_POP_JUMP_IF(state, inst) + + def op_JUMP_FORWARD(self, state, inst): + state.append(inst) + state.fork(pc=inst.get_jump_target()) + + def op_JUMP_BACKWARD(self, state, inst): + state.append(inst) + state.fork(pc=inst.get_jump_target()) + + op_JUMP_BACKWARD_NO_INTERRUPT = op_JUMP_BACKWARD + + def op_JUMP_ABSOLUTE(self, state, inst): + state.append(inst) + state.fork(pc=inst.get_jump_target()) + + def op_BREAK_LOOP(self, state, inst): + # NOTE: bytecode removed since py3.8 + end = state.get_top_block("LOOP")["end"] + state.append(inst, end=end) + state.pop_block() + state.fork(pc=end) + + def op_RETURN_VALUE(self, state, inst): + state.append(inst, retval=state.pop(), castval=state.make_temp()) + state.terminate() + + if PYVERSION in ((3, 12), (3, 13)): + + def op_RETURN_CONST(self, state, inst): + res = state.make_temp("const") + state.append(inst, retval=res, castval=state.make_temp()) + state.terminate() + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass + else: + raise NotImplementedError(PYVERSION) + + def op_YIELD_VALUE(self, state, inst): + val = state.pop() + res = state.make_temp() + state.append(inst, value=val, res=res) + state.push(res) + + if PYVERSION in ((3, 11), (3, 12), (3, 13)): + + def op_RAISE_VARARGS(self, state, inst): + if inst.arg == 0: + exc = None + # No re-raising within a try-except block. + # But we allow bare reraise. + if state.has_active_try(): + raise UnsupportedBytecodeError( + "The re-raising of an exception is not yet supported.", + loc=self.get_debug_loc(inst.lineno), + ) + elif inst.arg == 1: + exc = state.pop() + else: + raise ValueError("Multiple argument raise is not supported.") + state.append(inst, exc=exc) + + if state.has_active_try(): + self._adjust_except_stack(state) + else: + state.terminate() + + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + + def op_RAISE_VARARGS(self, state, inst): + in_exc_block = any( + [ + state.get_top_block("EXCEPT") is not None, + state.get_top_block("FINALLY") is not None, + ] + ) + if inst.arg == 0: + exc = None + if in_exc_block: + raise UnsupportedBytecodeError( + "The re-raising of an exception is not yet supported.", + loc=self.get_debug_loc(inst.lineno), + ) + elif inst.arg == 1: + exc = state.pop() + else: + raise ValueError("Multiple argument raise is not supported.") + state.append(inst, exc=exc) + state.terminate() + else: + raise NotImplementedError(PYVERSION) + + def op_BEGIN_FINALLY(self, state, inst): + temps = [] + for i in range(_EXCEPT_STACK_OFFSET): + tmp = state.make_temp() + temps.append(tmp) + state.push(tmp) + state.append(inst, temps=temps) + + def op_END_FINALLY(self, state, inst): + blk = state.pop_block() + state.reset_stack(blk["entry_stack"]) + + if PYVERSION in ((3, 13),): + + def op_END_FOR(self, state, inst): + state.pop() + elif PYVERSION in ((3, 12),): + + def op_END_FOR(self, state, inst): + state.pop() + state.pop() + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass + else: + raise NotImplementedError(PYVERSION) + + def op_POP_FINALLY(self, state, inst): + # we don't emulate the exact stack behavior + if inst.arg != 0: + msg = ( + "Unsupported use of a bytecode related to try..finally" + " or a with-context" + ) + raise UnsupportedBytecodeError( + msg, loc=self.get_debug_loc(inst.lineno) + ) + + def op_CALL_FINALLY(self, state, inst): + pass + + def op_WITH_EXCEPT_START(self, state, inst): + state.terminate() # do not support + + def op_WITH_CLEANUP_START(self, state, inst): + # we don't emulate the exact stack behavior + state.append(inst) + + def op_WITH_CLEANUP_FINISH(self, state, inst): + # we don't emulate the exact stack behavior + state.append(inst) + + def op_SETUP_LOOP(self, state, inst): + # NOTE: bytecode removed since py3.8 + state.push_block( + state.make_block( + kind="LOOP", + end=inst.get_jump_target(), + ) + ) + + def op_BEFORE_WITH(self, state, inst): + # Almost the same as py3.10 SETUP_WITH just lacking the finally block. + cm = state.pop() # the context-manager + + yielded = state.make_temp() + exitfn = state.make_temp(prefix="setup_with_exitfn") + + state.push(exitfn) + state.push(yielded) + + # Gather all exception entries for this WITH. There maybe multiple + # entries; esp. for nested WITHs. + bc = state._bytecode + ehhead = bc.find_exception_entry(inst.next) + ehrelated = [ehhead] + for eh in bc.exception_entries: + if eh.target == ehhead.target: + ehrelated.append(eh) + end = max(eh.end for eh in ehrelated) + state.append(inst, contextmanager=cm, exitfn=exitfn, end=end) + + state.push_block( + state.make_block( + kind="WITH", + end=end, + ) + ) + # Forces a new block + state.fork(pc=inst.next) + + def op_SETUP_WITH(self, state, inst): + cm = state.pop() # the context-manager + + yielded = state.make_temp() + exitfn = state.make_temp(prefix="setup_with_exitfn") + state.append(inst, contextmanager=cm, exitfn=exitfn) + + state.push(exitfn) + state.push(yielded) + + state.push_block( + state.make_block( + kind="WITH", + end=inst.get_jump_target(), + ) + ) + # Forces a new block + state.fork(pc=inst.next) + + def _setup_try(self, kind, state, next, end): + # Forces a new block + # Fork to the body of the finally + handler_block = state.make_block( + kind=kind, + end=None, + reset_stack=False, + ) + # Forces a new block + # Fork to the body of the finally + state.fork( + pc=next, + extra_block=state.make_block( + kind="TRY", + end=end, + reset_stack=False, + handler=handler_block, + ), + ) + + def op_PUSH_EXC_INFO(self, state, inst): + tos = state.pop() + state.push(state.make_temp("exception")) + state.push(tos) + + def op_SETUP_FINALLY(self, state, inst): + state.append(inst) + self._setup_try( + "FINALLY", + state, + next=inst.next, + end=inst.get_jump_target(), + ) + + if PYVERSION in ((3, 11), (3, 12), (3, 13)): + + def op_POP_EXCEPT(self, state, inst): + state.pop() + + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + + def op_POP_EXCEPT(self, state, inst): + blk = state.pop_block() + if blk["kind"] not in {BlockKind("EXCEPT"), BlockKind("FINALLY")}: + raise UnsupportedBytecodeError( + f"POP_EXCEPT got an unexpected block: {blk['kind']}", + loc=self.get_debug_loc(inst.lineno), + ) + state.pop() + state.pop() + state.pop() + # Forces a new block + state.fork(pc=inst.next) + else: + raise NotImplementedError(PYVERSION) + + def op_POP_BLOCK(self, state, inst): + blk = state.pop_block() + if blk["kind"] == BlockKind("TRY"): + state.append(inst, kind="try") + elif blk["kind"] == BlockKind("WITH"): + state.append(inst, kind="with") + state.fork(pc=inst.next) + + def op_BINARY_SUBSCR(self, state, inst): + index = state.pop() + target = state.pop() + res = state.make_temp() + state.append(inst, index=index, target=target, res=res) + state.push(res) + + def op_STORE_SUBSCR(self, state, inst): + index = state.pop() + target = state.pop() + value = state.pop() + state.append(inst, target=target, index=index, value=value) + + def op_DELETE_SUBSCR(self, state, inst): + index = state.pop() + target = state.pop() + state.append(inst, target=target, index=index) + + def op_CALL(self, state, inst): + narg = inst.arg + args = list(reversed([state.pop() for _ in range(narg)])) + if PYVERSION == (3, 13): + null_or_self = state.pop() + # position of the callable is fixed + callable = state.pop() + if not _is_null_temp_reg(null_or_self): + args = [null_or_self, *args] + kw_names = None + elif PYVERSION < (3, 13): + callable_or_firstarg = state.pop() + null_or_callable = state.pop() + if _is_null_temp_reg(null_or_callable): + callable = callable_or_firstarg + else: + callable = null_or_callable + args = [callable_or_firstarg, *args] + kw_names = state.pop_kw_names() + res = state.make_temp() + + state.append(inst, func=callable, args=args, kw_names=kw_names, res=res) + state.push(res) + + def op_KW_NAMES(self, state, inst): + state.set_kw_names(inst.arg) + + def op_CALL_FUNCTION(self, state, inst): + narg = inst.arg + args = list(reversed([state.pop() for _ in range(narg)])) + func = state.pop() + + res = state.make_temp() + state.append(inst, func=func, args=args, res=res) + state.push(res) + + def op_CALL_FUNCTION_KW(self, state, inst): + narg = inst.arg + names = state.pop() # tuple of names + args = list(reversed([state.pop() for _ in range(narg)])) + func = state.pop() + + res = state.make_temp() + state.append(inst, func=func, args=args, names=names, res=res) + state.push(res) + + if PYVERSION in ((3, 13),): + + def op_CALL_KW(self, state, inst): + narg = inst.arg + kw_names = state.pop() + args = list(reversed([state.pop() for _ in range(narg)])) + null_or_firstarg = state.pop() + callable = state.pop() + if not _is_null_temp_reg(null_or_firstarg): + args = [null_or_firstarg, *args] + + res = state.make_temp() + state.append( + inst, func=callable, args=args, kw_names=kw_names, res=res + ) + state.push(res) + + elif PYVERSION in ((3, 9), (3, 10), (3, 11), (3, 12)): + pass + else: + raise NotImplementedError(PYVERSION) + + if PYVERSION in ((3, 13),): + + def op_CALL_FUNCTION_EX(self, state, inst): + # (func, unused, callargs, kwargs if (oparg & 1) -- result)) + if inst.arg & 1: + varkwarg = state.pop() + else: + varkwarg = None + + vararg = state.pop() + state.pop() # unused + func = state.pop() + + res = state.make_temp() + state.append( + inst, func=func, vararg=vararg, varkwarg=varkwarg, res=res + ) + state.push(res) + + elif PYVERSION in ((3, 9), (3, 10), (3, 11), (3, 12)): + + def op_CALL_FUNCTION_EX(self, state, inst): + if inst.arg & 1: + varkwarg = state.pop() + else: + varkwarg = None + vararg = state.pop() + func = state.pop() + + if PYVERSION in ((3, 11), (3, 12)): + if _is_null_temp_reg(state.peek(1)): + state.pop() # pop NULL, it's not used + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + pass + else: + raise NotImplementedError(PYVERSION) + + res = state.make_temp() + state.append( + inst, func=func, vararg=vararg, varkwarg=varkwarg, res=res + ) + state.push(res) + else: + raise NotImplementedError(PYVERSION) + + def _dup_topx(self, state, inst, count): + orig = [state.pop() for _ in range(count)] + orig.reverse() + # We need to actually create new temporaries if we want the + # IR optimization pass to work correctly (see issue #580) + duped = [state.make_temp() for _ in range(count)] + state.append(inst, orig=orig, duped=duped) + for val in orig: + state.push(val) + for val in duped: + state.push(val) + + if PYVERSION in ((3, 12), (3, 13)): + + def op_CALL_INTRINSIC_1(self, state, inst): + # See https://github.com/python/cpython/blob/v3.12.0rc2/Include/ + # internal/pycore_intrinsics.h#L3-L17C36 + try: + operand = CALL_INTRINSIC_1_Operand(inst.arg) + except TypeError: + msg = f"op_CALL_INTRINSIC_1({inst.arg})" + loc = self.get_debug_loc(inst.lineno) + raise UnsupportedBytecodeError(msg, loc=loc) + if operand == ci1op.INTRINSIC_STOPITERATION_ERROR: + state.append(inst, operand=operand) + state.terminate() + return + elif operand == ci1op.UNARY_POSITIVE: + val = state.pop() + res = state.make_temp() + state.append(inst, operand=operand, value=val, res=res) + state.push(res) + return + elif operand == ci1op.INTRINSIC_LIST_TO_TUPLE: + tos = state.pop() + res = state.make_temp() + state.append(inst, operand=operand, const_list=tos, res=res) + state.push(res) + return + else: + raise NotImplementedError(operand) + + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass + else: + raise NotImplementedError(PYVERSION) + + def op_DUP_TOPX(self, state, inst): + count = inst.arg + assert 1 <= count <= 5, "Invalid DUP_TOPX count" + self._dup_topx(state, inst, count) + + def op_DUP_TOP(self, state, inst): + self._dup_topx(state, inst, count=1) + + def op_DUP_TOP_TWO(self, state, inst): + self._dup_topx(state, inst, count=2) + + def op_COPY(self, state, inst): + state.push(state.peek(inst.arg)) + + def op_SWAP(self, state, inst): + state.swap(inst.arg) + + def op_ROT_TWO(self, state, inst): + first = state.pop() + second = state.pop() + state.push(first) + state.push(second) + + def op_ROT_THREE(self, state, inst): + first = state.pop() + second = state.pop() + third = state.pop() + state.push(first) + state.push(third) + state.push(second) + + def op_ROT_FOUR(self, state, inst): + first = state.pop() + second = state.pop() + third = state.pop() + forth = state.pop() + state.push(first) + state.push(forth) + state.push(third) + state.push(second) + + def op_UNPACK_SEQUENCE(self, state, inst): + count = inst.arg + iterable = state.pop() + stores = [state.make_temp() for _ in range(count)] + tupleobj = state.make_temp() + state.append(inst, iterable=iterable, stores=stores, tupleobj=tupleobj) + for st in reversed(stores): + state.push(st) + + def op_BUILD_TUPLE(self, state, inst): + count = inst.arg + items = list(reversed([state.pop() for _ in range(count)])) + tup = state.make_temp() + state.append(inst, items=items, res=tup) + state.push(tup) + + def _build_tuple_unpack(self, state, inst): + # Builds tuple from other tuples on the stack + tuples = list(reversed([state.pop() for _ in range(inst.arg)])) + temps = [state.make_temp() for _ in range(len(tuples) - 1)] + + # if the unpack is assign-like, e.g. x = (*y,), it needs handling + # differently. + is_assign = len(tuples) == 1 + if is_assign: + temps = [ + state.make_temp(), + ] + + state.append(inst, tuples=tuples, temps=temps, is_assign=is_assign) + # The result is in the last temp var + state.push(temps[-1]) + + def op_BUILD_TUPLE_UNPACK_WITH_CALL(self, state, inst): + # just unpack the input tuple, call inst will be handled afterwards + self._build_tuple_unpack(state, inst) + + def op_BUILD_TUPLE_UNPACK(self, state, inst): + self._build_tuple_unpack(state, inst) + + def op_LIST_TO_TUPLE(self, state, inst): + # "Pops a list from the stack and pushes a tuple containing the same + # values." + tos = state.pop() + res = state.make_temp() # new tuple var + state.append(inst, const_list=tos, res=res) + state.push(res) + + def op_BUILD_CONST_KEY_MAP(self, state, inst): + keys = state.pop() + vals = list(reversed([state.pop() for _ in range(inst.arg)])) + keytmps = [state.make_temp() for _ in range(inst.arg)] + res = state.make_temp() + state.append(inst, keys=keys, keytmps=keytmps, values=vals, res=res) + state.push(res) + + def op_BUILD_LIST(self, state, inst): + count = inst.arg + items = list(reversed([state.pop() for _ in range(count)])) + lst = state.make_temp() + state.append(inst, items=items, res=lst) + state.push(lst) + + def op_LIST_APPEND(self, state, inst): + value = state.pop() + index = inst.arg + target = state.peek(index) + appendvar = state.make_temp() + res = state.make_temp() + state.append( + inst, target=target, value=value, appendvar=appendvar, res=res + ) + + def op_LIST_EXTEND(self, state, inst): + value = state.pop() + index = inst.arg + target = state.peek(index) + extendvar = state.make_temp() + res = state.make_temp() + state.append( + inst, target=target, value=value, extendvar=extendvar, res=res + ) + + def op_BUILD_MAP(self, state, inst): + dct = state.make_temp() + count = inst.arg + items = [] + # In 3.5+, BUILD_MAP takes pairs from the stack + for i in range(count): + v, k = state.pop(), state.pop() + items.append((k, v)) + state.append(inst, items=items[::-1], size=count, res=dct) + state.push(dct) + + def op_MAP_ADD(self, state, inst): + TOS = state.pop() + TOS1 = state.pop() + key, value = (TOS1, TOS) + index = inst.arg + target = state.peek(index) + setitemvar = state.make_temp() + res = state.make_temp() + state.append( + inst, + target=target, + key=key, + value=value, + setitemvar=setitemvar, + res=res, + ) + + def op_BUILD_SET(self, state, inst): + count = inst.arg + # Note: related python bug http://bugs.python.org/issue26020 + items = list(reversed([state.pop() for _ in range(count)])) + res = state.make_temp() + state.append(inst, items=items, res=res) + state.push(res) + + def op_SET_UPDATE(self, state, inst): + value = state.pop() + index = inst.arg + target = state.peek(index) + updatevar = state.make_temp() + res = state.make_temp() + state.append( + inst, target=target, value=value, updatevar=updatevar, res=res + ) + + def op_DICT_UPDATE(self, state, inst): + value = state.pop() + index = inst.arg + target = state.peek(index) + updatevar = state.make_temp() + res = state.make_temp() + state.append( + inst, target=target, value=value, updatevar=updatevar, res=res + ) + + def op_GET_ITER(self, state, inst): + value = state.pop() + res = state.make_temp() + state.append(inst, value=value, res=res) + state.push(res) + + def op_FOR_ITER(self, state, inst): + iterator = state.get_tos() + pair = state.make_temp() + indval = state.make_temp() + pred = state.make_temp() + state.append( + inst, iterator=iterator, pair=pair, indval=indval, pred=pred + ) + state.push(indval) + end = inst.get_jump_target() + if PYVERSION in ((3, 12), (3, 13)): + # Changed in version 3.12: Up until 3.11 the iterator was + # popped when it was exhausted. Now this is handled using END_FOR + # op code. + state.fork(pc=end) + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + state.fork(pc=end, npop=2) + else: + raise NotImplementedError(PYVERSION) + state.fork(pc=inst.next) + + def op_GEN_START(self, state, inst): + """Pops TOS. If TOS was not None, raises an exception. The kind + operand corresponds to the type of generator or coroutine and + determines the error message. The legal kinds are 0 for generator, + 1 for coroutine, and 2 for async generator. + + New in version 3.10. + """ + # no-op in Numba + pass + + def op_BINARY_OP(self, state, inst): + op = dis._nb_ops[inst.arg][1] + rhs = state.pop() + lhs = state.pop() + op_name = ALL_BINOPS_TO_OPERATORS[op].__name__ + res = state.make_temp(prefix=f"binop_{op_name}") + state.append(inst, op=op, lhs=lhs, rhs=rhs, res=res) + state.push(res) + + def _unaryop(self, state, inst): + val = state.pop() + res = state.make_temp() + state.append(inst, value=val, res=res) + state.push(res) + + op_UNARY_NEGATIVE = _unaryop + op_UNARY_POSITIVE = _unaryop + op_UNARY_NOT = _unaryop + op_UNARY_INVERT = _unaryop + + def _binaryop(self, state, inst): + rhs = state.pop() + lhs = state.pop() + res = state.make_temp() + state.append(inst, lhs=lhs, rhs=rhs, res=res) + state.push(res) + + op_COMPARE_OP = _binaryop + op_IS_OP = _binaryop + op_CONTAINS_OP = _binaryop + + op_INPLACE_ADD = _binaryop + op_INPLACE_SUBTRACT = _binaryop + op_INPLACE_MULTIPLY = _binaryop + op_INPLACE_DIVIDE = _binaryop + op_INPLACE_TRUE_DIVIDE = _binaryop + op_INPLACE_FLOOR_DIVIDE = _binaryop + op_INPLACE_MODULO = _binaryop + op_INPLACE_POWER = _binaryop + op_INPLACE_MATRIX_MULTIPLY = _binaryop + + op_INPLACE_LSHIFT = _binaryop + op_INPLACE_RSHIFT = _binaryop + op_INPLACE_AND = _binaryop + op_INPLACE_OR = _binaryop + op_INPLACE_XOR = _binaryop + + op_BINARY_ADD = _binaryop + op_BINARY_SUBTRACT = _binaryop + op_BINARY_MULTIPLY = _binaryop + op_BINARY_DIVIDE = _binaryop + op_BINARY_TRUE_DIVIDE = _binaryop + op_BINARY_FLOOR_DIVIDE = _binaryop + op_BINARY_MODULO = _binaryop + op_BINARY_POWER = _binaryop + op_BINARY_MATRIX_MULTIPLY = _binaryop + + op_BINARY_LSHIFT = _binaryop + op_BINARY_RSHIFT = _binaryop + op_BINARY_AND = _binaryop + op_BINARY_OR = _binaryop + op_BINARY_XOR = _binaryop + + def op_MAKE_FUNCTION(self, state, inst, MAKE_CLOSURE=False): + if PYVERSION in ((3, 11), (3, 12), (3, 13)): + # https://github.com/python/cpython/commit/2f180ce + # name set via co_qualname + name = None + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + name = state.pop() + else: + raise NotImplementedError(PYVERSION) + code = state.pop() + closure = annotations = kwdefaults = defaults = None + if PYVERSION in ((3, 13),): + assert inst.arg is None + # SET_FUNCTION_ATTRIBUTE is responsible for setting + # closure, annotations, kwdefaults and defaults. + else: + if inst.arg & 0x8: + closure = state.pop() + if inst.arg & 0x4: + annotations = state.pop() + if inst.arg & 0x2: + kwdefaults = state.pop() + if inst.arg & 0x1: + defaults = state.pop() + res = state.make_temp() + state.append( + inst, + name=name, + code=code, + closure=closure, + annotations=annotations, + kwdefaults=kwdefaults, + defaults=defaults, + res=res, + ) + state.push(res) + + def op_SET_FUNCTION_ATTRIBUTE(self, state, inst): + assert PYVERSION in ((3, 13),) + make_func_stack = state.pop() + data = state.pop() + if inst.arg == 0x1: + # 0x01 a tuple of default values for positional-only and + # positional-or-keyword parameters in positional order + state.set_function_attribute(make_func_stack, defaults=data) + elif inst.arg & 0x2: + # 0x02 a tuple of strings containing parameters’ annotations + state.set_function_attribute(make_func_stack, kwdefaults=data) + elif inst.arg & 0x4: + # 0x04 a tuple of strings containing parameters’ annotations + state.set_function_attribute(make_func_stack, annotations=data) + elif inst.arg == 0x8: + # 0x08 a tuple containing cells for free variables, making a closure + state.set_function_attribute(make_func_stack, closure=data) + else: + raise AssertionError("unreachable") + state.push(make_func_stack) + + def op_MAKE_CLOSURE(self, state, inst): + self.op_MAKE_FUNCTION(state, inst, MAKE_CLOSURE=True) + + def op_LOAD_CLOSURE(self, state, inst): + res = state.make_temp() + state.append(inst, res=res) + state.push(res) + + def op_LOAD_ASSERTION_ERROR(self, state, inst): + res = state.make_temp("assertion_error") + state.append(inst, res=res) + state.push(res) + + def op_CHECK_EXC_MATCH(self, state, inst): + pred = state.make_temp("predicate") + tos = state.pop() + tos1 = state.get_tos() + state.append(inst, pred=pred, tos=tos, tos1=tos1) + state.push(pred) + + def op_JUMP_IF_NOT_EXC_MATCH(self, state, inst): + # Tests whether the second value on the stack is an exception matching + # TOS, and jumps if it is not. Pops two values from the stack. + pred = state.make_temp("predicate") + tos = state.pop() + tos1 = state.pop() + state.append(inst, pred=pred, tos=tos, tos1=tos1) + state.fork(pc=inst.next) + state.fork(pc=inst.get_jump_target()) + + if PYVERSION in ((3, 11), (3, 12), (3, 13)): + + def op_RERAISE(self, state, inst): + # This isn't handled, but the state is set up anyway + exc = state.pop() + if inst.arg != 0: + state.pop() # lasti + state.append(inst, exc=exc) + + if state.has_active_try(): + self._adjust_except_stack(state) + else: + state.terminate() + + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + + def op_RERAISE(self, state, inst): + # This isn't handled, but the state is set up anyway + exc = state.pop() + state.append(inst, exc=exc) + state.terminate() + else: + raise NotImplementedError(PYVERSION) + + # NOTE: Please see notes in `interpreter.py` surrounding the implementation + # of LOAD_METHOD and CALL_METHOD. + + if PYVERSION in ((3, 12), (3, 13)): + # LOAD_METHOD has become a pseudo-instruction in 3.12 + pass + elif PYVERSION in ((3, 11),): + + def op_LOAD_METHOD(self, state, inst): + item = state.pop() + extra = state.make_null() + state.push(extra) + res = state.make_temp() + state.append(inst, item=item, res=res) + state.push(res) + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + + def op_LOAD_METHOD(self, state, inst): + self.op_LOAD_ATTR(state, inst) + else: + raise NotImplementedError(PYVERSION) + + def op_CALL_METHOD(self, state, inst): + self.op_CALL_FUNCTION(state, inst) + + +@total_ordering +class _State(object): + """State of the trace""" + + def __init__(self, bytecode, pc, nstack, blockstack, nullvals=()): + """ + Parameters + ---------- + bytecode : numba.bytecode.ByteCode + function bytecode + pc : int + program counter + nstack : int + stackdepth at entry + blockstack : Sequence[Dict] + A sequence of dictionary denoting entries on the blockstack. + """ + self._bytecode = bytecode + self._pc_initial = pc + self._pc = pc + self._nstack_initial = nstack + self._stack = [] + self._blockstack_initial = tuple(blockstack) + self._blockstack = list(blockstack) + self._temp_registers = [] + self._insts = [] + self._outedges = [] + self._terminated = False + self._phis = {} + self._outgoing_phis = UniqueDict() + self._used_regs = set() + for i in range(nstack): + if i in nullvals: + phi = self.make_temp("null$") + else: + phi = self.make_temp("phi") + self._phis[phi] = i + self.push(phi) + + def __repr__(self): + return "State(pc_initial={} nstack_initial={})".format( + self._pc_initial, self._nstack_initial + ) + + def get_identity(self): + return (self._pc_initial, self._nstack_initial) + + def __hash__(self): + return hash(self.get_identity()) + + def __lt__(self, other): + return self.get_identity() < other.get_identity() + + def __eq__(self, other): + return self.get_identity() == other.get_identity() + + @property + def pc_initial(self): + """The starting bytecode offset of this State. + The PC given to the constructor. + """ + return self._pc_initial + + @property + def instructions(self): + """The list of instructions information as a 2-tuple of + ``(pc : int, register_map : Dict)`` + """ + return self._insts + + @property + def outgoing_edges(self): + """The list of outgoing edges. + + Returns + ------- + edges : List[State] + """ + return self._outedges + + @property + def outgoing_phis(self): + """The dictionary of outgoing phi nodes. + + The keys are the name of the PHI nodes. + The values are the outgoing states. + """ + return self._outgoing_phis + + @property + def blockstack_initial(self): + """A copy of the initial state of the blockstack""" + return self._blockstack_initial + + @property + def stack_depth(self): + """The current size of the stack + + Returns + ------- + res : int + """ + return len(self._stack) + + def find_initial_try_block(self): + """Find the initial *try* block.""" + for blk in reversed(self._blockstack_initial): + if blk["kind"] == BlockKind("TRY"): + return blk + + def has_terminated(self): + return self._terminated + + def get_inst(self): + return self._bytecode[self._pc] + + def advance_pc(self): + inst = self.get_inst() + self._pc = inst.next + + def make_temp(self, prefix=""): + if not prefix: + name = "${prefix}{offset}{opname}.{tempct}".format( + prefix=prefix, + offset=self._pc, + opname=self.get_inst().opname.lower(), + tempct=len(self._temp_registers), + ) + else: + name = "${prefix}{offset}.{tempct}".format( + prefix=prefix, + offset=self._pc, + tempct=len(self._temp_registers), + ) + + self._temp_registers.append(name) + return name + + def append(self, inst, **kwargs): + """Append new inst""" + self._insts.append((inst.offset, kwargs)) + self._used_regs |= set(_flatten_inst_regs(kwargs.values())) + + def get_tos(self): + return self.peek(1) + + def peek(self, k): + """Return the k'th element on the stack""" + return self._stack[-k] + + def push(self, item): + """Push to stack""" + self._stack.append(item) + + def pop(self): + """Pop the stack""" + return self._stack.pop() + + def swap(self, idx): + """Swap stack[idx] with the tos""" + s = self._stack + s[-1], s[-idx] = s[-idx], s[-1] + + def push_block(self, synblk): + """Push a block to blockstack""" + assert "stack_depth" in synblk + self._blockstack.append(synblk) + + def reset_stack(self, depth): + """Reset the stack to the given stack depth. + Returning the popped items. + """ + self._stack, popped = self._stack[:depth], self._stack[depth:] + return popped + + def make_block(self, kind, end, reset_stack=True, handler=None): + """Make a new block""" + d = { + "kind": BlockKind(kind), + "end": end, + "entry_stack": len(self._stack), + } + if reset_stack: + d["stack_depth"] = len(self._stack) + else: + d["stack_depth"] = None + d["handler"] = handler + return d + + def pop_block(self): + """Pop a block and unwind the stack""" + b = self._blockstack.pop() + self.reset_stack(b["stack_depth"]) + return b + + def pop_block_and_above(self, blk): + """Find *blk* in the blockstack and remove it and all blocks above it + from the stack. + """ + idx = self._blockstack.index(blk) + assert 0 <= idx < len(self._blockstack) + self._blockstack = self._blockstack[:idx] + + def get_top_block(self, kind): + """Find the first block that matches *kind*""" + kind = BlockKind(kind) + for bs in reversed(self._blockstack): + if bs["kind"] == kind: + return bs + + def get_top_block_either(self, *kinds): + """Find the first block that matches *kind*""" + kinds = {BlockKind(kind) for kind in kinds} + for bs in reversed(self._blockstack): + if bs["kind"] in kinds: + return bs + + def has_active_try(self): + """Returns a boolean indicating if the top-block is a *try* block""" + return self.get_top_block("TRY") is not None + + def get_varname(self, inst): + """Get referenced variable name from the instruction's oparg""" + return self.get_varname_by_arg(inst.arg) + + def get_varname_by_arg(self, oparg: int): + """Get referenced variable name from the oparg""" + return self._bytecode.co_varnames[oparg] + + def terminate(self): + """Mark block as terminated""" + self._terminated = True + + def fork(self, pc, npop=0, npush=0, extra_block=None): + """Fork the state""" + # Handle changes on the stack + stack = list(self._stack) + if npop: + assert 0 <= npop <= len(self._stack) + nstack = len(self._stack) - npop + stack = stack[:nstack] + if npush: + assert 0 <= npush + for i in range(npush): + stack.append(self.make_temp()) + # Handle changes on the blockstack + blockstack = list(self._blockstack) + if PYVERSION in ((3, 11), (3, 12), (3, 13)): + # pop expired block in destination pc + while blockstack: + top = blockstack[-1] + end = top.get("end_offset") or top["end"] + if pc >= end: + blockstack.pop() + else: + break + elif PYVERSION in ( + (3, 9), + (3, 10), + ): + pass # intentionally bypass + else: + raise NotImplementedError(PYVERSION) + + if extra_block: + blockstack.append(extra_block) + self._outedges.append( + Edge( + pc=pc, + stack=tuple(stack), + npush=npush, + blockstack=tuple(blockstack), + ) + ) + self.terminate() + + def split_new_block(self): + """Split the state""" + self.fork(pc=self._pc) + + def get_outgoing_states(self): + """Get states for each outgoing edges""" + # Should only call once + assert not self._outgoing_phis + ret = [] + for edge in self._outedges: + state = State( + bytecode=self._bytecode, + pc=edge.pc, + nstack=len(edge.stack), + blockstack=edge.blockstack, + nullvals=[ + i for i, v in enumerate(edge.stack) if _is_null_temp_reg(v) + ], + ) + ret.append(state) + # Map outgoing_phis + for phi, i in state._phis.items(): + self._outgoing_phis[phi] = edge.stack[i] + return ret + + def get_outgoing_edgepushed(self): + """ + Returns + ------- + Dict[int, int] + where keys are the PC + values are the edge-pushed stack values + """ + + return { + edge.pc: tuple(edge.stack[-edge.npush :]) for edge in self._outedges + } + + +class StatePy311(_State): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._kw_names = None + + def pop_kw_names(self): + out = self._kw_names + self._kw_names = None + return out + + def set_kw_names(self, val): + assert self._kw_names is None + self._kw_names = val + + def is_in_exception(self): + bc = self._bytecode + return bc.find_exception_entry(self._pc) is not None + + def get_exception(self): + bc = self._bytecode + return bc.find_exception_entry(self._pc) + + def in_with(self): + for ent in self._blockstack_initial: + if ent["kind"] == BlockKind("WITH"): + return True + + def make_null(self): + return self.make_temp(prefix="null$") + + +class StatePy313(StatePy311): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._make_func_attrs = defaultdict(dict) + + def set_function_attribute(self, make_func_res, **kwargs): + self._make_func_attrs[make_func_res].update(kwargs) + + def get_function_attributes(self, make_func_res): + return self._make_func_attrs[make_func_res] + + +if PYVERSION in ((3, 13),): + State = StatePy313 +elif PYVERSION in ((3, 11), (3, 12)): + State = StatePy311 +elif PYVERSION < (3, 11): + State = _State +else: + raise NotImplementedError(PYVERSION) + + +Edge = namedtuple("Edge", ["pc", "stack", "blockstack", "npush"]) + + +class AdaptDFA(object): + """Adapt Flow to the old DFA class expected by Interpreter""" + + def __init__(self, flow): + self._flow = flow + + @property + def infos(self): + return self._flow.block_infos + + +AdaptBlockInfo = namedtuple( + "AdaptBlockInfo", + [ + "insts", + "outgoing_phis", + "blockstack", + "active_try_block", + "outgoing_edgepushed", + ], +) + + +def adapt_state_infos(state): + def process_function_attributes(inst_pair): + offset, data = inst_pair + inst = state._bytecode[offset] + if inst.opname == "MAKE_FUNCTION": + data.update(state.get_function_attributes(data["res"])) + return offset, data + + if PYVERSION in ((3, 13),): + insts = tuple(map(process_function_attributes, state.instructions)) + elif PYVERSION in ((3, 9), (3, 10), (3, 11), (3, 12)): + insts = tuple(state.instructions) + else: + raise NotImplementedError(PYVERSION) + return AdaptBlockInfo( + insts=insts, + outgoing_phis=state.outgoing_phis, + blockstack=state.blockstack_initial, + active_try_block=state.find_initial_try_block(), + outgoing_edgepushed=state.get_outgoing_edgepushed(), + ) + + +def _flatten_inst_regs(iterable): + """Flatten an iterable of registers used in an instruction""" + for item in iterable: + if isinstance(item, str): + yield item + elif isinstance(item, (tuple, list)): + for x in _flatten_inst_regs(item): + yield x + + +class AdaptCFA(object): + """Adapt Flow to the old CFA class expected by Interpreter""" + + def __init__(self, flow): + self._flow = flow + self._blocks = {} + for offset, blockinfo in flow.block_infos.items(): + self._blocks[offset] = AdaptCFBlock(blockinfo, offset) + backbone = self._flow.cfgraph.backbone() + + graph = flow.cfgraph + # Find backbone + backbone = graph.backbone() + # Filter out in loop blocks (Assuming no other cyclic control blocks) + # This is to unavoid variables defined in loops being considered as + # function scope. + inloopblocks = set() + for b in self.blocks.keys(): + if graph.in_loops(b): + inloopblocks.add(b) + self._backbone = backbone - inloopblocks + + @property + def graph(self): + return self._flow.cfgraph + + @property + def backbone(self): + return self._backbone + + @property + def blocks(self): + return self._blocks + + def iterliveblocks(self): + for b in sorted(self.blocks): + yield self.blocks[b] + + def dump(self): + self._flow.cfgraph.dump() + + +class AdaptCFBlock(object): + def __init__(self, blockinfo, offset): + self.offset = offset + self.body = tuple(i for i, _ in blockinfo.insts) diff --git a/numba_cuda/numba/cuda/core/controlflow.py b/numba_cuda/numba/cuda/core/controlflow.py new file mode 100644 index 000000000..2f31805ea --- /dev/null +++ b/numba_cuda/numba/cuda/core/controlflow.py @@ -0,0 +1,989 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import collections +import functools +import sys + +from numba.core.ir import Loc +from numba.core.errors import UnsupportedError +from numba.cuda.utils import PYVERSION + +# List of bytecodes creating a new block in the control flow graph +# (in addition to explicit jump labels). +NEW_BLOCKERS = frozenset( + ["SETUP_LOOP", "FOR_ITER", "SETUP_WITH", "BEFORE_WITH"] +) + + +class CFBlock(object): + def __init__(self, offset): + self.offset = offset + self.body = [] + # A map of jumps to outgoing blocks (successors): + # { offset of outgoing block -> number of stack pops } + self.outgoing_jumps = {} + # A map of jumps to incoming blocks (predecessors): + # { offset of incoming block -> number of stack pops } + self.incoming_jumps = {} + self.terminating = False + + def __repr__(self): + args = ( + self.offset, + sorted(self.outgoing_jumps), + sorted(self.incoming_jumps), + ) + return "block(offset:%d, outgoing: %s, incoming: %s)" % args + + def __iter__(self): + return iter(self.body) + + +class Loop( + collections.namedtuple("Loop", ("entries", "exits", "header", "body")) +): + """ + A control flow loop, as detected by a CFGraph object. + """ + + __slots__ = () + + # The loop header is enough to detect that two loops are really + # the same, assuming they belong to the same graph. + # (note: in practice, only one loop instance is created per graph + # loop, so identity would be fine) + + def __eq__(self, other): + return isinstance(other, Loop) and other.header == self.header + + def __hash__(self): + return hash(self.header) + + +class _DictOfContainers(collections.defaultdict): + """A defaultdict with customized equality checks that ignore empty values. + + Non-empty value is checked by: `bool(value_item) == True`. + """ + + def __eq__(self, other): + if isinstance(other, _DictOfContainers): + mine = self._non_empty_items() + theirs = other._non_empty_items() + return mine == theirs + + return NotImplemented + + def __ne__(self, other): + ret = self.__eq__(other) + if ret is NotImplemented: + return ret + else: + return not ret + + def _non_empty_items(self): + return [(k, vs) for k, vs in sorted(self.items()) if vs] + + +class CFGraph(object): + """ + Generic (almost) implementation of a Control Flow Graph. + """ + + def __init__(self): + self._nodes = set() + self._preds = _DictOfContainers(set) + self._succs = _DictOfContainers(set) + self._edge_data = {} + self._entry_point = None + + def add_node(self, node): + """ + Add *node* to the graph. This is necessary before adding any + edges from/to the node. *node* can be any hashable object. + """ + self._nodes.add(node) + + def add_edge(self, src, dest, data=None): + """ + Add an edge from node *src* to node *dest*, with optional + per-edge *data*. + If such an edge already exists, it is replaced (duplicate edges + are not possible). + """ + if src not in self._nodes: + raise ValueError( + "Cannot add edge as src node %s not in nodes %s" + % (src, self._nodes) + ) + if dest not in self._nodes: + raise ValueError( + "Cannot add edge as dest node %s not in nodes %s" + % (dest, self._nodes) + ) + self._add_edge(src, dest, data) + + def successors(self, src): + """ + Yield (node, data) pairs representing the successors of node *src*. + (*data* will be None if no data was specified when adding the edge) + """ + for dest in self._succs[src]: + yield dest, self._edge_data[src, dest] + + def predecessors(self, dest): + """ + Yield (node, data) pairs representing the predecessors of node *dest*. + (*data* will be None if no data was specified when adding the edge) + """ + for src in self._preds[dest]: + yield src, self._edge_data[src, dest] + + def set_entry_point(self, node): + """ + Set the entry point of the graph to *node*. + """ + assert node in self._nodes + self._entry_point = node + + def process(self): + """ + Compute essential properties of the control flow graph. The graph + must have been fully populated, and its entry point specified. Other + graph properties are computed on-demand. + """ + if self._entry_point is None: + raise RuntimeError("no entry point defined!") + self._eliminate_dead_blocks() + + def dominators(self): + """ + Return a dictionary of {node -> set(nodes)} mapping each node to + the nodes dominating it. + + A node D dominates a node N when any path leading to N must go through D + """ + return self._doms + + def post_dominators(self): + """ + Return a dictionary of {node -> set(nodes)} mapping each node to + the nodes post-dominating it. + + A node P post-dominates a node N when any path starting from N must go + through P. + """ + return self._post_doms + + def immediate_dominators(self): + """ + Return a dictionary of {node -> node} mapping each node to its + immediate dominator (idom). + + The idom(B) is the closest strict dominator of V + """ + return self._idom + + def dominance_frontier(self): + """ + Return a dictionary of {node -> set(nodes)} mapping each node to + the nodes in its dominance frontier. + + The dominance frontier _df(N) is the set of all nodes that are + immediate successors to blocks dominated by N but which aren't + strictly dominated by N + """ + return self._df + + def dominator_tree(self): + """ + return a dictionary of {node -> set(nodes)} mapping each node to + the set of nodes it immediately dominates + + The domtree(B) is the closest strict set of nodes that B dominates + """ + return self._domtree + + @functools.cached_property + def _exit_points(self): + return self._find_exit_points() + + @functools.cached_property + def _doms(self): + return self._find_dominators() + + @functools.cached_property + def _back_edges(self): + return self._find_back_edges() + + @functools.cached_property + def _topo_order(self): + return self._find_topo_order() + + @functools.cached_property + def _descs(self): + return self._find_descendents() + + @functools.cached_property + def _loops(self): + return self._find_loops() + + @functools.cached_property + def _in_loops(self): + return self._find_in_loops() + + @functools.cached_property + def _post_doms(self): + return self._find_post_dominators() + + @functools.cached_property + def _idom(self): + return self._find_immediate_dominators() + + @functools.cached_property + def _df(self): + return self._find_dominance_frontier() + + @functools.cached_property + def _domtree(self): + return self._find_dominator_tree() + + def descendents(self, node): + """ + Return the set of descendents of the given *node*, in topological + order (ignoring back edges). + """ + return self._descs[node] + + def entry_point(self): + """ + Return the entry point node. + """ + assert self._entry_point is not None + return self._entry_point + + def exit_points(self): + """ + Return the computed set of exit nodes (may be empty). + """ + return self._exit_points + + def backbone(self): + """ + Return the set of nodes constituting the graph's backbone. + (i.e. the nodes that every path starting from the entry point + must go through). By construction, it is non-empty: it contains + at least the entry point. + """ + return self._post_doms[self._entry_point] + + def loops(self): + """ + Return a dictionary of {node -> loop} mapping each loop header + to the loop (a Loop instance) starting with it. + """ + return self._loops + + def in_loops(self, node): + """ + Return the list of Loop objects the *node* belongs to, + from innermost to outermost. + """ + return [self._loops[x] for x in self._in_loops.get(node, ())] + + def dead_nodes(self): + """ + Return the set of dead nodes (eliminated from the graph). + """ + return self._dead_nodes + + def nodes(self): + """ + Return the set of live nodes. + """ + return self._nodes + + def topo_order(self): + """ + Return the sequence of nodes in topological order (ignoring back + edges). + """ + return self._topo_order + + def topo_sort(self, nodes, reverse=False): + """ + Iterate over the *nodes* in topological order (ignoring back edges). + The sort isn't guaranteed to be stable. + """ + nodes = set(nodes) + it = self._topo_order + if reverse: + it = reversed(it) + for n in it: + if n in nodes: + yield n + + def dump(self, file=None): + """ + Dump extensive debug information. + """ + import pprint + + file = file or sys.stdout + if 1: + print("CFG adjacency lists:", file=file) + self._dump_adj_lists(file) + print("CFG dominators:", file=file) + pprint.pprint(self._doms, stream=file) + print("CFG post-dominators:", file=file) + pprint.pprint(self._post_doms, stream=file) + print("CFG back edges:", sorted(self._back_edges), file=file) + print("CFG loops:", file=file) + pprint.pprint(self._loops, stream=file) + print("CFG node-to-loops:", file=file) + pprint.pprint(self._in_loops, stream=file) + print("CFG backbone:", file=file) + pprint.pprint(self.backbone(), stream=file) + + def render_dot(self, filename="numba_cfg.dot"): + """Render the controlflow graph with GraphViz DOT via the + ``graphviz`` python binding. + + Returns + ------- + g : graphviz.Digraph + Use `g.view()` to open the graph in the default PDF application. + """ + + try: + import graphviz as gv + except ImportError: + raise ImportError( + "The feature requires `graphviz` but it is not available. " + "Please install with `pip install graphviz`" + ) + g = gv.Digraph(filename=filename) + # Populate the nodes + for n in self._nodes: + g.node(str(n)) + # Populate the edges + for n in self._nodes: + for edge in self._succs[n]: + g.edge(str(n), str(edge)) + return g + + # Internal APIs + + def _add_edge(self, from_, to, data=None): + # This internal version allows adding edges to/from unregistered + # (ghost) nodes. + self._preds[to].add(from_) + self._succs[from_].add(to) + self._edge_data[from_, to] = data + + def _remove_node_edges(self, node): + for succ in self._succs.pop(node, ()): + self._preds[succ].remove(node) + del self._edge_data[node, succ] + for pred in self._preds.pop(node, ()): + self._succs[pred].remove(node) + del self._edge_data[pred, node] + + def _dfs(self, entries=None): + if entries is None: + entries = (self._entry_point,) + seen = set() + stack = list(entries) + while stack: + node = stack.pop() + if node not in seen: + yield node + seen.add(node) + for succ in self._succs[node]: + stack.append(succ) + + def _eliminate_dead_blocks(self): + """ + Eliminate all blocks not reachable from the entry point, and + stash them into self._dead_nodes. + """ + live = set() + for node in self._dfs(): + live.add(node) + self._dead_nodes = self._nodes - live + self._nodes = live + # Remove all edges leading from dead nodes + for dead in self._dead_nodes: + self._remove_node_edges(dead) + + def _find_exit_points(self): + """ + Compute the graph's exit points. + """ + exit_points = set() + for n in self._nodes: + if not self._succs.get(n): + exit_points.add(n) + return exit_points + + def _find_postorder(self): + succs = self._succs + back_edges = self._back_edges + post_order = [] + seen = set() + + post_order = [] + + # DFS + def dfs_rec(node): + if node not in seen: + seen.add(node) + stack.append((post_order.append, node)) + for dest in succs[node]: + if (node, dest) not in back_edges: + stack.append((dfs_rec, dest)) + + stack = [(dfs_rec, self._entry_point)] + while stack: + cb, data = stack.pop() + cb(data) + + return post_order + + def _find_immediate_dominators(self): + # The algorithm implemented computes the immediate dominator + # for each node in the CFG which is equivalent to build a dominator tree + # Based on the implementation from NetworkX + # library - nx.immediate_dominators + # https://github.com/networkx/networkx/blob/858e7cb183541a78969fed0cbcd02346f5866c02/networkx/algorithms/dominance.py # noqa: E501 + # References: + # Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy + # A Simple, Fast Dominance Algorithm + # https://www.cs.rice.edu/~keith/EMBED/dom.pdf + def intersect(u, v): + while u != v: + while idx[u] < idx[v]: + u = idom[u] + while idx[u] > idx[v]: + v = idom[v] + return u + + entry = self._entry_point + preds_table = self._preds + + order = self._find_postorder() + idx = {e: i for i, e in enumerate(order)} # index of each node + idom = {entry: entry} + order.pop() + order.reverse() + + changed = True + while changed: + changed = False + for u in order: + new_idom = functools.reduce( + intersect, (v for v in preds_table[u] if v in idom) + ) + if u not in idom or idom[u] != new_idom: + idom[u] = new_idom + changed = True + + return idom + + def _find_dominator_tree(self): + idom = self._idom + domtree = _DictOfContainers(set) + + for u, v in idom.items(): + # v dominates u + if u not in domtree: + domtree[u] = set() + if u != v: + domtree[v].add(u) + + return domtree + + def _find_dominance_frontier(self): + idom = self._idom + preds_table = self._preds + df = {u: set() for u in idom} + + for u in idom: + if len(preds_table[u]) < 2: + continue + for v in preds_table[u]: + while v != idom[u]: + df[v].add(u) + v = idom[v] + + return df + + def _find_dominators_internal(self, post=False): + # See theoretical description in + # http://en.wikipedia.org/wiki/Dominator_%28graph_theory%29 + # The algorithm implemented here uses a todo-list as described + # in http://pages.cs.wisc.edu/~fischer/cs701.f08/finding.loops.html + if post: + entries = set(self._exit_points) + preds_table = self._succs + succs_table = self._preds + else: + entries = set([self._entry_point]) + preds_table = self._preds + succs_table = self._succs + + if not entries: + raise RuntimeError( + "no entry points: dominator algorithm cannot be seeded" + ) + + doms = {} + for e in entries: + doms[e] = set([e]) + + todo = [] + for n in self._nodes: + if n not in entries: + doms[n] = set(self._nodes) + todo.append(n) + + while todo: + n = todo.pop() + if n in entries: + continue + new_doms = set([n]) + preds = preds_table[n] + if preds: + new_doms |= functools.reduce( + set.intersection, [doms[p] for p in preds] + ) + if new_doms != doms[n]: + assert len(new_doms) < len(doms[n]) + doms[n] = new_doms + todo.extend(succs_table[n]) + return doms + + def _find_dominators(self): + return self._find_dominators_internal(post=False) + + def _find_post_dominators(self): + # To handle infinite loops correctly, we need to add a dummy + # exit point, and link members of infinite loops to it. + dummy_exit = object() + self._exit_points.add(dummy_exit) + for loop in self._loops.values(): + if not loop.exits: + for b in loop.body: + self._add_edge(b, dummy_exit) + pdoms = self._find_dominators_internal(post=True) + # Fix the _post_doms table to make no reference to the dummy exit + del pdoms[dummy_exit] + for doms in pdoms.values(): + doms.discard(dummy_exit) + self._remove_node_edges(dummy_exit) + self._exit_points.remove(dummy_exit) + return pdoms + + # Finding loops and back edges: see + # http://pages.cs.wisc.edu/~fischer/cs701.f08/finding.loops.html + + def _find_back_edges(self, stats=None): + """ + Find back edges. An edge (src, dest) is a back edge if and + only if *dest* dominates *src*. + """ + # Prepare stats to capture execution information + if stats is not None: + if not isinstance(stats, dict): + raise TypeError(f"*stats* must be a dict; got {type(stats)}") + stats.setdefault("iteration_count", 0) + + # Uses a simple DFS to find back-edges. + # The new algorithm is faster than the the previous dominator based + # algorithm. + back_edges = set() + # stack: keeps track of the traversal path + stack = [] + # succs_state: keep track of unvisited successors of a node + succs_state = {} + entry_point = self.entry_point() + + checked = set() + + def push_state(node): + stack.append(node) + succs_state[node] = [dest for dest in self._succs[node]] + + push_state(entry_point) + + # Keep track for iteration count for debugging + iter_ct = 0 + while stack: + iter_ct += 1 + tos = stack[-1] + tos_succs = succs_state[tos] + # Are there successors not checked? + if tos_succs: + # Check the next successor + cur_node = tos_succs.pop() + # Is it in our traversal path? + if cur_node in stack: + # Yes, it's a backedge + back_edges.add((tos, cur_node)) + elif cur_node not in checked: + # Push + push_state(cur_node) + else: + # Checked all successors. Pop + stack.pop() + checked.add(tos) + + if stats is not None: + stats["iteration_count"] += iter_ct + return back_edges + + def _find_topo_order(self): + succs = self._succs + back_edges = self._back_edges + post_order = [] + seen = set() + + def _dfs_rec(node): + if node not in seen: + seen.add(node) + for dest in succs[node]: + if (node, dest) not in back_edges: + _dfs_rec(dest) + post_order.append(node) + + _dfs_rec(self._entry_point) + post_order.reverse() + return post_order + + def _find_descendents(self): + descs = {} + for node in reversed(self._topo_order): + descs[node] = node_descs = set() + for succ in self._succs[node]: + if (node, succ) not in self._back_edges: + node_descs.add(succ) + node_descs.update(descs[succ]) + return descs + + def _find_loops(self): + """ + Find the loops defined by the graph's back edges. + """ + bodies = {} + for src, dest in self._back_edges: + # The destination of the back edge is the loop header + header = dest + # Build up the loop body from the back edge's source node, + # up to the source header. + body = set([header]) + queue = [src] + while queue: + n = queue.pop() + if n not in body: + body.add(n) + queue.extend(self._preds[n]) + # There can be several back edges to a given loop header; + # if so, merge the resulting body fragments. + if header in bodies: + bodies[header].update(body) + else: + bodies[header] = body + + # Create a Loop object for each header. + loops = {} + for header, body in bodies.items(): + entries = set() + exits = set() + for n in body: + entries.update(self._preds[n] - body) + exits.update(self._succs[n] - body) + loop = Loop(header=header, body=body, entries=entries, exits=exits) + loops[header] = loop + return loops + + def _find_in_loops(self): + loops = self._loops + # Compute the loops to which each node belongs. + in_loops = dict((n, []) for n in self._nodes) + # Sort loops from longest to shortest + # This ensures that outer loops will come before inner loops + for loop in sorted(loops.values(), key=lambda loop: len(loop.body)): + for n in loop.body: + in_loops[n].append(loop.header) + return in_loops + + def _dump_adj_lists(self, file): + adj_lists = dict( + (src, sorted(list(dests))) for src, dests in self._succs.items() + ) + import pprint + + pprint.pprint(adj_lists, stream=file) + + def __eq__(self, other): + if not isinstance(other, CFGraph): + return NotImplemented + + for x in ["_nodes", "_edge_data", "_entry_point", "_preds", "_succs"]: + this = getattr(self, x, None) + that = getattr(other, x, None) + if this != that: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + +class ControlFlowAnalysis(object): + """ + Attributes + ---------- + - bytecode + + - blocks + + - blockseq + + - doms: dict of set + Dominators + + - backbone: set of block offsets + The set of block that is common to all possible code path. + + """ + + def __init__(self, bytecode): + self.bytecode = bytecode + self.blocks = {} + self.liveblocks = {} + self.blockseq = [] + self.doms = None + self.backbone = None + # Internal temp states + self._force_new_block = True + self._curblock = None + self._blockstack = [] + self._loops = [] + self._withs = [] + + def iterblocks(self): + """ + Return all blocks in sequence of occurrence + """ + for i in self.blockseq: + yield self.blocks[i] + + def iterliveblocks(self): + """ + Return all live blocks in sequence of occurrence + """ + for i in self.blockseq: + if i in self.liveblocks: + yield self.blocks[i] + + def incoming_blocks(self, block): + """ + Yield (incoming block, number of stack pops) pairs for *block*. + """ + for i, pops in block.incoming_jumps.items(): + if i in self.liveblocks: + yield self.blocks[i], pops + + def dump(self, file=None): + self.graph.dump(file=None) + + def run(self): + for inst in self._iter_inst(): + fname = "op_%s" % inst.opname + fn = getattr(self, fname, None) + if fn is not None: + fn(inst) + elif inst.is_jump: + # this catches e.g. try... except + l = Loc(self.bytecode.func_id.filename, inst.lineno) + if inst.opname in {"SETUP_FINALLY"}: + msg = "'try' block not supported until python3.7 or later" + else: + msg = "Use of unsupported opcode (%s) found" % inst.opname + raise UnsupportedError(msg, loc=l) + else: + # Non-jump instructions are ignored + pass # intentionally + + # Close all blocks + for cur, nxt in zip(self.blockseq, self.blockseq[1:]): + blk = self.blocks[cur] + if not blk.outgoing_jumps and not blk.terminating: + blk.outgoing_jumps[nxt] = 0 + + graph = CFGraph() + for b in self.blocks: + graph.add_node(b) + for b in self.blocks.values(): + for out, pops in b.outgoing_jumps.items(): + graph.add_edge(b.offset, out, pops) + graph.set_entry_point(min(self.blocks)) + graph.process() + self.graph = graph + + # Fill incoming + for b in self.blocks.values(): + for out, pops in b.outgoing_jumps.items(): + self.blocks[out].incoming_jumps[b.offset] = pops + + # Find liveblocks + self.liveblocks = dict((i, self.blocks[i]) for i in self.graph.nodes()) + + for lastblk in reversed(self.blockseq): + if lastblk in self.liveblocks: + break + else: + raise AssertionError("No live block that exits!?") + + # Find backbone + backbone = self.graph.backbone() + # Filter out in loop blocks (Assuming no other cyclic control blocks) + # This is to unavoid variable defined in loops to be considered as + # function scope. + inloopblocks = set() + + for b in self.blocks.keys(): + if self.graph.in_loops(b): + inloopblocks.add(b) + + self.backbone = backbone - inloopblocks + + def jump(self, target, pops=0): + """ + Register a jump (conditional or not) to *target* offset. + *pops* is the number of stack pops implied by the jump (default 0). + """ + self._curblock.outgoing_jumps[target] = pops + + def _iter_inst(self): + for inst in self.bytecode: + if self._use_new_block(inst): + self._guard_with_as(inst) + self._start_new_block(inst) + self._curblock.body.append(inst.offset) + yield inst + + def _use_new_block(self, inst): + if inst.offset in self.bytecode.labels: + res = True + elif inst.opname in NEW_BLOCKERS: + res = True + else: + res = self._force_new_block + + self._force_new_block = False + return res + + def _start_new_block(self, inst): + self._curblock = CFBlock(inst.offset) + self.blocks[inst.offset] = self._curblock + self.blockseq.append(inst.offset) + + def _guard_with_as(self, current_inst): + """Checks if the next instruction after a SETUP_WITH is something other + than a POP_TOP, if it is something else it'll be some sort of store + which is not supported (this corresponds to `with CTXMGR as VAR(S)`).""" + if current_inst.opname == "SETUP_WITH": + next_op = self.bytecode[current_inst.next].opname + if next_op != "POP_TOP": + msg = ( + "The 'with (context manager) as " + "(variable):' construct is not " + "supported." + ) + raise UnsupportedError(msg) + + def op_SETUP_LOOP(self, inst): + end = inst.get_jump_target() + self._blockstack.append(end) + self._loops.append((inst.offset, end)) + # TODO: Looplifting requires the loop entry be its own block. + # Forcing a new block here is the simplest solution for now. + # But, we should consider other less ad-hoc ways. + self.jump(inst.next) + self._force_new_block = True + + def op_SETUP_WITH(self, inst): + end = inst.get_jump_target() + self._blockstack.append(end) + self._withs.append((inst.offset, end)) + # TODO: WithLifting requires the loop entry be its own block. + # Forcing a new block here is the simplest solution for now. + # But, we should consider other less ad-hoc ways. + self.jump(inst.next) + self._force_new_block = True + + def op_POP_BLOCK(self, inst): + self._blockstack.pop() + + def op_FOR_ITER(self, inst): + self.jump(inst.get_jump_target()) + self.jump(inst.next) + self._force_new_block = True + + def _op_ABSOLUTE_JUMP_IF(self, inst): + self.jump(inst.get_jump_target()) + self.jump(inst.next) + self._force_new_block = True + + op_POP_JUMP_IF_FALSE = _op_ABSOLUTE_JUMP_IF + op_POP_JUMP_IF_TRUE = _op_ABSOLUTE_JUMP_IF + op_JUMP_IF_FALSE = _op_ABSOLUTE_JUMP_IF + op_JUMP_IF_TRUE = _op_ABSOLUTE_JUMP_IF + + op_POP_JUMP_FORWARD_IF_FALSE = _op_ABSOLUTE_JUMP_IF + op_POP_JUMP_BACKWARD_IF_FALSE = _op_ABSOLUTE_JUMP_IF + op_POP_JUMP_FORWARD_IF_TRUE = _op_ABSOLUTE_JUMP_IF + op_POP_JUMP_BACKWARD_IF_TRUE = _op_ABSOLUTE_JUMP_IF + + def _op_ABSOLUTE_JUMP_OR_POP(self, inst): + self.jump(inst.get_jump_target()) + self.jump(inst.next, pops=1) + self._force_new_block = True + + op_JUMP_IF_FALSE_OR_POP = _op_ABSOLUTE_JUMP_OR_POP + op_JUMP_IF_TRUE_OR_POP = _op_ABSOLUTE_JUMP_OR_POP + + def op_JUMP_ABSOLUTE(self, inst): + self.jump(inst.get_jump_target()) + self._force_new_block = True + + def op_JUMP_FORWARD(self, inst): + self.jump(inst.get_jump_target()) + self._force_new_block = True + + op_JUMP_BACKWARD = op_JUMP_FORWARD + + def op_RETURN_VALUE(self, inst): + self._curblock.terminating = True + self._force_new_block = True + + if PYVERSION in ((3, 12), (3, 13)): + + def op_RETURN_CONST(self, inst): + self._curblock.terminating = True + self._force_new_block = True + elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + pass + else: + raise NotImplementedError(PYVERSION) + + def op_RAISE_VARARGS(self, inst): + self._curblock.terminating = True + self._force_new_block = True + + def op_BREAK_LOOP(self, inst): + self.jump(self._blockstack[-1]) + self._force_new_block = True diff --git a/numba_cuda/numba/cuda/core/inline_closurecall.py b/numba_cuda/numba/cuda/core/inline_closurecall.py index 37e03a5e9..34899f9df 100644 --- a/numba_cuda/numba/cuda/core/inline_closurecall.py +++ b/numba_cuda/numba/cuda/core/inline_closurecall.py @@ -688,9 +688,9 @@ def inline_closure_call( # first, get the IR of the callee if isinstance(callee, pytypes.FunctionType): - from numba.cuda.core import compiler + from numba.cuda.compiler import run_frontend - callee_ir = compiler.run_frontend(callee, inline_closures=True) + callee_ir = run_frontend(callee, inline_closures=True) else: callee_ir = get_ir_of_code(glbls, callee_code) diff --git a/numba_cuda/numba/cuda/core/interpreter.py b/numba_cuda/numba/cuda/core/interpreter.py index bce399842..4cf566bdf 100644 --- a/numba_cuda/numba/cuda/core/interpreter.py +++ b/numba_cuda/numba/cuda/core/interpreter.py @@ -22,14 +22,14 @@ INPLACE_BINOPS_TO_OPERATORS, ) from numba.cuda.utils import _lazy_pformat -from numba.core.byteflow import Flow, AdaptDFA, AdaptCFA, BlockKind +from numba.cuda.core.byteflow import Flow, AdaptDFA, AdaptCFA, BlockKind from numba.cuda.core.unsafe import eh from numba.cpython.unsafe.tuple import unpack_single_tuple if PYVERSION in ((3, 12), (3, 13)): # Operands for CALL_INTRINSIC_1 - from numba.core.byteflow import CALL_INTRINSIC_1_Operand as ci1op + from numba.cuda.core.byteflow import CALL_INTRINSIC_1_Operand as ci1op elif PYVERSION in ((3, 9), (3, 10), (3, 11)): pass else: diff --git a/numba_cuda/numba/cuda/core/ir_utils.py b/numba_cuda/numba/cuda/core/ir_utils.py index 9865063a9..f72253821 100644 --- a/numba_cuda/numba/cuda/core/ir_utils.py +++ b/numba_cuda/numba/cuda/core/ir_utils.py @@ -10,11 +10,11 @@ import warnings import numba -from numba.core import types, ir, analysis +from numba.core import types, ir from numba.cuda import typing -from numba.cuda.core import postproc, rewrites, config +from numba.cuda.core import analysis, postproc, rewrites, config from numba.core.typing.templates import signature -from numba.core.analysis import ( +from numba.cuda.core.analysis import ( compute_live_map, compute_use_defs, compute_cfg_from_blocks, diff --git a/numba_cuda/numba/cuda/core/ssa.py b/numba_cuda/numba/cuda/core/ssa.py index f4d1a8852..4c58c7f48 100644 --- a/numba_cuda/numba/cuda/core/ssa.py +++ b/numba_cuda/numba/cuda/core/ssa.py @@ -18,9 +18,10 @@ from collections import defaultdict from numba.cuda import config -from numba.core import ir, ir_utils, errors +from numba.core import ir, errors +from numba.cuda.core import ir_utils from numba.cuda.utils import OrderedSet, _lazy_pformat -from numba.core.analysis import compute_cfg_from_blocks +from numba.cuda.core.analysis import compute_cfg_from_blocks _logger = logging.getLogger(__name__) diff --git a/numba_cuda/numba/cuda/core/untyped_passes.py b/numba_cuda/numba/cuda/core/untyped_passes.py index e4dc35ff6..8a8956266 100644 --- a/numba_cuda/numba/cuda/core/untyped_passes.py +++ b/numba_cuda/numba/cuda/core/untyped_passes.py @@ -24,7 +24,7 @@ from numba.cuda.misc.special import literal_unroll from numba.cuda.core.analysis import dead_branch_prune -from numba.core.analysis import ( +from numba.cuda.core.analysis import ( rewrite_semantic_constants, find_literally_calls, compute_cfg_from_blocks, diff --git a/numba_cuda/numba/cuda/lowering.py b/numba_cuda/numba/cuda/lowering.py index 570c50aa5..a028e6868 100644 --- a/numba_cuda/numba/cuda/lowering.py +++ b/numba_cuda/numba/cuda/lowering.py @@ -28,7 +28,7 @@ ) from numba.cuda.core.funcdesc import default_mangler from numba.cuda.core.environment import Environment -from numba.core.analysis import compute_use_defs, must_use_alloca +from numba.cuda.core.analysis import compute_use_defs, must_use_alloca from numba.cuda.misc.firstlinefinder import get_func_body_first_lineno from numba import version_info diff --git a/numba_cuda/numba/cuda/simulator/compiler.py b/numba_cuda/numba/cuda/simulator/compiler.py index 1bf5a3af0..d2fb21fea 100644 --- a/numba_cuda/numba/cuda/simulator/compiler.py +++ b/numba_cuda/numba/cuda/simulator/compiler.py @@ -14,5 +14,25 @@ compile_all = None -def run_frontend(func, inline_closures=False, emit_dels=False): +def run_frontend(func): pass + + +class DefaultPassBuilder(object): + @staticmethod + def define_nopython_lowering_pipeline(state, name="nopython_lowering"): + pass + + @staticmethod + def define_typed_pipeline(state, name="typed"): + pass + + +class CompilerBase: + def __init__( + self, typingctx, targetctx, library, args, return_type, flags, locals + ): + pass + + +PassManager = None diff --git a/numba_cuda/numba/cuda/simulator/tests/support.py b/numba_cuda/numba/cuda/simulator/tests/support.py new file mode 100644 index 000000000..661b7a58c --- /dev/null +++ b/numba_cuda/numba/cuda/simulator/tests/support.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +MemoryLeakMixin = None diff --git a/numba_cuda/numba/cuda/tests/test_analysis.py b/numba_cuda/numba/cuda/tests/test_analysis.py new file mode 100644 index 000000000..c507438b1 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/test_analysis.py @@ -0,0 +1,1122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +# Tests numba.analysis functions +import collections +import types as pytypes + +import numpy as np +from numba.cuda.compiler import run_frontend +from numba.cuda.flags import Flags +from numba.cuda.core.compiler import StateDict +from numba.cuda import jit +from numba.core import types, errors, ir +from numba.cuda.utils import PYVERSION +from numba.cuda.core import postproc, rewrites, ir_utils +from numba.cuda.core.options import ParallelOptions +from numba.cuda.core.inline_closurecall import InlineClosureCallPass +from numba.cuda.tests.support import ( + TestCase, + MemoryLeakMixin, +) +from numba.cuda.core.analysis import ( + dead_branch_prune, + rewrite_semantic_constants, +) +from numba.cuda.core.untyped_passes import ( + ReconstructSSA, +) +import unittest +from numba.cuda.core import config + +_GLOBAL = 123 + +enable_pyobj_flags = Flags() +enable_pyobj_flags.enable_pyobject = True + +if config.ENABLE_CUDASIM: + raise unittest.SkipTest("Analysis passes not done in simulator") + + +def compile_to_ir(func): + func_ir = run_frontend(func) + state = StateDict() + state.func_ir = func_ir + state.typemap = None + state.calltypes = None + # Transform to SSA + ReconstructSSA().run_pass(state) + # call this to get print etc rewrites + rewrites.rewrite_registry.apply("before-inference", state) + return func_ir + + +class TestBranchPruneBase(MemoryLeakMixin, TestCase): + """ + Tests branch pruning + """ + + _DEBUG = False + + # find *all* branches + def find_branches(self, the_ir): + branches = [] + for blk in the_ir.blocks.values(): + tmp = [_ for _ in blk.find_insts(cls=ir.Branch)] + branches.extend(tmp) + return branches + + def assert_prune(self, func, args_tys, prune, *args, **kwargs): + # This checks that the expected pruned branches have indeed been pruned. + # func is a python function to assess + # args_tys is the numba types arguments tuple + # prune arg is a list, one entry per branch. The value in the entry is + # encoded as follows: + # True: using constant inference only, the True branch will be pruned + # False: using constant inference only, the False branch will be pruned + # None: under no circumstances should this branch be pruned + # *args: the argument instances to pass to the function to check + # execution is still valid post transform + # **kwargs: + # - flags: args to pass to `jit` default is `nopython=True`, + # e.g. permits use of e.g. object mode. + + func_ir = compile_to_ir(func) + before = func_ir.copy() + if self._DEBUG: + print("=" * 80) + print("before inline") + func_ir.dump() + + # run closure inlining to ensure that nonlocals in closures are visible + inline_pass = InlineClosureCallPass( + func_ir, + ParallelOptions(False), + ) + inline_pass.run() + + # Remove all Dels, and re-run postproc + post_proc = postproc.PostProcessor(func_ir) + post_proc.run() + + rewrite_semantic_constants(func_ir, args_tys) + if self._DEBUG: + print("=" * 80) + print("before prune") + func_ir.dump() + + dead_branch_prune(func_ir, args_tys) + + after = func_ir + if self._DEBUG: + print("after prune") + func_ir.dump() + + before_branches = self.find_branches(before) + self.assertEqual(len(before_branches), len(prune)) + + # what is expected to be pruned + expect_removed = [] + for idx, prune in enumerate(prune): + branch = before_branches[idx] + if prune is True: + expect_removed.append(branch.truebr) + elif prune is False: + expect_removed.append(branch.falsebr) + elif prune is None: + pass # nothing should be removed! + elif prune == "both": + expect_removed.append(branch.falsebr) + expect_removed.append(branch.truebr) + else: + assert 0, "unreachable" + + # compare labels + original_labels = set([_ for _ in before.blocks.keys()]) + new_labels = set([_ for _ in after.blocks.keys()]) + # assert that the new labels are precisely the original less the + # expected pruned labels + try: + self.assertEqual(new_labels, original_labels - set(expect_removed)) + except AssertionError as e: + print("new_labels", sorted(new_labels)) + print("original_labels", sorted(original_labels)) + print("expect_removed", sorted(expect_removed)) + raise e + + if [ + arg is types.NoneType("none") or arg is types.Omitted(None) + for arg in args_tys + ].count(True) == 0: + self.run_func(func, args) + + def run_func(self, impl, args): + cres = jit(impl) + dargs = args + out = np.zeros(1) + cout = np.zeros(1) + args += (out,) + dargs += (cout,) + cres.py_func(*args) + cres[1, 1](*dargs) + self.assertPreciseEqual(out[0], cout[0]) + + +class TestBranchPrune(TestBranchPruneBase): + def test_single_if(self): + def impl(x, res): + if 1 == 0: + res[0] = 3.14159 + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [True], + None, + ) + + def impl(x, res): + if 1 == 1: + res[0] = 3.14159 + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [False], + None, + ) + + def impl(x, res): + if x is None: + res[0] = 3.14159 + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [False], + None, + ) + self.assert_prune( + impl, + (types.IntegerLiteral(10), types.Array(types.float64, 1, "C")), + [True], + 10, + ) + + def impl(x, res): + if x == 10: + res[0] = 3.14159 + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [True], + None, + ) + self.assert_prune( + impl, + (types.IntegerLiteral(10), types.Array(types.float64, 1, "C")), + [None], + 10, + ) + + def impl(x, res): + if x == 10: + z = 3.14159 # noqa: F841 # no effect + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [True], + None, + ) + self.assert_prune( + impl, + (types.IntegerLiteral(10), types.Array(types.float64, 1, "C")), + [None], + 10, + ) + + def impl(x, res): + z = None + y = z + if x == y: + res[0] = 100 + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [False], + None, + ) + self.assert_prune( + impl, + (types.IntegerLiteral(10), types.Array(types.float64, 1, "C")), + [True], + 10, + ) + + def test_single_if_else(self): + def impl(x, res): + if x is None: + res[0] = 3.14159 + else: + res[0] = 1.61803 + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [False], + None, + ) + self.assert_prune( + impl, + (types.IntegerLiteral(10), types.Array(types.float64, 1, "C")), + [True], + 10, + ) + + def test_single_if_const_val(self): + def impl(x, res): + if x == 100: + res[0] = 3.14159 + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [True], + None, + ) + self.assert_prune( + impl, + (types.IntegerLiteral(100), types.Array(types.float64, 1, "C")), + [None], + 100, + ) + + def impl(x, res): + # switch the condition order + if 100 == x: + res[0] = 3.14159 + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [True], + None, + ) + self.assert_prune( + impl, + (types.IntegerLiteral(100), types.Array(types.float64, 1, "C")), + [None], + 100, + ) + + def test_single_if_else_two_const_val(self): + def impl(x, y, res): + if x == y: + res[0] = 3.14159 + else: + res[0] = 1.61803 + + self.assert_prune( + impl, + (types.IntegerLiteral(100),) * 2 + + (types.Array(types.float64, 1, "C"),), + [None], + 100, + 100, + ) + self.assert_prune( + impl, + (types.NoneType("none"),) * 2 + + (types.Array(types.float64, 1, "C"),), + [False], + None, + None, + ) + self.assert_prune( + impl, + ( + types.IntegerLiteral(100), + types.NoneType("none"), + types.Array(types.float64, 1, "C"), + ), + [True], + 100, + None, + ) + self.assert_prune( + impl, + ( + types.IntegerLiteral(100), + types.IntegerLiteral(1000), + types.Array(types.float64, 1, "C"), + ), + [None], + 100, + 1000, + ) + + def test_single_if_else_w_following_undetermined(self): + def impl(x, res): + x_is_none_work = False + if x is None: + x_is_none_work = True + else: + dead = 7 # noqa: F841 # no effect + + if x_is_none_work: + y = 10 + else: + y = -3 + res[0] = y + + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [False, None], + None, + ) + self.assert_prune( + impl, + (types.IntegerLiteral(10), types.Array(types.float64, 1, "C")), + [True, None], + 10, + ) + + def impl(x, res): + x_is_none_work = False + if x is None: + x_is_none_work = True + else: + pass + + if x_is_none_work: + y = 10 + else: + y = -3 + res[0] = y + + # Python 3.10 creates a block with a NOP in it for the `pass` which + # means it gets pruned. + if PYVERSION >= (3, 10): + # Python 3.10 creates a block with a NOP in it for the `pass` which + # means it gets pruned. + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [False, None], + None, + ) + else: + self.assert_prune( + impl, + (types.NoneType("none"), types.Array(types.float64, 1, "C")), + [None, None], + None, + ) + + self.assert_prune( + impl, + (types.IntegerLiteral(10), types.Array(types.float64, 1, "C")), + [True, None], + 10, + ) + + def test_double_if_else_rt_const(self): + def impl(x, res): + one_hundred = 100 + x_is_none_work = 4 + if x is None: + x_is_none_work = 100 + else: + dead = 7 # noqa: F841 # no effect + + if x_is_none_work == one_hundred: + y = 10 + else: + y = -3 + + res[0] = y + x_is_none_work + + self.assert_prune(impl, (types.NoneType("none"),), [False, None], None) + self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10) + + def test_double_if_else_non_literal_const(self): + def impl(x, res): + one_hundred = 100 + if x == one_hundred: + res[0] = 3.14159 + else: + res[0] = 1.61803 + + # no prune as compilation specialization on literal value not permitted + self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10) + self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100) + + def test_single_two_branches_same_cond(self): + def impl(x, res): + if x is None: + y = 10 + else: + y = 40 + + if x is not None: + z = 100 + else: + z = 400 + + res[0] = z + y + + self.assert_prune(impl, (types.NoneType("none"),), [False, True], None) + self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10) + + def test_cond_is_kwarg_none(self): + def impl(x=None, res=None): + if x is None: + y = 10 + else: + y = 40 + + if x is not None: + z = 100 + else: + z = 400 + + res[0] = z + y + + self.assert_prune(impl, (types.Omitted(None),), [False, True], None) + self.assert_prune(impl, (types.NoneType("none"),), [False, True], None) + self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10) + + def test_cond_is_kwarg_value(self): + def impl(x=1000, res=None): + if x == 1000: + y = 10 + else: + y = 40 + + if x != 1000: + z = 100 + else: + z = 400 + + res[0] = z + y + + self.assert_prune(impl, (types.Omitted(1000),), [None, None], 1000) + self.assert_prune( + impl, (types.IntegerLiteral(1000),), [None, None], 1000 + ) + self.assert_prune(impl, (types.IntegerLiteral(0),), [None, None], 0) + self.assert_prune(impl, (types.NoneType("none"),), [True, False], None) + + def test_cond_rewrite_is_correct(self): + # this checks that when a condition is replaced, it is replace by a + # true/false bit that correctly represents the evaluated condition + def fn(x): + if x is None: + return 10 + return 12 + + def check(func, arg_tys, bit_val): + func_ir = compile_to_ir(func) + + # check there is 1 branch + before_branches = self.find_branches(func_ir) + self.assertEqual(len(before_branches), 1) + + # check the condition in the branch is a binop + pred_var = before_branches[0].cond + pred_defn = ir_utils.get_definition(func_ir, pred_var) + self.assertEqual(pred_defn.op, "call") + condition_var = pred_defn.args[0] + condition_op = ir_utils.get_definition(func_ir, condition_var) + self.assertEqual(condition_op.op, "binop") + + # do the prune, this should kill the dead branch and rewrite the + #'condition to a true/false const bit + if self._DEBUG: + print("=" * 80) + print("before prune") + func_ir.dump() + dead_branch_prune(func_ir, arg_tys) + if self._DEBUG: + print("=" * 80) + print("after prune") + func_ir.dump() + + # after mutation, the condition should be a const value `bit_val` + new_condition_defn = ir_utils.get_definition(func_ir, condition_var) + self.assertTrue(isinstance(new_condition_defn, ir.Const)) + self.assertEqual(new_condition_defn.value, bit_val) + + check(fn, (types.NoneType("none"),), 1) + check(fn, (types.IntegerLiteral(10),), 0) + + def test_global_bake_in(self): + def impl(x, res): + if _GLOBAL == 123: + res[0] = x + else: + res[0] = x + 10 + + self.assert_prune( + impl, + (types.IntegerLiteral(1), types.Array(types.float64, 1, "C")), + [False], + 1, + ) + + global _GLOBAL + tmp = _GLOBAL + + try: + _GLOBAL = 5 + + def impl(x, res): + if _GLOBAL == 123: + res[0] = x + else: + res[0] = x + 10 + + self.assert_prune( + impl, + (types.IntegerLiteral(1), types.Array(types.float64, 1, "C")), + [True], + 1, + ) + finally: + _GLOBAL = tmp + + def test_freevar_bake_in(self): + _FREEVAR = 123 + + def impl(x, res): + if _FREEVAR == 123: + res[0] = x + else: + res[0] = x + 10 + + self.assert_prune( + impl, + (types.IntegerLiteral(1), types.Array(types.float64, 1, "C")), + [False], + 1, + ) + + _FREEVAR = 12 + + def impl(x, res): + if _FREEVAR == 123: + res[0] = x + else: + res[0] = x + 10 + + self.assert_prune( + impl, + (types.IntegerLiteral(1), types.Array(types.float64, 1, "C")), + [True], + 1, + ) + + def test_redefined_variables_are_not_considered_in_prune(self): + # see issue #4163, checks that if a variable that is an argument is + # redefined in the user code it is not considered const + + def impl(array, a=None, res=None): + if a is None: + a = 0 + if a < 0: + res[0] = 10 + res[0] = 30 + + self.assert_prune( + impl, + ( + types.Array(types.float64, 2, "C"), + types.NoneType("none"), + types.Array(types.float64, 1, "C"), + ), + [None, None], + np.zeros((2, 3)), + None, + ) + + def test_redefinition_analysis_same_block(self): + # checks that a redefinition in a block with prunable potential doesn't + # break + + def impl(array, x, a=None, res=None): + b = 2 + if x < 4: + b = 12 + if a is None: # known true + a = 7 # live + else: + b = 15 # dead + if a < 0: # valid as a result of the redefinition of 'a' + res[0] = 10 + res[0] = 30 + b + a + + self.assert_prune( + impl, + ( + types.Array(types.float64, 2, "C"), + types.float64, + types.NoneType("none"), + types.Array(types.float64, 1, "C"), + ), + [None, False, None], + np.zeros((2, 3)), + 1.0, + None, + ) + + def test_redefinition_analysis_different_block_can_exec(self): + # checks that a redefinition in a block that may be executed prevents + # pruning + + def impl(array, x, res): + b = 0 + if x > 5: + a = 11 # a redefined, cannot tell statically if this will exec + if x < 4: + b = 12 + if a is None: # cannot prune, cannot determine if re-defn occurred + b += 5 + else: + b += 7 + if a < 0: + res[0] = 10 + res[0] = 30 + b + + self.assert_prune( + impl, + ( + types.Array(types.float64, 2, "C"), + types.float64, + types.NoneType("none"), + types.Array(types.float64, 1, "C"), + ), + [None, None, None, None], + np.zeros((2, 3)), + 1.0, + None, + ) + + def test_redefinition_analysis_different_block_cannot_exec(self): + # checks that a redefinition in a block guarded by something that + # has prune potential + + def impl(array, x=None, a=None, res=None): + b = 0 + if x is not None: + a = 11 + if a is None: + b += 5 + else: + b += 7 + res[0] = 30 + b + + self.assert_prune( + impl, + ( + types.Array(types.float64, 2, "C"), + types.NoneType("none"), + types.NoneType("none"), + types.Array(types.float64, 1, "C"), + ), + [True, None], + np.zeros((2, 3)), + None, + None, + ) + + self.assert_prune( + impl, + ( + types.Array(types.float64, 2, "C"), + types.NoneType("none"), + types.float64, + types.Array(types.float64, 1, "C"), + ), + [True, None], + np.zeros((2, 3)), + None, + 1.2, + ) + + self.assert_prune( + impl, + ( + types.Array(types.float64, 2, "C"), + types.float64, + types.NoneType("none"), + types.Array(types.float64, 1, "C"), + ), + [None, None], + np.zeros((2, 3)), + 1.2, + None, + ) + + def test_closure_and_nonlocal_can_prune(self): + # Closures must be inlined ahead of branch pruning in case nonlocal + # is used. See issue #6585. + def impl(res): + x = 1000 + + def closure(): + nonlocal x + x = 0 + + closure() + + if x == 0: + res[0] = True + else: + res[0] = False + + self.assert_prune( + impl, + (types.Array(types.float64, 1, "C"),), + [ + False, + ], + ) + + def test_closure_and_nonlocal_cannot_prune(self): + # Closures must be inlined ahead of branch pruning in case nonlocal + # is used. See issue #6585. + def impl(n, res): + x = 1000 + + def closure(t): + nonlocal x + x = t + + closure(n) + + if x == 0: + res[0] = True + else: + res[0] = False + + self.assert_prune( + impl, + (types.int64, types.Array(types.float64, 1, "C")), + [ + None, + ], + 1, + ) + + +class TestBranchPrunePredicates(TestBranchPruneBase): + # Really important thing to remember... the branch on predicates end up as + # POP_JUMP_IF_ and the targets are backwards compared to normal, i.e. + # the true condition is far jump and the false the near i.e. `if x` would + # end up in Numba IR as e.g. `branch x 10, 6`. + + _TRUTHY = (1, "String", True, 7.4, 3j) + _FALSEY = (0, "", False, 0.0, 0j, None) + + def _literal_const_sample_generator(self, pyfunc, consts): + """ + This takes a python function, pyfunc, and manipulates its co_const + __code__ member to create a new function with different co_consts as + supplied in argument consts. + + consts is a dict {index: value} of co_const tuple index to constant + value used to update a pyfunc clone's co_const. + """ + pyfunc_code = pyfunc.__code__ + + # translate consts spec to update the constants + co_consts = {k: v for k, v in enumerate(pyfunc_code.co_consts)} + for k, v in consts.items(): + co_consts[k] = v + new_consts = tuple([v for _, v in sorted(co_consts.items())]) + + # create code object with mutation + new_code = pyfunc_code.replace(co_consts=new_consts) + + # get function + return pytypes.FunctionType(new_code, globals()) + + def test_literal_const_code_gen(self): + def impl(x): + _CONST1 = "PLACEHOLDER1" + if _CONST1: + return 3.14159 + else: + _CONST2 = "PLACEHOLDER2" + return _CONST2 + 4 + + new = self._literal_const_sample_generator(impl, {1: 0, 3: 20}) + iconst = impl.__code__.co_consts + nconst = new.__code__.co_consts + self.assertEqual( + iconst, (None, "PLACEHOLDER1", 3.14159, "PLACEHOLDER2", 4) + ) + self.assertEqual(nconst, (None, 0, 3.14159, 20, 4)) + self.assertEqual(impl(None), 3.14159) + self.assertEqual(new(None), 24) + + def test_single_if_const(self): + def impl(x): + _CONST1 = "PLACEHOLDER1" + if _CONST1: + return 3.14159 + + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for const in c_inp: + func = self._literal_const_sample_generator(impl, {1: const}) + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_negate_const(self): + def impl(x): + _CONST1 = "PLACEHOLDER1" + if not _CONST1: + return 3.14159 + + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for const in c_inp: + func = self._literal_const_sample_generator(impl, {1: const}) + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_else_const(self): + def impl(x): + _CONST1 = "PLACEHOLDER1" + if _CONST1: + return 3.14159 + else: + return 1.61803 + + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for const in c_inp: + func = self._literal_const_sample_generator(impl, {1: const}) + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_else_negate_const(self): + def impl(x): + _CONST1 = "PLACEHOLDER1" + if not _CONST1: + return 3.14159 + else: + return 1.61803 + + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for const in c_inp: + func = self._literal_const_sample_generator(impl, {1: const}) + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_freevar(self): + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for const in c_inp: + + def func(x): + if const: + return 3.14159, const + + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_negate_freevar(self): + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for const in c_inp: + + def func(x): + if not const: + return 3.14159, const + + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_else_negate_freevar(self): + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for const in c_inp: + + def func(x): + if not const: + return 3.14159, const + else: + return 1.61803, const + + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + # globals in this section have absurd names after their test usecase names + # so as to prevent collisions and permit tests to run in parallel + def test_single_if_global(self): + global c_test_single_if_global + + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for c in c_inp: + c_test_single_if_global = c + + def func(x): + if c_test_single_if_global: + return 3.14159, c_test_single_if_global + + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_negate_global(self): + global c_test_single_if_negate_global + + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for c in c_inp: + c_test_single_if_negate_global = c + + def func(x): + if c_test_single_if_negate_global: + return 3.14159, c_test_single_if_negate_global + + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_else_global(self): + global c_test_single_if_else_global + + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for c in c_inp: + c_test_single_if_else_global = c + + def func(x): + if c_test_single_if_else_global: + return 3.14159, c_test_single_if_else_global + else: + return 1.61803, c_test_single_if_else_global + + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_single_if_else_negate_global(self): + global c_test_single_if_else_negate_global + + for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): + for c in c_inp: + c_test_single_if_else_negate_global = c + + def func(x): + if not c_test_single_if_else_negate_global: + return 3.14159, c_test_single_if_else_negate_global + else: + return 1.61803, c_test_single_if_else_negate_global + + self.assert_prune( + func, (types.NoneType("none"),), [prune], None + ) + + def test_issue_5618(self): + @jit + def foo(res): + tmp = 666 + if tmp: + res[0] = tmp + + self.run_func(foo, ()) + + +class TestBranchPrunePostSemanticConstRewrites(TestBranchPruneBase): + # Tests that semantic constants rewriting works by virtue of branch pruning + + def test_array_ndim_attr(self): + def impl(array, res): + if array.ndim == 2: + if array.shape[1] == 2: + res[0] = 1 + else: + res[0] = 10 + + self.assert_prune( + impl, + (types.Array(types.float64, 2, "C"),), + [False, None], + np.zeros((2, 3)), + ) + self.assert_prune( + impl, + (types.Array(types.float64, 1, "C"),), + [True, "both"], + np.zeros((2,)), + ) + + def test_tuple_len(self): + def impl(tup, res): + if len(tup) == 3: + if tup[2] == 2: + res[0] = 1 + else: + res[0] = 0 + + self.assert_prune( + impl, + (types.UniTuple(types.int64, 3),), + [False, None], + tuple([1, 2, 3]), + ) + self.assert_prune( + impl, + (types.UniTuple(types.int64, 2),), + [True, "both"], + tuple([1, 2]), + ) + + def test_attr_not_len(self): + # The purpose of this test is to make sure that the conditions guarding + # the rewrite part do not themselves raise exceptions. + # This produces an `ir.Expr` call node for `float.as_integer_ratio`, + # which is a getattr() on `float`. + + @jit + def test(): + float.as_integer_ratio(1.23) + + # this should raise a TypingError + with self.assertRaises(errors.TypingError) as e: + test[1, 1]() + + self.assertIn("Unknown attribute 'as_integer_ratio'", str(e.exception)) + + def test_ndim_not_on_array(self): + FakeArray = collections.namedtuple("FakeArray", ["ndim"]) + fa = FakeArray(ndim=2) + + def impl(fa, res): + if fa.ndim == 2: + res[0] = fa.ndim + + # check prune works for array ndim + self.assert_prune( + impl, + (types.Array(types.float64, 2, "C"),), + [False], + np.zeros((2, 3)), + ) + + # check prune fails for something with `ndim` attr that is not array + FakeArrayType = types.NamedUniTuple(types.int64, 1, FakeArray) + self.assert_prune( + impl, + (FakeArrayType,), + [None], + fa, + flags={"nopython": False, "forceobj": True}, + ) diff --git a/numba_cuda/numba/cuda/tests/test_byteflow.py b/numba_cuda/numba/cuda/tests/test_byteflow.py new file mode 100644 index 000000000..20574c0a7 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/test_byteflow.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Test byteflow.py specific issues +""" + +import unittest + +from numba.cuda.tests.support import TestCase +from numba.cuda.compiler import run_frontend + + +class TestByteFlowIssues(TestCase): + def test_issue_5087(self): + # This is an odd issue. The exact number of print below is + # necessary to trigger it. Too many or too few will alter the behavior. + # Also note that the function below will not be executed. The problem + # occurs at compilation. The definition below is invalid for execution. + # The problem occurs in the bytecode analysis. + def udt(): + print + print + print + + for i in range: + print + print + print + print + print + print + print + print + print + print + print + print + print + print + print + print + print + print + + for j in range: + print + print + print + print + print + print + print + for k in range: + for l in range: + print + + print + print + print + print + print + print + print + print + print + if print: + for n in range: + print + else: + print + + run_frontend(udt) + + def test_issue_5097(self): + # Inspired by https://github.com/numba/numba/issues/5097 + def udt(): + for i in range(0): + if i > 0: + pass + a = None # noqa: F841 + + run_frontend(udt) + + def test_issue_5680(self): + # From https://github.com/numba/numba/issues/5680#issuecomment-625351336 + def udt(): + for k in range(0): + if 1 == 1: + ... + if "a" == "a": + ... + + run_frontend(udt) + + +if __name__ == "__main__": + unittest.main() diff --git a/numba_cuda/numba/cuda/tests/test_flow_control.py b/numba_cuda/numba/cuda/tests/test_flow_control.py new file mode 100644 index 000000000..49710819d --- /dev/null +++ b/numba_cuda/numba/cuda/tests/test_flow_control.py @@ -0,0 +1,1433 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import itertools + +import unittest +from numba.cuda import jit +from numba.cuda.core.controlflow import CFGraph, ControlFlowAnalysis +from numba.core import types +from numba.cuda.core.bytecode import ( + FunctionIdentity, + ByteCode, + _fix_LOAD_GLOBAL_arg, +) +from numba.cuda.tests.support import TestCase +from numba.cuda import utils +from numba.cuda.core import config + +if config.ENABLE_CUDASIM: + raise unittest.SkipTest("Analysis passes not done in simulator") + + +def for_loop_usecase1(x, y, res1, res2): + result = 0 + for i in range(x): + result += i + res1 = result # noqa: F841 + + +def for_loop_usecase2(x, y, res1, res2): + result = 0 + for i, j in enumerate(range(x, y, -1)): + result += i * j + res1 = result # noqa: F841 + + +def for_loop_usecase4(x, y, res1, res2): + result = 0 + for i in range(10): + for j in range(10): + result += 1 + res1 = result # noqa: F841 + + +def for_loop_usecase5(x, y, res1, res2): + result = 0 + for i in range(x): + result += 1 + if result > y: + break + res1 = result # noqa: F841 + + +def for_loop_usecase6(x, y, res1, res2): + result = 0 + for i in range(x): + if i > y: + continue + result += 1 + res1 = result # noqa: F841 + + +def for_loop_usecase7(x, y, res1, res2): + for i in range(x): + x = 0 + for j in range(x): + res1 = 1 # noqa: F841 + else: + pass + res1 = 0 # noqa: F841 + + +def for_loop_usecase8(x, y, res1, res2): + result = 0 + for i in range(x, y, y - x + 1): + result += 1 + res1 = result # noqa: F841 + + +def for_loop_usecase9(x, y, res1, res2): + z = 0 + for i in range(x): + x = 0 + for j in range(x): + if j == x / 2: + z += j + break + else: + z += y + + res1 = z # noqa: F841 + + +def for_loop_usecase10(x, y, res1, res2): + for i in range(x): + if i == y: + z = y + break + else: + z = i * 2 + res1 = z # noqa: F841 + + +def while_loop_usecase1(x, y, res1, res2): + result = 0 + i = 0 + while i < x: + result += i + i += 1 + res1 = result # noqa: F841 + + +def while_loop_usecase2(x, y, res1, res2): + result = 0 + while result != x: + result += 1 + res1 = result # noqa: F841 + + +def while_loop_usecase3(x, y, res1, res2): + result = 0 + i = 0 + j = 0 + while i < x: + while j < y: + result += i + j + i += 1 + j += 1 + res1 = result # noqa: F841 + + +def while_loop_usecase4(x, y, res1, res2): + result = 0 + while True: + result += 1 + if result > x: + break + res1 = result # noqa: F841 + + +def while_loop_usecase5(x, y, res1, res2): + result = 0 + while result < x: + if result > y: + result += 2 + continue + result += 1 + res1 = result # noqa: F841 + + +def ifelse_usecase1(x, y, res1, res2): + if x > 0: + pass + elif y > 0: + pass + else: + pass + res1 = True # noqa: F841 + + +def ifelse_usecase2(x, y, res1, res2): + if x > y: + res1 = 1 # noqa: F841 + elif x == 0 or y == 0: + res1 = 2 # noqa: F841 + else: + res1 = 3 # noqa: F841 + + +def ifelse_usecase3(x, y, res1, res2): + if x > 0: + if y > 0: + res1 = 1 # noqa: F841 + elif y < 0: + res1 = 1 # noqa: F841 + else: + res1 = 0 # noqa: F841 + elif x < 0: + res1 = 1 # noqa: F841 + else: + res1 = 0 # noqa: F841 + + +def ifelse_usecase4(x, y, res1, res2): + if x == y: + res1 = 1 # noqa: F841 + + +def ternary_ifelse_usecase1(x, y, res1, res2): + res1 = True if x > y else False # noqa: F841 + + +def double_infinite_loop(x, y, res1, res2): + L = x + i = y + + while True: + while True: + if i == L - 1: + break + i += 1 + i += 1 + if i >= L: + break + + res1 = i # noqa: F841 + res2 = L # noqa: F841 + + +def try_except_usecase(): + try: + pass + except Exception: + pass + + +class TestFlowControl(TestCase): + def run_test( + self, + pyfunc, + x_operands, + y_operands, + res1_operands, + res2_operands, + ): + cfunc = jit((types.intp, types.intp, types.intp, types.intp))(pyfunc) + for x, y, res1, res2 in itertools.product( + x_operands, y_operands, res1_operands, res2_operands + ): + pyerr = None + cerr = None + try: + pyfunc(x, y, res1, res2) + except Exception as e: + pyerr = e + + cres1 = 0 + cres2 = 0 + try: + cfunc[1, 1](x, y, cres1, cres2) + except Exception as e: + if pyerr is None: + raise + cerr = e + self.assertEqual(type(pyerr), type(cerr)) + else: + if pyerr is not None: + self.fail( + "Invalid for pure-python but numba-cuda works\n" + + str(pyerr) + ) + self.assertEqual(res1, cres1) + self.assertEqual(res2, cres2) + + def test_for_loop1(self): + self.run_test(for_loop_usecase1, [-10, 0, 10], [0], [0], [0]) + + def test_for_loop1_npm(self): + self.test_for_loop1() + + def test_for_loop2(self): + self.run_test(for_loop_usecase2, [-10, 0, 10], [-10, 0, 10], [0], [0]) + + def test_for_loop2_npm(self): + self.test_for_loop2() + + def test_for_loop4(self): + self.run_test(for_loop_usecase4, [10], [10], [0], [0]) + + def test_for_loop4_npm(self): + self.test_for_loop4() + + def test_for_loop5(self): + self.run_test(for_loop_usecase5, [100], [50], [0], [0]) + + def test_for_loop5_npm(self): + self.test_for_loop5() + + def test_for_loop6(self): + self.run_test(for_loop_usecase6, [100], [50], [0], [0]) + + def test_for_loop6_npm(self): + self.test_for_loop6() + + def test_for_loop7(self): + self.run_test(for_loop_usecase7, [5], [0], [0], [0]) + + def test_for_loop7_npm(self): + self.test_for_loop7() + + @unittest.expectedFailure + def test_for_loop8(self): + self.run_test(for_loop_usecase8, [0, 1], [0, 2, 10], [0], [0]) + + @unittest.expectedFailure + def test_for_loop8_npm(self): + self.test_for_loop8() + + def test_for_loop9(self): + self.run_test(for_loop_usecase9, [0, 1], [0, 2, 10], [0], [0]) + + def test_for_loop9_npm(self): + self.test_for_loop9() + + def test_for_loop10(self): + self.run_test(for_loop_usecase10, [5], [2, 7], [0], [0]) + + def test_for_loop10_npm(self): + self.test_for_loop10() + + def test_while_loop1(self): + self.run_test(while_loop_usecase1, [10], [0], [0], [0]) + + def test_while_loop1_npm(self): + self.test_while_loop1() + + def test_while_loop2(self): + self.run_test(while_loop_usecase2, [10], [0], [0], [0]) + + def test_while_loop2_npm(self): + self.test_while_loop2() + + def test_while_loop3(self): + self.run_test(while_loop_usecase3, [10], [10], [0], [0]) + + def test_while_loop3_npm(self): + self.test_while_loop3() + + def test_while_loop4(self): + self.run_test(while_loop_usecase4, [10], [0], [0], [0]) + + def test_while_loop4_npm(self): + self.test_while_loop4() + + def test_while_loop5(self): + self.run_test(while_loop_usecase5, [0, 5, 10], [0, 5, 10], [0], [0]) + + def test_while_loop5_npm(self): + self.test_while_loop5() + + def test_ifelse1(self): + self.run_test(ifelse_usecase1, [-1, 0, 1], [-1, 0, 1], [0], [0]) + + def test_ifelse1_npm(self): + self.test_ifelse1() + + def test_ifelse2(self): + self.run_test(ifelse_usecase2, [-1, 0, 1], [-1, 0, 1], [0], [0]) + + def test_ifelse2_npm(self): + self.test_ifelse2() + + def test_ifelse3(self): + self.run_test(ifelse_usecase3, [-1, 0, 1], [-1, 0, 1], [0], [0]) + + def test_ifelse3_npm(self): + self.test_ifelse3() + + def test_ifelse4(self): + self.run_test(ifelse_usecase4, [-1, 0, 1], [-1, 0, 1], [0], [0]) + + def test_ifelse4_npm(self): + self.test_ifelse4() + + def test_ternary_ifelse1(self): + self.run_test( + ternary_ifelse_usecase1, + [-1, 0, 1], + [-1, 0, 1], + [0], + [0], + ) + + def test_ternary_ifelse1_npm(self): + self.test_ternary_ifelse1() + + def test_double_infinite_loop(self): + self.run_test(double_infinite_loop, [10], [0], [0], [0]) + + def test_double_infinite_loop_npm(self): + self.test_double_infinite_loop() + + +class TestCFGraph(TestCase): + """ + Test the numba.controlflow.CFGraph class. + """ + + def from_adj_list(self, d, entry_point=0): + """ + Build a CFGraph class from a dict of adjacency lists. + """ + g = CFGraph() + # Need to add all nodes before adding edges + for node in d: + g.add_node(node) + for node, dests in d.items(): + for dest in dests: + g.add_edge(node, dest) + return g + + def loopless1(self): + """ + A simple CFG corresponding to the following code structure: + + c = (... if ... else ...) + ... + return b + c + """ + g = self.from_adj_list({0: [18, 12], 12: [21], 18: [21], 21: []}) + g.set_entry_point(0) + g.process() + return g + + def loopless1_dead_nodes(self): + """ + Same as loopless1(), but with added dead blocks (some of them + in a loop). + """ + g = self.from_adj_list( + { + 0: [18, 12], + 12: [21], + 18: [21], + 21: [], + 91: [12, 0], + 92: [91, 93], + 93: [92], + 94: [], + } + ) + g.set_entry_point(0) + g.process() + return g + + def loopless2(self): + """ + A loopless CFG corresponding to the following code structure: + + c = (... if ... else ...) + ... + if c: + return ... + else: + return ... + + Note there are two exit points, and the entry point has been + changed to a non-zero value. + """ + g = self.from_adj_list( + {99: [18, 12], 12: [21], 18: [21], 21: [42, 34], 34: [], 42: []} + ) + g.set_entry_point(99) + g.process() + return g + + def multiple_loops(self): + """ + A CFG with multiple nested loops: + + for y in b: + for x in a: + # This loop has two back edges + if b: + continue + else: + continue + for z in c: + if z: + return ... + """ + g = self.from_adj_list( + { + 0: [7], + 7: [10, 60], + 10: [13], + 13: [20], + 20: [56, 23], + 23: [32, 44], + 32: [20], + 44: [20], + 56: [57], + 57: [7], + 60: [61], + 61: [68], + 68: [87, 71], + 71: [80, 68], + 80: [], + 87: [88], + 88: [], + } + ) + g.set_entry_point(0) + g.process() + return g + + def multiple_exits(self): + """ + A CFG with three loop exits, one of which is also a function + exit point, and another function exit point: + + for x in a: + if a: + return b + elif b: + break + return c + """ + g = self.from_adj_list( + { + 0: [7], + 7: [10, 36], + 10: [19, 23], + 19: [], + 23: [29, 7], + 29: [37], + 36: [37], + 37: [], + } + ) + g.set_entry_point(0) + g.process() + return g + + def infinite_loop1(self): + """ + A CFG with a infinite loop and an alternate exit point: + + if c: + return + while True: + if a: + ... + else: + ... + """ + g = self.from_adj_list( + {0: [10, 6], 6: [], 10: [13], 13: [26, 19], 19: [13], 26: [13]} + ) + g.set_entry_point(0) + g.process() + return g + + def infinite_loop2(self): + """ + A CFG with no exit point at all: + + while True: + if a: + ... + else: + ... + """ + g = self.from_adj_list({0: [3], 3: [16, 9], 9: [3], 16: [3]}) + g.set_entry_point(0) + g.process() + return g + + def test_simple_properties(self): + g = self.loopless1() + self.assertEqual(sorted(g.successors(0)), [(12, None), (18, None)]) + self.assertEqual(sorted(g.successors(21)), []) + self.assertEqual(sorted(g.predecessors(0)), []) + self.assertEqual(sorted(g.predecessors(21)), [(12, None), (18, None)]) + + def test_exit_points(self): + g = self.loopless1() + self.assertEqual(sorted(g.exit_points()), [21]) + g = self.loopless1_dead_nodes() + self.assertEqual(sorted(g.exit_points()), [21]) + g = self.loopless2() + self.assertEqual(sorted(g.exit_points()), [34, 42]) + g = self.multiple_loops() + self.assertEqual(sorted(g.exit_points()), [80, 88]) + g = self.infinite_loop1() + self.assertEqual(sorted(g.exit_points()), [6]) + g = self.infinite_loop2() + self.assertEqual(sorted(g.exit_points()), []) + g = self.multiple_exits() + self.assertEqual(sorted(g.exit_points()), [19, 37]) + + def test_dead_nodes(self): + g = self.loopless1() + self.assertEqual(len(g.dead_nodes()), 0) + self.assertEqual(sorted(g.nodes()), [0, 12, 18, 21]) + g = self.loopless2() + self.assertEqual(len(g.dead_nodes()), 0) + self.assertEqual(sorted(g.nodes()), [12, 18, 21, 34, 42, 99]) + g = self.multiple_loops() + self.assertEqual(len(g.dead_nodes()), 0) + g = self.infinite_loop1() + self.assertEqual(len(g.dead_nodes()), 0) + g = self.multiple_exits() + self.assertEqual(len(g.dead_nodes()), 0) + # Only this example has dead nodes + g = self.loopless1_dead_nodes() + self.assertEqual(sorted(g.dead_nodes()), [91, 92, 93, 94]) + self.assertEqual(sorted(g.nodes()), [0, 12, 18, 21]) + + def test_descendents(self): + g = self.loopless2() + d = g.descendents(34) + self.assertEqual(sorted(d), []) + d = g.descendents(42) + self.assertEqual(sorted(d), []) + d = g.descendents(21) + self.assertEqual(sorted(d), [34, 42]) + d = g.descendents(99) + self.assertEqual(sorted(d), [12, 18, 21, 34, 42]) + g = self.infinite_loop1() + d = g.descendents(26) + self.assertEqual(sorted(d), []) + d = g.descendents(19) + self.assertEqual(sorted(d), []) + d = g.descendents(13) + self.assertEqual(sorted(d), [19, 26]) + d = g.descendents(10) + self.assertEqual(sorted(d), [13, 19, 26]) + d = g.descendents(6) + self.assertEqual(sorted(d), []) + d = g.descendents(0) + self.assertEqual(sorted(d), [6, 10, 13, 19, 26]) + + def test_topo_order(self): + g = self.loopless1() + self.assertIn(g.topo_order(), ([0, 12, 18, 21], [0, 18, 12, 21])) + g = self.loopless2() + self.assertIn( + g.topo_order(), ([99, 18, 12, 21, 34, 42], [99, 12, 18, 21, 34, 42]) + ) + g = self.infinite_loop2() + self.assertIn(g.topo_order(), ([0, 3, 9, 16], [0, 3, 16, 9])) + g = self.infinite_loop1() + self.assertIn( + g.topo_order(), + ( + [0, 6, 10, 13, 19, 26], + [0, 6, 10, 13, 26, 19], + [0, 10, 13, 19, 26, 6], + [0, 10, 13, 26, 19, 6], + ), + ) + + def test_topo_sort(self): + def check_topo_sort(nodes, expected): + self.assertIn(list(g.topo_sort(nodes)), expected) + self.assertIn(list(g.topo_sort(nodes[::-1])), expected) + self.assertIn( + list(g.topo_sort(nodes, reverse=True))[::-1], expected + ) + self.assertIn( + list(g.topo_sort(nodes[::-1], reverse=True))[::-1], expected + ) + self.random.shuffle(nodes) + self.assertIn(list(g.topo_sort(nodes)), expected) + self.assertIn( + list(g.topo_sort(nodes, reverse=True))[::-1], expected + ) + + g = self.loopless2() + check_topo_sort([21, 99, 12, 34], ([99, 12, 21, 34],)) + # NOTE: topo_sort() is not stable + check_topo_sort([18, 12, 42, 99], ([99, 12, 18, 42], [99, 18, 12, 42])) + g = self.multiple_exits() + check_topo_sort( + [19, 10, 7, 36], ([7, 10, 19, 36], [7, 10, 36, 19], [7, 36, 10, 19]) + ) + + def check_dominators(self, got, expected): + self.assertEqual(sorted(got), sorted(expected)) + for node in sorted(got): + self.assertEqual( + sorted(got[node]), + sorted(expected[node]), + "mismatch for %r" % (node,), + ) + + def test_dominators_loopless(self): + def eq_(d, l): + self.assertEqual(sorted(doms[d]), l) + + for g in [self.loopless1(), self.loopless1_dead_nodes()]: + doms = g.dominators() + eq_(0, [0]) + eq_(12, [0, 12]) + eq_(18, [0, 18]) + eq_(21, [0, 21]) + g = self.loopless2() + doms = g.dominators() + eq_(99, [99]) + eq_(12, [12, 99]) + eq_(18, [18, 99]) + eq_(21, [21, 99]) + eq_(34, [21, 34, 99]) + eq_(42, [21, 42, 99]) + + def test_dominators_loops(self): + g = self.multiple_exits() + doms = g.dominators() + self.check_dominators( + doms, + { + 0: [0], + 7: [0, 7], + 10: [0, 7, 10], + 19: [0, 7, 10, 19], + 23: [0, 7, 10, 23], + 29: [0, 7, 10, 23, 29], + 36: [0, 7, 36], + 37: [0, 7, 37], + }, + ) + g = self.multiple_loops() + doms = g.dominators() + self.check_dominators( + doms, + { + 0: [0], + 7: [0, 7], + 10: [0, 10, 7], + 13: [0, 10, 13, 7], + 20: [0, 10, 20, 13, 7], + 23: [0, 20, 23, 7, 10, 13], + 32: [32, 0, 20, 23, 7, 10, 13], + 44: [0, 20, 23, 7, 10, 44, 13], + 56: [0, 20, 7, 56, 10, 13], + 57: [0, 20, 7, 56, 57, 10, 13], + 60: [0, 60, 7], + 61: [0, 60, 61, 7], + 68: [0, 68, 60, 61, 7], + 71: [0, 68, 71, 7, 60, 61], + 80: [80, 0, 68, 71, 7, 60, 61], + 87: [0, 68, 87, 7, 60, 61], + 88: [0, 68, 87, 88, 7, 60, 61], + }, + ) + g = self.infinite_loop1() + doms = g.dominators() + self.check_dominators( + doms, + { + 0: [0], + 6: [0, 6], + 10: [0, 10], + 13: [0, 10, 13], + 19: [0, 10, 19, 13], + 26: [0, 10, 13, 26], + }, + ) + + def test_post_dominators_loopless(self): + def eq_(d, l): + self.assertEqual(sorted(doms[d]), l) + + for g in [self.loopless1(), self.loopless1_dead_nodes()]: + doms = g.post_dominators() + eq_(0, [0, 21]) + eq_(12, [12, 21]) + eq_(18, [18, 21]) + eq_(21, [21]) + g = self.loopless2() + doms = g.post_dominators() + eq_(34, [34]) + eq_(42, [42]) + eq_(21, [21]) + eq_(18, [18, 21]) + eq_(12, [12, 21]) + eq_(99, [21, 99]) + + def test_post_dominators_loops(self): + g = self.multiple_exits() + doms = g.post_dominators() + self.check_dominators( + doms, + { + 0: [0, 7], + 7: [7], + 10: [10], + 19: [19], + 23: [23], + 29: [29, 37], + 36: [36, 37], + 37: [37], + }, + ) + g = self.multiple_loops() + doms = g.post_dominators() + self.check_dominators( + doms, + { + 0: [0, 60, 68, 61, 7], + 7: [60, 68, 61, 7], + 10: [68, 7, 10, 13, 20, 56, 57, 60, 61], + 13: [68, 7, 13, 20, 56, 57, 60, 61], + 20: [20, 68, 7, 56, 57, 60, 61], + 23: [68, 7, 20, 23, 56, 57, 60, 61], + 32: [32, 68, 7, 20, 56, 57, 60, 61], + 44: [68, 7, 44, 20, 56, 57, 60, 61], + 56: [68, 7, 56, 57, 60, 61], + 57: [57, 60, 68, 61, 7], + 60: [60, 68, 61], + 61: [68, 61], + 68: [68], + 71: [71], + 80: [80], + 87: [88, 87], + 88: [88], + }, + ) + + def test_post_dominators_infinite_loops(self): + # Post-dominators with infinite loops need special care + # (the ordinary algorithm won't work). + g = self.infinite_loop1() + doms = g.post_dominators() + self.check_dominators( + doms, + { + 0: [0], + 6: [6], + 10: [10, 13], + 13: [13], + 19: [19], + 26: [26], + }, + ) + g = self.infinite_loop2() + doms = g.post_dominators() + self.check_dominators( + doms, + { + 0: [0, 3], + 3: [3], + 9: [9], + 16: [16], + }, + ) + + def test_dominator_tree(self): + def check(graph, expected): + domtree = graph.dominator_tree() + self.assertEqual(domtree, expected) + + check( + self.loopless1(), {0: {12, 18, 21}, 12: set(), 18: set(), 21: set()} + ) + check( + self.loopless2(), + { + 12: set(), + 18: set(), + 21: {34, 42}, + 34: set(), + 42: set(), + 99: {18, 12, 21}, + }, + ) + check( + self.loopless1_dead_nodes(), + {0: {12, 18, 21}, 12: set(), 18: set(), 21: set()}, + ) + check( + self.multiple_loops(), + { + 0: {7}, + 7: {10, 60}, + 60: {61}, + 61: {68}, + 68: {71, 87}, + 87: {88}, + 88: set(), + 71: {80}, + 80: set(), + 10: {13}, + 13: {20}, + 20: {56, 23}, + 23: {32, 44}, + 44: set(), + 32: set(), + 56: {57}, + 57: set(), + }, + ) + check( + self.multiple_exits(), + { + 0: {7}, + 7: {10, 36, 37}, + 36: set(), + 10: {19, 23}, + 23: {29}, + 29: set(), + 37: set(), + 19: set(), + }, + ) + check( + self.infinite_loop1(), + { + 0: {10, 6}, + 6: set(), + 10: {13}, + 13: {26, 19}, + 19: set(), + 26: set(), + }, + ) + check(self.infinite_loop2(), {0: {3}, 3: {16, 9}, 9: set(), 16: set()}) + + def test_immediate_dominators(self): + def check(graph, expected): + idoms = graph.immediate_dominators() + self.assertEqual(idoms, expected) + + check(self.loopless1(), {0: 0, 12: 0, 18: 0, 21: 0}) + check( + self.loopless2(), {18: 99, 12: 99, 21: 99, 42: 21, 34: 21, 99: 99} + ) + check(self.loopless1_dead_nodes(), {0: 0, 12: 0, 18: 0, 21: 0}) + check( + self.multiple_loops(), + { + 0: 0, + 7: 0, + 10: 7, + 13: 10, + 20: 13, + 23: 20, + 32: 23, + 44: 23, + 56: 20, + 57: 56, + 60: 7, + 61: 60, + 68: 61, + 71: 68, + 80: 71, + 87: 68, + 88: 87, + }, + ) + check( + self.multiple_exits(), + {0: 0, 7: 0, 10: 7, 19: 10, 23: 10, 29: 23, 36: 7, 37: 7}, + ) + check( + self.infinite_loop1(), {0: 0, 6: 0, 10: 0, 13: 10, 19: 13, 26: 13} + ) + check(self.infinite_loop2(), {0: 0, 3: 0, 9: 3, 16: 3}) + + def test_dominance_frontier(self): + def check(graph, expected): + df = graph.dominance_frontier() + self.assertEqual(df, expected) + + check(self.loopless1(), {0: set(), 12: {21}, 18: {21}, 21: set()}) + check( + self.loopless2(), + {18: {21}, 12: {21}, 21: set(), 42: set(), 34: set(), 99: set()}, + ) + check( + self.loopless1_dead_nodes(), + {0: set(), 12: {21}, 18: {21}, 21: set()}, + ) + check( + self.multiple_loops(), + { + 0: set(), + 7: {7}, + 10: {7}, + 13: {7}, + 20: {20, 7}, + 23: {20}, + 32: {20}, + 44: {20}, + 56: {7}, + 57: {7}, + 60: set(), + 61: set(), + 68: {68}, + 71: {68}, + 80: set(), + 87: set(), + 88: set(), + }, + ) + check( + self.multiple_exits(), + { + 0: set(), + 7: {7}, + 10: {37, 7}, + 19: set(), + 23: {37, 7}, + 29: {37}, + 36: {37}, + 37: set(), + }, + ) + check( + self.infinite_loop1(), + {0: set(), 6: set(), 10: set(), 13: {13}, 19: {13}, 26: {13}}, + ) + check(self.infinite_loop2(), {0: set(), 3: {3}, 9: {3}, 16: {3}}) + + def test_backbone_loopless(self): + for g in [self.loopless1(), self.loopless1_dead_nodes()]: + self.assertEqual(sorted(g.backbone()), [0, 21]) + g = self.loopless2() + self.assertEqual(sorted(g.backbone()), [21, 99]) + + def test_backbone_loops(self): + g = self.multiple_loops() + self.assertEqual(sorted(g.backbone()), [0, 7, 60, 61, 68]) + g = self.infinite_loop1() + self.assertEqual(sorted(g.backbone()), [0]) + g = self.infinite_loop2() + self.assertEqual(sorted(g.backbone()), [0, 3]) + + def test_loops(self): + for g in [ + self.loopless1(), + self.loopless1_dead_nodes(), + self.loopless2(), + ]: + self.assertEqual(len(g.loops()), 0) + + g = self.multiple_loops() + # Loop headers + self.assertEqual(sorted(g.loops()), [7, 20, 68]) + outer1 = g.loops()[7] + inner1 = g.loops()[20] + outer2 = g.loops()[68] + self.assertEqual(outer1.header, 7) + self.assertEqual(sorted(outer1.entries), [0]) + self.assertEqual(sorted(outer1.exits), [60]) + self.assertEqual( + sorted(outer1.body), [7, 10, 13, 20, 23, 32, 44, 56, 57] + ) + self.assertEqual(inner1.header, 20) + self.assertEqual(sorted(inner1.entries), [13]) + self.assertEqual(sorted(inner1.exits), [56]) + self.assertEqual(sorted(inner1.body), [20, 23, 32, 44]) + self.assertEqual(outer2.header, 68) + self.assertEqual(sorted(outer2.entries), [61]) + self.assertEqual(sorted(outer2.exits), [80, 87]) + self.assertEqual(sorted(outer2.body), [68, 71]) + for node in [0, 60, 61, 80, 87, 88]: + self.assertEqual(g.in_loops(node), []) + for node in [7, 10, 13, 56, 57]: + self.assertEqual(g.in_loops(node), [outer1]) + for node in [20, 23, 32, 44]: + self.assertEqual(g.in_loops(node), [inner1, outer1]) + for node in [68, 71]: + self.assertEqual(g.in_loops(node), [outer2]) + + g = self.infinite_loop1() + # Loop headers + self.assertEqual(sorted(g.loops()), [13]) + loop = g.loops()[13] + self.assertEqual(loop.header, 13) + self.assertEqual(sorted(loop.entries), [10]) + self.assertEqual(sorted(loop.exits), []) + self.assertEqual(sorted(loop.body), [13, 19, 26]) + for node in [0, 6, 10]: + self.assertEqual(g.in_loops(node), []) + for node in [13, 19, 26]: + self.assertEqual(g.in_loops(node), [loop]) + + g = self.infinite_loop2() + # Loop headers + self.assertEqual(sorted(g.loops()), [3]) + loop = g.loops()[3] + self.assertEqual(loop.header, 3) + self.assertEqual(sorted(loop.entries), [0]) + self.assertEqual(sorted(loop.exits), []) + self.assertEqual(sorted(loop.body), [3, 9, 16]) + for node in [0]: + self.assertEqual(g.in_loops(node), []) + for node in [3, 9, 16]: + self.assertEqual(g.in_loops(node), [loop]) + + g = self.multiple_exits() + # Loop headers + self.assertEqual(sorted(g.loops()), [7]) + loop = g.loops()[7] + self.assertEqual(loop.header, 7) + self.assertEqual(sorted(loop.entries), [0]) + self.assertEqual(sorted(loop.exits), [19, 29, 36]) + self.assertEqual(sorted(loop.body), [7, 10, 23]) + for node in [0, 19, 29, 36]: + self.assertEqual(g.in_loops(node), []) + for node in [7, 10, 23]: + self.assertEqual(g.in_loops(node), [loop]) + + def test_loop_dfs_pathological(self): + # The follow adjlist is an export from the reproducer in #6186 + g = self.from_adj_list( + { + 0: {38, 14}, + 14: {38, 22}, + 22: {38, 30}, + 30: {42, 38}, + 38: {42}, + 42: {64, 50}, + 50: {64, 58}, + 58: {128}, + 64: {72, 86}, + 72: {80, 86}, + 80: {128}, + 86: {108, 94}, + 94: {108, 102}, + 102: {128}, + 108: {128, 116}, + 116: {128, 124}, + 124: {128}, + 128: {178, 174}, + 174: {178}, + 178: {210, 206}, + 206: {210}, + 210: {248, 252}, + 248: {252}, + 252: {282, 286}, + 282: {286}, + 286: {296, 326}, + 296: {330}, + 326: {330}, + 330: {370, 340}, + 340: {374}, + 370: {374}, + 374: {380, 382}, + 380: {382}, + 382: {818, 390}, + 390: {456, 458}, + 456: {458}, + 458: {538, 566}, + 538: {548, 566}, + 548: set(), + 566: {586, 572}, + 572: {586}, + 586: {708, 596}, + 596: {608}, + 608: {610}, + 610: {704, 620}, + 620: {666, 630}, + 630: {636, 646}, + 636: {666, 646}, + 646: {666}, + 666: {610}, + 704: {706}, + 706: {818}, + 708: {720}, + 720: {722}, + 722: {816, 732}, + 732: {778, 742}, + 742: {748, 758}, + 748: {778, 758}, + 758: {778}, + 778: {722}, + 816: {818}, + 818: set(), + } + ) + g.set_entry_point(0) + g.process() + stats = {} + # Compute backedges and store the iteration count for testing + back_edges = g._find_back_edges(stats=stats) + self.assertEqual(back_edges, {(666, 610), (778, 722)}) + self.assertEqual(stats["iteration_count"], 155) + + def test_equals(self): + def get_new(): + g = self.from_adj_list({0: [18, 12], 12: [21], 18: [21], 21: []}) + g.set_entry_point(0) + g.process() + return g + + x = get_new() + y = get_new() + + # identical + self.assertEqual(x, y) + + # identical but defined in a different order + g = self.from_adj_list({0: [12, 18], 18: [21], 21: [], 12: [21]}) + g.set_entry_point(0) + g.process() + self.assertEqual(x, g) + + # different entry point + z = get_new() + z.set_entry_point(18) + z.process() + self.assertNotEqual(x, z) + + # extra node/edge, same entry point + z = self.from_adj_list( + {0: [18, 12], 12: [21], 18: [21], 21: [22], 22: []} + ) + z.set_entry_point(0) + z.process() + self.assertNotEqual(x, z) + + # same nodes, different edges + a = self.from_adj_list({0: [18, 12], 12: [0], 18: []}) + a.set_entry_point(0) + a.process() + z = self.from_adj_list({0: [18, 12], 12: [18], 18: []}) + z.set_entry_point(0) + z.process() + self.assertNotEqual(a, z) + + +class TestRealCodeDomFront(TestCase): + """Test IDOM and DOMFRONT computation on real python bytecode. + Note: there will be less testing on IDOM (esp in loop) because of + the extra blocks inserted by the interpreter. But, testing on DOMFRONT + (which depends on IDOM) is easier. + + Testing is done by associating names to basicblock by using globals of + the pattern "SET_BLOCK_", which are scanned by + `.get_cfa_and_namedblocks` into *namedblocks* dictionary. That way, we + can check that a block of a certain name is a IDOM or DOMFRONT of another + named block. + """ + + def cfa(self, bc): + cfa = ControlFlowAnalysis(bc) + cfa.run() + return cfa + + def get_cfa_and_namedblocks(self, fn): + fid = FunctionIdentity.from_function(fn) + bc = ByteCode(func_id=fid) + cfa = self.cfa(bc) + namedblocks = self._scan_namedblocks(bc, cfa) + + #### To debug, uncomment below + # print(bc.dump()) + # print("IDOMS") + # for k, v in sorted(cfa.graph.immediate_dominators().items()): + # print('{} -> {}'.format(k, v)) + # print("DOMFRONT") + # for k, vs in sorted(cfa.graph.dominance_frontier().items()): + # print('{} -> {}'.format(k, vs)) + # print(namedblocks) + # cfa.graph.render_dot().view() + + return cfa, namedblocks + + def _scan_namedblocks(self, bc, cfa): + """Scan namedblocks as denoted by a LOAD_GLOBAL bytecode referring + to global variables with the pattern "SET_BLOCK_", where "" + would be the name for the current block. + """ + namedblocks = {} + blocks = sorted([x.offset for x in cfa.iterblocks()]) + prefix = "SET_BLOCK_" + + for inst in bc: + # Find LOAD_GLOBAL that refers to "SET_BLOCK_" + if inst.opname == "LOAD_GLOBAL": + gv = bc.co_names[_fix_LOAD_GLOBAL_arg(inst.arg)] + if gv.startswith(prefix): + name = gv[len(prefix) :] + # Find the block where this instruction resides + for s, e in zip(blocks, blocks[1:] + [blocks[-1] + 1]): + if s <= inst.offset < e: + break + else: + raise AssertionError("unreachable loop") + blkno = s + namedblocks[name] = blkno + return namedblocks + + def test_loop(self): + def foo(n): + c = 0 + SET_BLOCK_A # noqa: F821 + i = 0 + while SET_BLOCK_B0: # noqa: F821 + SET_BLOCK_B1 # noqa: F821 + c += 1 + i += 1 + SET_BLOCK_C # noqa: F821 + return c + + cfa, blkpts = self.get_cfa_and_namedblocks(foo) + + # Py3.10 turns while loop into if(...) { do {...} while(...) }. + # Also, `SET_BLOCK_B0` is duplicated. As a result, the second B0 + # is picked up by `blkpts`. + domfront = cfa.graph.dominance_frontier() + self.assertFalse(domfront[blkpts["A"]]) + self.assertFalse(domfront[blkpts["C"]]) + + def test_loop_nested_and_break(self): + def foo(n): + SET_BLOCK_A # noqa: F821 + while SET_BLOCK_B0: # noqa: F821 + SET_BLOCK_B1 # noqa: F821 + while SET_BLOCK_C0: # noqa: F821 + SET_BLOCK_C1 # noqa: F821 + if SET_BLOCK_D0: # noqa: F821 + SET_BLOCK_D1 # noqa: F821 + break + elif n: + SET_BLOCK_D2 # noqa: F821 + SET_BLOCK_E # noqa: F821 + SET_BLOCK_F # noqa: F821 + SET_BLOCK_G # noqa: F821 + + cfa, blkpts = self.get_cfa_and_namedblocks(foo) + self.assertEqual(blkpts["D0"], blkpts["C1"]) + + # Py3.10 changes while loop into if-do-while + domfront = cfa.graph.dominance_frontier() + self.assertFalse(domfront[blkpts["A"]]) + self.assertFalse(domfront[blkpts["G"]]) + # 2 domfront members for C1 + # C0 because of the loop; F because of the break. + self.assertEqual({blkpts["F"]}, domfront[blkpts["D1"]]) + self.assertEqual({blkpts["E"]}, domfront[blkpts["D2"]]) + + def test_if_else(self): + def foo(a, b): + c = 0 + SET_BLOCK_A # noqa: F821 + if a < b: + SET_BLOCK_B # noqa: F821 + c = 1 + elif SET_BLOCK_C0: # noqa: F821 + SET_BLOCK_C1 # noqa: F821 + c = 2 + else: + SET_BLOCK_D # noqa: F821 + c = 3 + + SET_BLOCK_E # noqa: F821 + if a % b == 0: + SET_BLOCK_F # noqa: F821 + c += 1 + SET_BLOCK_G # noqa: F821 + return c + + cfa, blkpts = self.get_cfa_and_namedblocks(foo) + + idoms = cfa.graph.immediate_dominators() + self.assertEqual(blkpts["A"], idoms[blkpts["B"]]) + self.assertEqual(blkpts["A"], idoms[blkpts["C0"]]) + self.assertEqual(blkpts["C0"], idoms[blkpts["C1"]]) + self.assertEqual(blkpts["C0"], idoms[blkpts["D"]]) + self.assertEqual(blkpts["A"], idoms[blkpts["E"]]) + self.assertEqual(blkpts["E"], idoms[blkpts["F"]]) + self.assertEqual(blkpts["E"], idoms[blkpts["G"]]) + + domfront = cfa.graph.dominance_frontier() + self.assertFalse(domfront[blkpts["A"]]) + self.assertFalse(domfront[blkpts["E"]]) + self.assertFalse(domfront[blkpts["G"]]) + self.assertEqual({blkpts["E"]}, domfront[blkpts["B"]]) + self.assertEqual({blkpts["E"]}, domfront[blkpts["C0"]]) + self.assertEqual({blkpts["E"]}, domfront[blkpts["C1"]]) + self.assertEqual({blkpts["E"]}, domfront[blkpts["D"]]) + self.assertEqual({blkpts["G"]}, domfront[blkpts["F"]]) + + def test_if_else_nested(self): + def foo(): + if SET_BLOCK_A0: # noqa: F821 + SET_BLOCK_A1 # noqa: F821 + if SET_BLOCK_B0: # noqa: F821 + SET_BLOCK_B1 # noqa: F821 + a = 0 + else: + if SET_BLOCK_C0: # noqa: F821 + SET_BLOCK_C1 # noqa: F821 + a = 1 + else: + SET_BLOCK_C2 # noqa: F821 + a = 2 + SET_BLOCK_D # noqa: F821 + SET_BLOCK_E # noqa: F821 + SET_BLOCK_F # noqa: F821 + return a + + cfa, blkpts = self.get_cfa_and_namedblocks(foo) + + idoms = cfa.graph.immediate_dominators() + self.assertEqual(blkpts["A0"], idoms[blkpts["A1"]]) + self.assertEqual(blkpts["A1"], idoms[blkpts["B1"]]) + self.assertEqual(blkpts["A1"], idoms[blkpts["C0"]]) + self.assertEqual(blkpts["C0"], idoms[blkpts["D"]]) + self.assertEqual(blkpts["A1"], idoms[blkpts["E"]]) + self.assertEqual(blkpts["A0"], idoms[blkpts["F"]]) + + domfront = cfa.graph.dominance_frontier() + self.assertFalse(domfront[blkpts["A0"]]) + self.assertFalse(domfront[blkpts["F"]]) + self.assertEqual({blkpts["E"]}, domfront[blkpts["B1"]]) + self.assertEqual({blkpts["D"]}, domfront[blkpts["C1"]]) + self.assertEqual({blkpts["E"]}, domfront[blkpts["D"]]) + self.assertEqual({blkpts["F"]}, domfront[blkpts["E"]]) + + def test_infinite_loop(self): + def foo(): + SET_BLOCK_A # noqa: F821 + while True: # infinite loop + if SET_BLOCK_B: # noqa: F821 + SET_BLOCK_C # noqa: F821 + return + SET_BLOCK_D # noqa: F821 + SET_BLOCK_E # noqa: F821 + + cfa, blkpts = self.get_cfa_and_namedblocks(foo) + + idoms = cfa.graph.immediate_dominators() + if utils.PYVERSION >= (3, 10): + self.assertNotIn("E", blkpts) + else: + self.assertNotIn(blkpts["E"], idoms) + self.assertEqual(blkpts["B"], idoms[blkpts["C"]]) + self.assertEqual(blkpts["B"], idoms[blkpts["D"]]) + + domfront = cfa.graph.dominance_frontier() + if utils.PYVERSION < (3, 10): + self.assertNotIn(blkpts["E"], domfront) + self.assertFalse(domfront[blkpts["A"]]) + self.assertFalse(domfront[blkpts["C"]]) + self.assertEqual({blkpts["B"]}, domfront[blkpts["B"]]) + self.assertEqual({blkpts["B"]}, domfront[blkpts["D"]]) + + +if __name__ == "__main__": + unittest.main()