diff --git a/numba_cuda/numba/cuda/core/analysis.py b/numba_cuda/numba/cuda/core/analysis.py index bda7fdf2f..b849891a8 100644 --- a/numba_cuda/numba/cuda/core/analysis.py +++ b/numba_cuda/numba/cuda/core/analysis.py @@ -38,13 +38,16 @@ def compute_use_defs(blocks): func = ir_extension_usedefs[type(stmt)] func(stmt, use_set, def_set) continue - if isinstance(stmt, ir.Assign): - if isinstance(stmt.value, ir.Inst): + if isinstance(stmt, ir.assign_types): + if isinstance(stmt.value, ir.inst_types): rhs_set = set(var.name for var in stmt.value.list_vars()) - elif isinstance(stmt.value, ir.Var): + elif isinstance(stmt.value, ir.var_types): rhs_set = set([stmt.value.name]) - elif isinstance( - stmt.value, (ir.Arg, ir.Const, ir.Global, ir.FreeVar) + elif ( + isinstance(stmt.value, ir.arg_types) + or isinstance(stmt.value, ir.const_types) + or isinstance(stmt.value, ir.global_types) + or isinstance(stmt.value, ir.freevar_types) ): rhs_set = () else: @@ -326,7 +329,7 @@ def rewrite_array_ndim(val, func_ir, called_args): if getattr(val, "op", None) == "getattr": if val.attr == "ndim": arg_def = guard(get_definition, func_ir, val.value) - if isinstance(arg_def, ir.Arg): + if isinstance(arg_def, ir.arg_types): argty = called_args[arg_def.index] if isinstance(argty, types.Array): rewrite_statement(func_ir, stmt, argty.ndim) @@ -337,17 +340,17 @@ def rewrite_tuple_len(val, func_ir, called_args): func = guard(get_definition, func_ir, val.func) if ( func is not None - and isinstance(func, ir.Global) + and isinstance(func, ir.global_types) and getattr(func, "value", None) is len ): (arg,) = val.args arg_def = guard(get_definition, func_ir, arg) - if isinstance(arg_def, ir.Arg): + if isinstance(arg_def, ir.arg_types): argty = called_args[arg_def.index] if isinstance(argty, types.BaseTuple): rewrite_statement(func_ir, stmt, argty.count) elif ( - isinstance(arg_def, ir.Expr) + isinstance(arg_def, ir.expr_types) and arg_def.op == "typed_getitem" ): argty = arg_def.dtype @@ -358,9 +361,9 @@ def rewrite_tuple_len(val, func_ir, called_args): for blk in func_ir.blocks.values(): for stmt in blk.body: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): val = stmt.value - if isinstance(val, ir.Expr): + if isinstance(val, ir.expr_types): rewrite_array_ndim(val, func_ir, called_args) rewrite_tuple_len(val, func_ir, called_args) @@ -391,7 +394,7 @@ def find_literally_calls(func_ir, argtypes): for blk in func_ir.blocks.values(): for assign in blk.find_exprs(op="call"): var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func) - if isinstance(var, (ir.Global, ir.FreeVar)): + if isinstance(var, ir.global_types + ir.freevar_types): fnobj = var.value else: fnobj = ir_utils.guard( @@ -401,7 +404,7 @@ def find_literally_calls(func_ir, argtypes): # Found [arg] = assign.args defarg = func_ir.get_definition(arg) - if isinstance(defarg, ir.Arg): + if isinstance(defarg, ir.arg_types): argindex = defarg.index marked_args.add(argindex) first_loc.setdefault(argindex, assign.loc) @@ -473,14 +476,14 @@ def find_branches(func_ir): branches = [] for blk in func_ir.blocks.values(): branch_or_jump = blk.body[-1] - if isinstance(branch_or_jump, ir.Branch): + if isinstance(branch_or_jump, ir.branch_types): branch = branch_or_jump pred = guard(get_definition, func_ir, branch.cond.name) if pred is not None and getattr(pred, "op", None) == "call": function = guard(get_definition, func_ir, pred.func) if ( function is not None - and isinstance(function, ir.Global) + and isinstance(function, ir.global_types) and function.value is bool ): condition = guard(get_definition, func_ir, pred.args[0]) @@ -539,7 +542,9 @@ def prune_by_predicate(branch, pred, blk): try: # Just to prevent accidents, whilst already guarded, ensure this # is an ir.Const - if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)): + if not isinstance( + pred, ir.const_types + ir.freevar_types + ir.global_types + ): raise TypeError("Expected constant Numba IR node") take_truebr = bool(pred.value) except TypeError: @@ -584,8 +589,11 @@ def resolve_input_arg_const(input_arg_idx): phi2asgn = dict() for lbl, blk in func_ir.blocks.items(): for stmt in blk.body: - if isinstance(stmt, ir.Assign): - if isinstance(stmt.value, ir.Expr) and stmt.value.op == "phi": + if isinstance(stmt, ir.assign_types): + if ( + isinstance(stmt.value, ir.expr_types) + and stmt.value.op == "phi" + ): phi2lbl[stmt.value] = lbl phi2asgn[stmt.value] = stmt @@ -599,12 +607,12 @@ def resolve_input_arg_const(input_arg_idx): for branch, condition, blk in branch_info: const_conds = [] - if isinstance(condition, ir.Expr) and condition.op == "binop": + if isinstance(condition, ir.expr_types) and condition.op == "binop": prune = prune_by_value for arg in [condition.lhs, condition.rhs]: resolved_const = Unknown() arg_def = guard(get_definition, func_ir, arg) - if isinstance(arg_def, ir.Arg): + if isinstance(arg_def, ir.arg_types): # it's an e.g. literal argument to the function resolved_const = resolve_input_arg_const(arg_def.index) prune = prune_by_type @@ -668,7 +676,7 @@ def resolve_input_arg_const(input_arg_idx): for _, cond, blk in branch_info: if cond in deadcond: for x in blk.body: - if isinstance(x, ir.Assign) and x.value is cond: + if isinstance(x, ir.assign_types) and x.value is cond: # rewrite the condition as a true/false bit nullified_info = nullified_conditions[deadcond.index(cond)] # only do a rewrite of conditions, predicates need to retain diff --git a/numba_cuda/numba/cuda/core/annotations/type_annotations.py b/numba_cuda/numba/cuda/core/annotations/type_annotations.py index 2cb520ec5..8c426322c 100644 --- a/numba_cuda/numba/cuda/core/annotations/type_annotations.py +++ b/numba_cuda/numba/cuda/core/annotations/type_annotations.py @@ -94,16 +94,16 @@ def prepare_annotations(self): for inst in blk.body: lineno = inst.loc.line - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): if found_lifted_loop: atype = "XXX Lifted Loop XXX" found_lifted_loop = False elif ( - isinstance(inst.value, ir.Expr) + isinstance(inst.value, ir.expr_types) and inst.value.op == "call" ): atype = self.calltypes[inst.value] - elif isinstance(inst.value, ir.Const) and isinstance( + elif isinstance(inst.value, ir.const_types) and isinstance( inst.value.value, LiftedLoop ): atype = "XXX Lifted Loop XXX" @@ -113,7 +113,7 @@ def prepare_annotations(self): atype = self.typemap.get(inst.target.name, "") aline = "%s = %s :: %s" % (inst.target, inst.value, atype) - elif isinstance(inst, ir.SetItem): + elif isinstance(inst, ir.setitem_types): atype = self.calltypes[inst] aline = "%s :: %s" % (inst, atype) else: diff --git a/numba_cuda/numba/cuda/core/consts.py b/numba_cuda/numba/cuda/core/consts.py index a3757c80f..760e7e540 100644 --- a/numba_cuda/numba/cuda/core/consts.py +++ b/numba_cuda/numba/cuda/core/consts.py @@ -68,7 +68,7 @@ def _do_infer(self, name): try: const = defn.infer_constant() except ConstantInferenceError: - if isinstance(defn, ir.Expr): + if isinstance(defn, ir.expr_types): return self._infer_expr(defn) self._fail(defn) return const diff --git a/numba_cuda/numba/cuda/core/inline_closurecall.py b/numba_cuda/numba/cuda/core/inline_closurecall.py index 4b5dd09f7..964b9aceb 100644 --- a/numba_cuda/numba/cuda/core/inline_closurecall.py +++ b/numba_cuda/numba/cuda/core/inline_closurecall.py @@ -54,8 +54,8 @@ def callee_ir_validator(func_ir): """Checks the IR of a callee is supported for inlining""" for blk in func_ir.blocks.values(): - for stmt in blk.find_insts(ir.Assign): - if isinstance(stmt.value, ir.Yield): + for stmt in blk.find_insts(ir.assign_types): + if isinstance(stmt.value, ir.yield_types): msg = "The use of yield in a closure is unsupported." raise errors.UnsupportedError(msg, loc=stmt.loc) @@ -100,9 +100,9 @@ def run(self): while work_list: _label, block = work_list.pop() for i, instr in enumerate(block.body): - if isinstance(instr, ir.Assign): + if isinstance(instr, ir.assign_types): expr = instr.value - if isinstance(expr, ir.Expr) and expr.op == "call": + if isinstance(expr, ir.expr_types) and expr.op == "call": call_name = guard(find_callname, self.func_ir, expr) func_def = guard( get_definition, self.func_ir, expr.func @@ -212,7 +212,8 @@ def reduce_func(f, A, v=None): def _inline_closure(self, work_list, block, i, func_def): require( - isinstance(func_def, ir.Expr) and func_def.op == "make_function" + isinstance(func_def, ir.expr_types) + and func_def.op == "make_function" ) inline_closure_call( self.func_ir, @@ -235,7 +236,9 @@ def check_reduce_func(func_ir, func_var): "Reduce function cannot be found for njit \ analysis" ) - if isinstance(reduce_func, (ir.FreeVar, ir.Global)): + if isinstance(reduce_func, ir.freevar_types) or isinstance( + reduce_func, ir.global_types + ): if HAS_NUMBA: from numba.core.registry import CPUDispatcher @@ -658,7 +661,10 @@ def inline_closure_call( cellget.argtypes = (ctypes.py_object,) items = tuple(cellget(x) for x in closure) else: - assert isinstance(closure, ir.Expr) and closure.op == "build_tuple" + assert ( + isinstance(closure, ir.expr_types) + and closure.op == "build_tuple" + ) items = closure.items assert len(callee_code.co_freevars) == len(items) _replace_freevars(callee_blocks, items) @@ -786,18 +792,18 @@ def stararg_handler(index, param, default): if isinstance(callee_defaults, tuple): # Python 3.5 defaults_list = [] for x in callee_defaults: - if isinstance(x, ir.Var): + if isinstance(x, ir.var_types): defaults_list.append(x) else: # this branch is predominantly for kwargs from # inlinable functions defaults_list.append(ir.Const(value=x, loc=loc)) args = args + defaults_list - elif isinstance(callee_defaults, ir.Var) or isinstance( + elif isinstance(callee_defaults, ir.var_types) or isinstance( callee_defaults, str ): default_tuple = func_ir.get_definition(callee_defaults) - assert isinstance(default_tuple, ir.Expr) + assert isinstance(default_tuple, ir.expr_types) assert default_tuple.op == "build_tuple" const_vals = [ func_ir.get_definition(x) for x in default_tuple.items @@ -839,9 +845,9 @@ def _replace_args_with(blocks, args): Replace ir.Arg(...) with real arguments from call site """ for label, block in blocks.items(): - assigns = block.find_insts(ir.Assign) + assigns = block.find_insts(ir.assign_types) for stmt in assigns: - if isinstance(stmt.value, ir.Arg): + if isinstance(stmt.value, ir.arg_types): idx = stmt.value.index assert idx < len(args) stmt.value = args[idx] @@ -852,12 +858,12 @@ def _replace_freevars(blocks, args): Replace ir.FreeVar(...) with real variables from parent function """ for label, block in blocks.items(): - assigns = block.find_insts(ir.Assign) + assigns = block.find_insts(ir.assign_types) for stmt in assigns: - if isinstance(stmt.value, ir.FreeVar): + if isinstance(stmt.value, ir.freevar_types): idx = stmt.value.index assert idx < len(args) - if isinstance(args[idx], ir.Var): + if isinstance(args[idx], ir.var_types): stmt.value = args[idx] else: stmt.value = ir.Const(args[idx], stmt.loc) @@ -871,7 +877,7 @@ def _replace_returns(blocks, target, return_label): casts = [] for i in range(len(block.body)): stmt = block.body[i] - if isinstance(stmt, ir.Return): + if isinstance(stmt, ir.return_types): assert i + 1 == len(block.body) block.body[i] = ir.Assign(stmt.value, target, stmt.loc) block.body.append(ir.Jump(return_label, stmt.loc)) @@ -880,8 +886,8 @@ def _replace_returns(blocks, target, return_label): if cast.target.name == stmt.value.name: cast.value = cast.value.value elif ( - isinstance(stmt, ir.Assign) - and isinstance(stmt.value, ir.Expr) + isinstance(stmt, ir.assign_types) + and isinstance(stmt.value, ir.expr_types) and stmt.value.op == "cast" ): casts.append(stmt) @@ -892,7 +898,7 @@ def _add_definitions(func_ir, block): Add variable definitions found in a block to parent func_ir. """ definitions = func_ir._definitions - assigns = block.find_insts(ir.Assign) + assigns = block.find_insts(ir.assign_types) for stmt in assigns: definitions[stmt.target.name].append(stmt.value) @@ -910,27 +916,27 @@ def _find_arraycall(func_ir, block): i = 0 while i < len(block.body): instr = block.body[i] - if isinstance(instr, ir.Del): + if isinstance(instr, ir.del_types): # Stop the process if list_var becomes dead if list_var and array_var and instr.value == list_var.name: list_var_dead_after_array_call = True break pass - elif isinstance(instr, ir.Assign): + elif isinstance(instr, ir.assign_types): # Found array_var = array(list_var) lhs = instr.target expr = instr.value if guard(find_callname, func_ir, expr) == ( "array", "numpy", - ) and isinstance(expr.args[0], ir.Var): + ) and isinstance(expr.args[0], ir.var_types): list_var = expr.args[0] array_var = lhs array_stmt_index = i array_kws = dict(expr.kws) elif ( - isinstance(instr, ir.SetItem) - and isinstance(instr.value, ir.Var) + isinstance(instr, ir.setitem_types) + and isinstance(instr.value, ir.var_types) and not list_var ): list_var = instr.value @@ -958,16 +964,17 @@ def _find_iter_range(func_ir, range_iter_var, swapped): range_iter_def = get_definition(func_ir, range_iter_var) debug_print("range_iter_var = ", range_iter_var, " def = ", range_iter_def) require( - isinstance(range_iter_def, ir.Expr) and range_iter_def.op == "getiter" + isinstance(range_iter_def, ir.expr_types) + and range_iter_def.op == "getiter" ) range_var = range_iter_def.value range_def = get_definition(func_ir, range_var) debug_print("range_var = ", range_var, " range_def = ", range_def) - require(isinstance(range_def, ir.Expr) and range_def.op == "call") + require(isinstance(range_def, ir.expr_types) and range_def.op == "call") func_var = range_def.func func_def = get_definition(func_ir, func_var) debug_print("func_var = ", func_var, " func_def = ", func_def) - require(isinstance(func_def, ir.Global) and func_def.value is range) + require(isinstance(func_def, ir.global_types) and func_def.value is range) nargs = len(range_def.args) swapping = [('"array comprehension"', "closure of"), range_def.func.loc] if nargs == 1: @@ -1082,20 +1089,23 @@ def _inline_arraycall( dtype_def = None dtype_mod_def = None if "dtype" in array_kws: - require(isinstance(array_kws["dtype"], ir.Var)) + require(isinstance(array_kws["dtype"], ir.var_types)) # We require that dtype argument to be a constant of getattr Expr, and # we'll remember its definition for later use. dtype_def = get_definition(func_ir, array_kws["dtype"]) - require(isinstance(dtype_def, ir.Expr) and dtype_def.op == "getattr") + require( + isinstance(dtype_def, ir.expr_types) and dtype_def.op == "getattr" + ) dtype_mod_def = get_definition(func_ir, dtype_def.value) list_var_def = get_definition(func_ir, list_var) debug_print("list_var = ", list_var, " def = ", list_var_def) - if isinstance(list_var_def, ir.Expr) and list_var_def.op == "cast": + if isinstance(list_var_def, ir.expr_types) and list_var_def.op == "cast": list_var_def = get_definition(func_ir, list_var_def.value) # Check if the definition is a build_list require( - isinstance(list_var_def, ir.Expr) and list_var_def.op == "build_list" + isinstance(list_var_def, ir.expr_types) + and list_var_def.op == "build_list" ) # The build_list must be empty require(len(list_var_def.items) == 0) @@ -1112,12 +1122,12 @@ def _inline_arraycall( continue block = func_ir.blocks[label] debug_print("check loop body block ", label) - for stmt in block.find_insts(ir.Assign): + for stmt in block.find_insts(ir.assign_types): expr = stmt.value - if isinstance(expr, ir.Expr) and expr.op == "call": + if isinstance(expr, ir.expr_types) and expr.op == "call": func_def = get_definition(func_ir, expr.func) if ( - isinstance(func_def, ir.Expr) + isinstance(func_def, ir.expr_types) and func_def.op == "getattr" and func_def.attr == "append" ): @@ -1146,9 +1156,9 @@ def _inline_arraycall( iter_vars = [] iter_first_vars = [] loop_header = func_ir.blocks[loop.header] - for stmt in loop_header.find_insts(ir.Assign): + for stmt in loop_header.find_insts(ir.assign_types): expr = stmt.value - if isinstance(expr, ir.Expr): + if isinstance(expr, ir.expr_types): if expr.op == "iternext": iter_def = get_definition(func_ir, expr.value) debug_print("iter_def = ", iter_def) @@ -1175,7 +1185,7 @@ def _inline_arraycall( removed = [] def is_removed(val, removed): - if isinstance(val, ir.Var): + if isinstance(val, ir.var_types): for x in removed: if x.name == val.name: return True @@ -1184,7 +1194,7 @@ def is_removed(val, removed): # Skip list construction and skip terminator, add the rest to stmts for i in range(len(loop_entry.body) - 1): stmt = loop_entry.body[i] - if isinstance(stmt, ir.Assign) and ( + if isinstance(stmt, ir.assign_types) and ( stmt.value is list_def or is_removed(stmt.value, removed) ): removed.append(stmt.target) @@ -1319,7 +1329,7 @@ def is_removed(val, removed): # when range doesn't start from 0, index_var becomes loop index # (iter_first_var) minus an offset (range_def[0]) terminator = loop_header.terminator - assert isinstance(terminator, ir.Branch) + assert isinstance(terminator, ir.branch_types) # find the block in the loop body that header jumps to block_id = terminator.truebr blk = func_ir.blocks[block_id] @@ -1377,7 +1387,9 @@ def is_removed(val, removed): # replace array call, by changing "a = array(b)" to "a = b" stmt = func_ir.blocks[exit_block].body[array_call_index] # stmt can be either array call or SetItem, we only replace array call - if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr): + if isinstance(stmt, ir.assign_types) and isinstance( + stmt.value, ir.expr_types + ): stmt.value = array_var func_ir._definitions[stmt.target.name] = [stmt.value] @@ -1386,10 +1398,10 @@ def is_removed(val, removed): def _find_unsafe_empty_inferred(func_ir, expr): unsafe_empty_inferred - require(isinstance(expr, ir.Expr) and expr.op == "call") + require(isinstance(expr, ir.expr_types) and expr.op == "call") callee = expr.func callee_def = get_definition(func_ir, callee) - require(isinstance(callee_def, ir.Global)) + require(isinstance(callee_def, ir.global_types)) _make_debug_print("_find_unsafe_empty_inferred")(callee_def.value) return callee_def.value == unsafe_empty_inferred @@ -1411,7 +1423,7 @@ def find_array_def(arr): """ arr_def = get_definition(func_ir, arr) _make_debug_print("find_array_def")(arr, arr_def) - if isinstance(arr_def, ir.Expr): + if isinstance(arr_def, ir.expr_types): if guard(_find_unsafe_empty_inferred, func_ir, arr_def): return arr_def elif arr_def.op == "getitem": @@ -1430,7 +1442,7 @@ def fix_dependencies(expr, varlist): defined = set() for i in range(len(body)): inst = body[i] - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): defined.add(inst.target.name) if inst.value is expr: new_varlist = [] @@ -1446,7 +1458,7 @@ def fix_dependencies(expr, varlist): else: debug_print(var.name, " not yet defined") var_def = get_definition(func_ir, var.name) - if isinstance(var_def, ir.Const): + if isinstance(var_def, ir.const_types): loc = var.loc new_var = scope.redefine("new_var", loc) new_const = ir.Const(var_def.value, loc) @@ -1475,8 +1487,8 @@ def fix_array_assign(stmt): 3. replace the definition of rhs = numba.unsafe.ndarray.empty_inferred(...) with rhs = lhs[idx] """ - require(isinstance(stmt, ir.SetItem)) - require(isinstance(stmt.value, ir.Var)) + require(isinstance(stmt, ir.setitem_types)) + require(isinstance(stmt.value, ir.var_types)) debug_print = _make_debug_print("fix_array_assign") debug_print("found SetItem: ", stmt) lhs = stmt.target @@ -1485,14 +1497,16 @@ def fix_array_assign(stmt): debug_print("found lhs_def: ", lhs_def) rhs_def = get_definition(func_ir, stmt.value) debug_print("found rhs_def: ", rhs_def) - require(isinstance(rhs_def, ir.Expr)) + require(isinstance(rhs_def, ir.expr_types)) if rhs_def.op == "cast": rhs_def = get_definition(func_ir, rhs_def.value) - require(isinstance(rhs_def, ir.Expr)) + require(isinstance(rhs_def, ir.expr_types)) require(_find_unsafe_empty_inferred(func_ir, rhs_def)) # Find the array dimension of rhs dim_def = get_definition(func_ir, rhs_def.args[0]) - require(isinstance(dim_def, ir.Expr) and dim_def.op == "build_tuple") + require( + isinstance(dim_def, ir.expr_types) and dim_def.op == "build_tuple" + ) debug_print("dim_def = ", dim_def) extra_dims = [ get_definition(func_ir, x, lhs_only=True) for x in dim_def.items @@ -1501,7 +1515,7 @@ def fix_array_assign(stmt): # Expand size tuple when creating lhs_def with extra_dims size_tuple_def = get_definition(func_ir, lhs_def.args[0]) require( - isinstance(size_tuple_def, ir.Expr) + isinstance(size_tuple_def, ir.expr_types) and size_tuple_def.op == "build_tuple" ) debug_print("size_tuple_def = ", size_tuple_def) @@ -1719,13 +1733,13 @@ def list_var_used(self, inst): state = State() for inst in block.body: - if isinstance(inst, ir.Assign): - if isinstance(inst.value, ir.Var): + if isinstance(inst, ir.assign_types): + if isinstance(inst.value, ir.var_types): if inst.value.name in state.list_vars: state.list_vars.append(inst.target.name) state.stmts.append(inst) continue - elif isinstance(inst.value, ir.Expr): + elif isinstance(inst.value, ir.expr_types): expr = inst.value if expr.op == "build_list": # new build_list encountered, reset state @@ -1745,7 +1759,7 @@ def list_var_used(self, inst): ): state.modified = True continue - elif isinstance(inst, ir.Del): + elif isinstance(inst, ir.del_types): removed_var = inst.value if removed_var in state.list_items: state.dels.append(inst) @@ -1764,10 +1778,10 @@ def list_var_used(self, inst): body = [] for inst in state.stmts: if ( - isinstance(inst, ir.Assign) + isinstance(inst, ir.assign_types) and inst.target.name in state.dead_vars ) or ( - isinstance(inst, ir.Del) + isinstance(inst, ir.del_types) and inst.value in state.dead_vars ): continue diff --git a/numba_cuda/numba/cuda/core/interpreter.py b/numba_cuda/numba/cuda/core/interpreter.py index d63f6cf11..90108b231 100644 --- a/numba_cuda/numba/cuda/core/interpreter.py +++ b/numba_cuda/numba/cuda/core/interpreter.py @@ -217,8 +217,8 @@ def _call_function_ex_replace_kws_large( # The first value must be a constant. const_stmt = old_body[search_start] if not ( - isinstance(const_stmt, ir.Assign) - and isinstance(const_stmt.value, ir.Const) + isinstance(const_stmt, ir.assign_types) + and isinstance(const_stmt.value, ir.const_types) ): # We cannot handle this format so raise the # original error message. @@ -231,8 +231,8 @@ def _call_function_ex_replace_kws_large( while search_start <= search_end and not found_getattr: getattr_stmt = old_body[search_start] if ( - isinstance(getattr_stmt, ir.Assign) - and isinstance(getattr_stmt.value, ir.Expr) + isinstance(getattr_stmt, ir.assign_types) + and isinstance(getattr_stmt.value, ir.expr_types) and getattr_stmt.value.op == "getattr" and (getattr_stmt.value.value.name == buildmap_name) and getattr_stmt.value.attr == "__setitem__" @@ -262,8 +262,8 @@ def _call_function_ex_replace_kws_large( raise UnsupportedBytecodeError(errmsg) setitem_stmt = old_body[search_start + 1] if not ( - isinstance(setitem_stmt, ir.Assign) - and isinstance(setitem_stmt.value, ir.Expr) + isinstance(setitem_stmt, ir.assign_types) + and isinstance(setitem_stmt.value, ir.expr_types) and setitem_stmt.value.op == "call" and (setitem_stmt.value.func.name == getattr_stmt.target.name) and len(setitem_stmt.value.args) == 2 @@ -366,8 +366,8 @@ def _call_function_ex_replace_args_large( # tuple. search_start = 0 total_args = [] - if isinstance(vararg_stmt, ir.Assign) and isinstance( - vararg_stmt.value, ir.Var + if isinstance(vararg_stmt, ir.assign_types) and isinstance( + vararg_stmt.value, ir.var_types ): target_name = vararg_stmt.value.name # If there is an initial assignment, delete it @@ -387,9 +387,9 @@ def _call_function_ex_replace_args_large( while search_end >= search_start: concat_stmt = old_body[search_end] if ( - isinstance(concat_stmt, ir.Assign) + isinstance(concat_stmt, ir.assign_types) and concat_stmt.target.name == target_name - and isinstance(concat_stmt.value, ir.Expr) + and isinstance(concat_stmt.value, ir.expr_types) and concat_stmt.value.op == "build_tuple" and not concat_stmt.value.items ): @@ -404,9 +404,9 @@ def _call_function_ex_replace_args_large( # We expect to find another arg to append. # The first stmt must be a binop "add" if (search_end == search_start) or not ( - isinstance(concat_stmt, ir.Assign) + isinstance(concat_stmt, ir.assign_types) and (concat_stmt.target.name == target_name) - and isinstance(concat_stmt.value, ir.Expr) + and isinstance(concat_stmt.value, ir.expr_types) and concat_stmt.value.op == "binop" and concat_stmt.value.fn == operator.add ): @@ -418,8 +418,8 @@ def _call_function_ex_replace_args_large( # build_tuple containing the arg. arg_tuple_stmt = old_body[search_end - 1] if not ( - isinstance(arg_tuple_stmt, ir.Assign) - and isinstance(arg_tuple_stmt.value, ir.Expr) + isinstance(arg_tuple_stmt, ir.assign_types) + and isinstance(arg_tuple_stmt.value, ir.expr_types) and (arg_tuple_stmt.value.op == "build_tuple") and len(arg_tuple_stmt.value.items) == 1 ): @@ -448,7 +448,7 @@ def _call_function_ex_replace_args_large( keep_looking = True while search_end >= search_start and keep_looking: next_stmt = old_body[search_end] - if isinstance(next_stmt, ir.Assign) and ( + if isinstance(next_stmt, ir.assign_types) and ( next_stmt.target.name == target_name ): keep_looking = False @@ -522,8 +522,8 @@ def peep_hole_call_function_ex_to_call_function_kw(func_ir): new_body = [] for i, stmt in enumerate(blk.body): if ( - isinstance(stmt, ir.Assign) - and isinstance(stmt.value, ir.Expr) + isinstance(stmt, ir.assign_types) + and isinstance(stmt.value, ir.expr_types) and stmt.value.op == "call" and stmt.value.varkwarg is not None ): @@ -548,7 +548,7 @@ def peep_hole_call_function_ex_to_call_function_kw(func_ir): while varkwarg_loc >= 0 and not found: keyword_def = blk.body[varkwarg_loc] if ( - isinstance(keyword_def, ir.Assign) + isinstance(keyword_def, ir.assign_types) and keyword_def.target.name == varkwarg.name ): found = True @@ -558,7 +558,7 @@ def peep_hole_call_function_ex_to_call_function_kw(func_ir): kws or not found or not ( - isinstance(keyword_def.value, ir.Expr) + isinstance(keyword_def.value, ir.expr_types) and keyword_def.value.op == "build_map" ) ): @@ -624,7 +624,7 @@ def peep_hole_call_function_ex_to_call_function_kw(func_ir): while vararg_loc >= 0 and not found: args_def = blk.body[vararg_loc] if ( - isinstance(args_def, ir.Assign) + isinstance(args_def, ir.assign_types) and args_def.target.name == vararg.name ): found = True @@ -635,7 +635,7 @@ def peep_hole_call_function_ex_to_call_function_kw(func_ir): # then we can't handle this format. raise UnsupportedBytecodeError(errmsg) if ( - isinstance(args_def.value, ir.Expr) + isinstance(args_def.value, ir.expr_types) and args_def.value.op == "build_tuple" ): # n_args <= 30 case. @@ -656,7 +656,7 @@ def peep_hole_call_function_ex_to_call_function_kw(func_ir): already_deleted_defs, ) elif ( - isinstance(args_def.value, ir.Expr) + isinstance(args_def.value, ir.expr_types) and args_def.value.op == "list_to_tuple" ): # If there is a call with vararg we need to check @@ -708,8 +708,8 @@ def peep_hole_call_function_ex_to_call_function_kw(func_ir): # Update the definition func_ir._definitions[stmt.target.name].append(new_call) elif ( - isinstance(stmt, ir.Assign) - and isinstance(stmt.value, ir.Expr) + isinstance(stmt, ir.assign_types) + and isinstance(stmt.value, ir.expr_types) and stmt.value.op == "call" and stmt.value.vararg is not None ): @@ -725,7 +725,10 @@ def peep_hole_call_function_ex_to_call_function_kw(func_ir): # If this value is still a list to tuple raise the # exception. expr = func_ir._definitions[vararg_name][0] - if isinstance(expr, ir.Expr) and expr.op == "list_to_tuple": + if ( + isinstance(expr, ir.expr_types) + and expr.op == "list_to_tuple" + ): raise UnsupportedBytecodeError(errmsg) new_body.append(stmt) @@ -780,17 +783,17 @@ def find_postive_region(): found = False for idx in reversed(range(len(blk.body))): stmt = blk.body[idx] - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): value = stmt.value if ( - isinstance(value, ir.Expr) + isinstance(value, ir.expr_types) and value.op == "list_to_tuple" ): target_list = value.info[0] found = True bt = (idx, stmt) if found: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): if stmt.target.name == target_list: region = (bt, (idx, stmt)) return region @@ -812,8 +815,8 @@ def find_postive_region(): # Walk through the peep_hole and find things that are being # "extend"ed and "append"ed to the BUILD_LIST for x in peep_hole: - if isinstance(x, ir.Assign): - if isinstance(x.value, ir.Expr): + if isinstance(x, ir.assign_types): + if isinstance(x.value, ir.expr_types): expr = x.value if ( expr.op == "getattr" @@ -857,8 +860,8 @@ def append_and_fix(x): t2l_agn = region[0][1] acc = the_build_list for x in peep_hole: - if isinstance(x, ir.Assign): - if isinstance(x.value, ir.Expr): + if isinstance(x, ir.assign_types): + if isinstance(x.value, ir.expr_types): expr = x.value if expr.op == "getattr": if ( @@ -877,7 +880,7 @@ def append_and_fix(x): fname = expr.func.name if fname in extends or fname in appends: arg = expr.args[0] - if isinstance(arg, ir.Var): + if isinstance(arg, ir.var_types): tmp_name = "%s_var_%s" % ( fname, arg.name, @@ -997,7 +1000,7 @@ def peep_hole_delete_with_exit(func_ir): # Any assignment that uses any of the dead variable is considered # dead. if used & dead_vars: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): dead_vars.add(stmt.target) new_body = [] @@ -1112,7 +1115,9 @@ def peep_hole_fuse_dict_add_updates(func_ir): # vars in statement. This is always the lhs with # a build_map. stmt_build_map_out = None - if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr): + if isinstance(stmt, ir.assign_types) and isinstance( + stmt.value, ir.expr_types + ): if stmt.value.op == "build_map": # Skip the output build_map when looking for used vars. stmt_build_map_out = stmt.target.name @@ -1130,9 +1135,9 @@ def peep_hole_fuse_dict_add_updates(func_ir): getattr_stmt = blk.body[i - 1] args = stmt.value.args if ( - isinstance(getattr_stmt, ir.Assign) + isinstance(getattr_stmt, ir.assign_types) and getattr_stmt.target.name == func_name - and isinstance(getattr_stmt.value, ir.Expr) + and isinstance(getattr_stmt.value, ir.expr_types) and getattr_stmt.value.op == "getattr" and getattr_stmt.value.attr in ("__setitem__", "_update_from_bytecode") @@ -1207,8 +1212,8 @@ def peep_hole_fuse_dict_add_updates(func_ir): # will be removed when handling their call in the next # iteration. if not ( - isinstance(stmt, ir.Assign) - and isinstance(stmt.value, ir.Expr) + isinstance(stmt, ir.assign_types) + and isinstance(stmt.value, ir.expr_types) and stmt.value.op == "getattr" and stmt.value.value.name in lit_map_use_idx and stmt.value.attr in ("__setitem__", "_update_from_bytecode") @@ -1249,7 +1254,7 @@ def peep_hole_split_at_pop_block(func_ir): # Gather locations of PopBlock pop_block_locs = [] for i, inst in enumerate(blk.body): - if isinstance(inst, ir.PopBlock): + if isinstance(inst, ir.popblock_types): pop_block_locs.append(i) # Rewrite block with PopBlock if pop_block_locs: @@ -1301,10 +1306,14 @@ def _build_new_build_map(func_ir, name, old_body, old_lineno, new_items): for pair in new_items: k, v = pair key_def = ir_utils.guard(ir_utils.get_definition, func_ir, k) - if isinstance(key_def, (ir.Const, ir.Global, ir.FreeVar)): + if isinstance( + key_def, ir.const_types + ir.global_types + ir.freevar_types + ): literal_keys.append(key_def.value) value_def = ir_utils.guard(ir_utils.get_definition, func_ir, v) - if isinstance(value_def, (ir.Const, ir.Global, ir.FreeVar)): + if isinstance( + value_def, ir.const_types + ir.global_types + ir.freevar_types + ): values.append(value_def.value) else: # Append unknown value if not a literal. @@ -1514,7 +1523,7 @@ def _legalize_exception_vars(self): # Propagate the exception variables to LHS of assignment for varname, defnvars in self.definitions.items(): for v in defnvars: - if isinstance(v, ir.Var): + if isinstance(v, ir.var_types): k = v.name if k in excvars: excvars.add(varname) @@ -1587,7 +1596,7 @@ def _start_new_block(self, offset): while self.syntax_blocks: if offset >= self.syntax_blocks[-1].exit: synblk = self.syntax_blocks.pop() - if isinstance(synblk, ir.With): + if isinstance(synblk, ir.with_types): self.current_block.append(ir.PopBlock(self.loc)) else: break @@ -1684,11 +1693,11 @@ def _remove_unused_temporaries(self): # like a = b[i] = 1, so need to handle replaced temporaries in # later setitem/setattr nodes if ( - isinstance(inst, (ir.SetItem, ir.SetAttr)) + isinstance(inst, ir.setitem_types + ir.setattr_types) and inst.value.name in replaced_var ): inst.value = replaced_var[inst.value.name] - elif isinstance(inst, ir.Assign): + elif isinstance(inst, ir.assign_types): if ( inst.target.is_temp and inst.target.name in self.assigner.unused_dests @@ -1698,7 +1707,7 @@ def _remove_unused_temporaries(self): # like a = b = 1, so need to handle replaced temporaries in # later assignments if ( - isinstance(inst.value, ir.Var) + isinstance(inst.value, ir.var_types) and inst.value.name in replaced_var ): inst.value = replaced_var[inst.value.name] @@ -1707,7 +1716,7 @@ def _remove_unused_temporaries(self): # chained unpack cases may reuse temporary # e.g. a = (b, c) = (x, y) if ( - isinstance(inst.value, ir.Expr) + isinstance(inst.value, ir.expr_types) and inst.value.op == "exhaust_iter" and inst.value.value.name in replaced_var ): @@ -1720,10 +1729,10 @@ def _remove_unused_temporaries(self): # the temporary variable is not reused elsewhere since CPython # bytecode is stack-based and this pattern corresponds to a pop if ( - isinstance(inst.value, ir.Var) + isinstance(inst.value, ir.var_types) and inst.value.is_temp and new_body - and isinstance(new_body[-1], ir.Assign) + and isinstance(new_body[-1], ir.assign_types) ): prev_assign = new_body[-1] # _var_used_in_binop check makes sure we don't create a new @@ -1753,7 +1762,7 @@ def _var_used_in_binop(self, varname, expr): in it as an argument """ return ( - isinstance(expr, ir.Expr) + isinstance(expr, ir.expr_types) and expr.op in ("binop", "inplace_binop") and (varname == expr.lhs.name or varname == expr.rhs.name) ) @@ -1832,7 +1841,7 @@ def _dispatch(self, inst, kws): if PYVERSION in ((3, 11), (3, 12), (3, 13)): if self.syntax_blocks: top = self.syntax_blocks[-1] - if isinstance(top, ir.With): + if isinstance(top, ir.with_types): if inst.offset >= top.exit: self.current_block.append(ir.PopBlock(loc=self.loc)) self.syntax_blocks.pop() @@ -1881,7 +1890,7 @@ def store(self, value, name, redefine=False): ) else: target = self.current_scope.get_or_define(name, loc=self.loc) - if isinstance(value, ir.Var): + if isinstance(value, ir.var_types): value = self.assigner.assign(value, target) stmt = ir.Assign(value=value, target=target, loc=self.loc) self.current_block.append(stmt) @@ -2649,7 +2658,7 @@ def op_CALL_FUNCTION_KW(self, inst, func, args, names, res): # Find names const names = self.get(names) for inst in self.current_block.body: - if isinstance(inst, ir.Assign) and inst.target is names: + if isinstance(inst, ir.assign_types) and inst.target is names: self.current_block.remove(inst) # scan up the block looking for the values, remove them # and find their name strings @@ -2750,7 +2759,7 @@ def op_BUILD_CONST_KEY_MAP(self, inst, keys, keytmps, values, res): keyvar = self.get(keys) # TODO: refactor this pattern. occurred several times. for inst in self.current_block.body: - if isinstance(inst, ir.Assign) and inst.target is keyvar: + if isinstance(inst, ir.assign_types) and inst.target is keyvar: self.current_block.remove(inst) # scan up the block looking for the values, remove them # and find their name strings @@ -2776,7 +2785,7 @@ def op_BUILD_CONST_KEY_MAP(self, inst, keys, keytmps, values, res): if len(defns) != 1: break defn = defns[0] - if not isinstance(defn, ir.Const): + if not isinstance(defn, ir.const_types): break literal_items.append(defn.value) @@ -2785,7 +2794,7 @@ def resolve_const(v): if len(defns) != 1: return _UNKNOWN_VALUE(self.get(v).name) defn = defns[0] - if not isinstance(defn, ir.Const): + if not isinstance(defn, ir.const_types): return _UNKNOWN_VALUE(self.get(v).name) return defn.value @@ -2921,7 +2930,7 @@ def get_literals(target): if len(defns) != 1: break defn = defns[0] - if not isinstance(defn, ir.Const): + if not isinstance(defn, ir.const_types): break literal_items.append(defn.value) return literal_items @@ -3196,7 +3205,7 @@ def op_CONTAINS_OP(self, inst, lhs, rhs, res): def op_BREAK_LOOP(self, inst, end=None): if end is None: loop = self.syntax_blocks[-1] - assert isinstance(loop, ir.Loop) + assert isinstance(loop, ir.loop_types) end = loop.exit jmp = ir.Jump(target=end, loc=self.loc) self.current_block.append(jmp) @@ -3408,7 +3417,7 @@ def op_MAKE_FUNCTION( defaults = self.get(defaults) assume_code_const = self.definitions[code][0] - if not isinstance(assume_code_const, ir.Const): + if not isinstance(assume_code_const, ir.const_types): msg = ( "Unsupported use of closure. " "Probably caused by complex control-flow constructs; " @@ -3504,23 +3513,29 @@ def op_LIST_EXTEND(self, inst, target, value, extendvar, res): # is last emitted statement a build_tuple? stmt = self.current_block.body[-1] - ok = isinstance(stmt.value, ir.Expr) and stmt.value.op == "build_tuple" + ok = ( + isinstance(stmt.value, ir.expr_types) + and stmt.value.op == "build_tuple" + ) # check statements from self.current_block.body[-1] through to target, # make sure they are consts build_empty_list = None if ok: for stmt in reversed(self.current_block.body[:-1]): - if not isinstance(stmt, ir.Assign): + if not isinstance(stmt, ir.assign_types): ok = False break # if its not a const, it needs to be the `build_list` for the # target, else it's something else we don't know about so just # bail - if isinstance(stmt.value, ir.Const): + if isinstance(stmt.value, ir.const_types): continue # it's not a const, check for target - elif isinstance(stmt.value, ir.Expr) and stmt.target == target: + elif ( + isinstance(stmt.value, ir.expr_types) + and stmt.target == target + ): build_empty_list = stmt # it's only ok to do this if the target has no initializer # already diff --git a/numba_cuda/numba/cuda/core/ir.py b/numba_cuda/numba/cuda/core/ir.py index 807733fea..2f66ee0c1 100644 --- a/numba_cuda/numba/cuda/core/ir.py +++ b/numba_cuda/numba/cuda/core/ir.py @@ -337,9 +337,9 @@ def _rec_list_vars(self, val): """ A recursive helper used to implement list_vars() in subclasses. """ - if isinstance(val, Var): + if isinstance(val, var_types): return [val] - elif isinstance(val, Inst): + elif isinstance(val, inst_types): return val.list_vars() elif isinstance(val, (list, tuple)): lst = [] @@ -396,7 +396,7 @@ class Expr(Inst): def __init__(self, op, loc, **kws): assert isinstance(op, str) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) self.op = op self.loc = loc self._kws = kws @@ -415,9 +415,9 @@ def __setattr__(self, name, value): @classmethod def binop(cls, fn, lhs, rhs, loc): assert isinstance(fn, BuiltinFunctionType) - assert isinstance(lhs, Var) - assert isinstance(rhs, Var) - assert isinstance(loc, Loc) + assert isinstance(lhs, var_types) + assert isinstance(rhs, var_types) + assert isinstance(loc, loc_types) op = "binop" return cls( op=op, @@ -433,9 +433,9 @@ def binop(cls, fn, lhs, rhs, loc): def inplace_binop(cls, fn, immutable_fn, lhs, rhs, loc): assert isinstance(fn, BuiltinFunctionType) assert isinstance(immutable_fn, BuiltinFunctionType) - assert isinstance(lhs, Var) - assert isinstance(rhs, Var) - assert isinstance(loc, Loc) + assert isinstance(lhs, var_types) + assert isinstance(rhs, var_types) + assert isinstance(loc, loc_types) op = "inplace_binop" return cls( op=op, @@ -450,8 +450,9 @@ def inplace_binop(cls, fn, immutable_fn, lhs, rhs, loc): @classmethod def unary(cls, fn, value, loc): - assert isinstance(value, (str, Var, FunctionType)) - assert isinstance(loc, Loc) + if not isinstance(value, (str, FunctionType)): + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) op = "unary" fn = UNARY_BUILTINS_TO_OPERATORS.get(fn, fn) return cls(op=op, loc=loc, fn=fn, value=value) @@ -460,8 +461,8 @@ def unary(cls, fn, value, loc): def call( cls, func, args, kws, loc, vararg=None, varkwarg=None, target=None ): - assert isinstance(func, Var) - assert isinstance(loc, Loc) + assert isinstance(func, var_types) + assert isinstance(loc, loc_types) op = "call" return cls( op=op, @@ -476,25 +477,25 @@ def call( @classmethod def build_tuple(cls, items, loc): - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "build_tuple" return cls(op=op, loc=loc, items=items) @classmethod def build_list(cls, items, loc): - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "build_list" return cls(op=op, loc=loc, items=items) @classmethod def build_set(cls, items, loc): - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "build_set" return cls(op=op, loc=loc, items=items) @classmethod def build_map(cls, items, size, literal_value, value_indexes, loc): - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "build_map" return cls( op=op, @@ -507,68 +508,68 @@ def build_map(cls, items, size, literal_value, value_indexes, loc): @classmethod def pair_first(cls, value, loc): - assert isinstance(value, Var) + assert isinstance(value, var_types) op = "pair_first" return cls(op=op, loc=loc, value=value) @classmethod def pair_second(cls, value, loc): - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) op = "pair_second" return cls(op=op, loc=loc, value=value) @classmethod def getiter(cls, value, loc): - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) op = "getiter" return cls(op=op, loc=loc, value=value) @classmethod def iternext(cls, value, loc): - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) op = "iternext" return cls(op=op, loc=loc, value=value) @classmethod def exhaust_iter(cls, value, count, loc): - assert isinstance(value, Var) + assert isinstance(value, var_types) assert isinstance(count, int) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "exhaust_iter" return cls(op=op, loc=loc, value=value, count=count) @classmethod def getattr(cls, value, attr, loc): - assert isinstance(value, Var) + assert isinstance(value, var_types) assert isinstance(attr, str) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "getattr" return cls(op=op, loc=loc, value=value, attr=attr) @classmethod def getitem(cls, value, index, loc): - assert isinstance(value, Var) - assert isinstance(index, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert isinstance(index, var_types) + assert isinstance(loc, loc_types) op = "getitem" fn = operator.getitem return cls(op=op, loc=loc, value=value, index=index, fn=fn) @classmethod def typed_getitem(cls, value, dtype, index, loc): - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) op = "typed_getitem" return cls(op=op, loc=loc, value=value, dtype=dtype, index=index) @classmethod def static_getitem(cls, value, index, index_var, loc): - assert isinstance(value, Var) - assert index_var is None or isinstance(index_var, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert index_var is None or isinstance(index_var, var_types) + assert isinstance(loc, loc_types) op = "static_getitem" fn = operator.getitem return cls( @@ -580,15 +581,15 @@ def cast(cls, value, loc): """ A node for implicit casting at the return statement """ - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) op = "cast" return cls(op=op, value=value, loc=loc) @classmethod def phi(cls, loc): """Phi node""" - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) return cls(op="phi", incoming_values=[], incoming_blocks=[], loc=loc) @classmethod @@ -596,7 +597,7 @@ def make_function(cls, name, code, closure, defaults, loc): """ A node for making a function object. """ - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "make_function" return cls( op=op, @@ -615,7 +616,7 @@ def null(cls, loc): This node is not handled by type inference. It is only added by post-typing passes. """ - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "null" return cls(op=op, loc=loc) @@ -624,7 +625,7 @@ def undef(cls, loc): """ A node for undefined value specifically from LOAD_FAST_AND_CLEAR opcode. """ - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) op = "undef" return cls(op=op, loc=loc) @@ -638,7 +639,7 @@ def dummy(cls, op, info, loc): by type inference or lowering. It's presence outside of the interpreter renders IR as illegal. """ - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) assert isinstance(op, str) return cls(op=op, info=info, loc=loc) @@ -682,10 +683,10 @@ class SetItem(Stmt): """ def __init__(self, target, index, value, loc): - assert isinstance(target, Var) - assert isinstance(index, Var) - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(target, var_types) + assert isinstance(index, var_types) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) self.target = target self.index = index self.value = value @@ -701,11 +702,11 @@ class StaticSetItem(Stmt): """ def __init__(self, target, index, index_var, value, loc): - assert isinstance(target, Var) - assert not isinstance(index, Var) - assert isinstance(index_var, Var) - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(target, var_types) + assert not isinstance(index, var_types) + assert isinstance(index_var, var_types) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) self.target = target self.index = index self.index_var = index_var @@ -722,9 +723,9 @@ class DelItem(Stmt): """ def __init__(self, target, index, loc): - assert isinstance(target, Var) - assert isinstance(index, Var) - assert isinstance(loc, Loc) + assert isinstance(target, var_types) + assert isinstance(index, var_types) + assert isinstance(loc, loc_types) self.target = target self.index = index self.loc = loc @@ -735,10 +736,10 @@ def __repr__(self): class SetAttr(Stmt): def __init__(self, target, attr, value, loc): - assert isinstance(target, Var) + assert isinstance(target, var_types) assert isinstance(attr, str) - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) self.target = target self.attr = attr self.value = value @@ -750,9 +751,9 @@ def __repr__(self): class DelAttr(Stmt): def __init__(self, target, attr, loc): - assert isinstance(target, Var) + assert isinstance(target, var_types) assert isinstance(attr, str) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) self.target = target self.attr = attr self.loc = loc @@ -763,10 +764,10 @@ def __repr__(self): class StoreMap(Stmt): def __init__(self, dct, key, value, loc): - assert isinstance(dct, Var) - assert isinstance(key, Var) - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(dct, var_types) + assert isinstance(key, var_types) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) self.dct = dct self.key = key self.value = value @@ -779,10 +780,7 @@ def __repr__(self): class Del(Stmt): def __init__(self, value, loc): assert isinstance(value, str) - if HAS_NUMBA: - assert isinstance(loc, (Loc, numba.core.ir.Loc)) - else: - assert isinstance(loc, (Loc)) + assert isinstance(loc, loc_types) self.value = value self.loc = loc @@ -794,8 +792,8 @@ class Raise(Terminator): is_exit = True def __init__(self, exception, loc): - assert exception is None or isinstance(exception, Var) - assert isinstance(loc, Loc) + assert exception is None or isinstance(exception, var_types) + assert isinstance(loc, loc_types) self.exception = exception self.loc = loc @@ -817,7 +815,7 @@ class StaticRaise(Terminator): def __init__(self, exc_class, exc_args, loc): assert exc_class is None or isinstance(exc_class, type) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) assert exc_args is None or isinstance(exc_args, tuple) self.exc_class = exc_class self.exc_args = exc_args @@ -849,7 +847,7 @@ class DynamicRaise(Terminator): def __init__(self, exc_class, exc_args, loc): assert exc_class is None or isinstance(exc_class, type) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) assert exc_args is None or isinstance(exc_args, tuple) self.exc_class = exc_class self.exc_args = exc_args @@ -876,8 +874,8 @@ class TryRaise(Stmt): """ def __init__(self, exception, loc): - assert exception is None or isinstance(exception, Var) - assert isinstance(loc, Loc) + assert exception is None or isinstance(exception, var_types) + assert isinstance(loc, loc_types) self.exception = exception self.loc = loc @@ -892,7 +890,7 @@ class StaticTryRaise(Stmt): def __init__(self, exc_class, exc_args, loc): assert exc_class is None or isinstance(exc_class, type) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) assert exc_args is None or isinstance(exc_args, tuple) self.exc_class = exc_class self.exc_args = exc_args @@ -915,7 +913,7 @@ class DynamicTryRaise(Stmt): def __init__(self, exc_class, exc_args, loc): assert exc_class is None or isinstance(exc_class, type) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) assert exc_args is None or isinstance(exc_args, tuple) self.exc_class = exc_class self.exc_args = exc_args @@ -939,8 +937,8 @@ class Return(Terminator): is_exit = True def __init__(self, value, loc): - assert isinstance(value, Var), type(value) - assert isinstance(loc, Loc) + assert isinstance(value, var_types), type(value) + assert isinstance(loc, loc_types) self.value = value self.loc = loc @@ -957,7 +955,7 @@ class Jump(Terminator): """ def __init__(self, target, loc): - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) self.target = target self.loc = loc @@ -974,8 +972,8 @@ class Branch(Terminator): """ def __init__(self, cond, truebr, falsebr, loc): - assert isinstance(cond, Var) - assert isinstance(loc, Loc) + assert isinstance(cond, var_types) + assert isinstance(loc, loc_types) self.cond = cond self.truebr = truebr self.falsebr = falsebr @@ -994,9 +992,9 @@ class Assign(Stmt): """ def __init__(self, value, target, loc): - assert isinstance(value, AbstractRHS) - assert isinstance(target, Var) - assert isinstance(loc, Loc) + assert isinstance(value, abstractrhs_types) + assert isinstance(target, var_types) + assert isinstance(loc, loc_types) self.value = value self.target = target self.loc = loc @@ -1011,9 +1009,10 @@ class Print(Stmt): """ def __init__(self, args, vararg, loc): - assert all(isinstance(x, Var) for x in args) - assert vararg is None or isinstance(vararg, Var) - assert isinstance(loc, Loc) + assert all(isinstance(x, var_types) for x in args) + if vararg is not None: + assert isinstance(vararg, var_types) + assert isinstance(loc, loc_types) self.args = tuple(args) self.vararg = vararg # Constant-inferred arguments @@ -1026,8 +1025,8 @@ def __str__(self): class Yield(Inst): def __init__(self, value, loc, index): - assert isinstance(value, Var) - assert isinstance(loc, Loc) + assert isinstance(value, var_types) + assert isinstance(loc, loc_types) self.value = value self.loc = loc self.index = index @@ -1052,8 +1051,8 @@ def __init__(self, contextmanager, begin, end, loc): loc : ir.Loc instance Source location """ - assert isinstance(contextmanager, Var) - assert isinstance(loc, Loc) + assert isinstance(contextmanager, var_types) + assert isinstance(loc, loc_types) self.contextmanager = contextmanager self.begin = begin self.end = end @@ -1070,7 +1069,7 @@ class PopBlock(Stmt): """Marker statement for a pop block op code""" def __init__(self, loc): - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) self.loc = loc def __str__(self): @@ -1081,7 +1080,7 @@ class Arg(EqualityCheckMixin, AbstractRHS): def __init__(self, name, index, loc): assert isinstance(name, str) assert isinstance(index, int) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) self.name = name self.index = index self.loc = loc @@ -1095,7 +1094,7 @@ def infer_constant(self): class Const(EqualityCheckMixin, AbstractRHS): def __init__(self, value, loc, use_literal_type=True): - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) self.value = value self.loc = loc # Note: need better way to tell if this is a literal or not. @@ -1118,7 +1117,7 @@ def __deepcopy__(self, memo): class Global(EqualityCheckMixin, AbstractRHS): def __init__(self, name, value, loc): - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) self.name = name self.value = value self.loc = loc @@ -1144,7 +1143,7 @@ class FreeVar(EqualityCheckMixin, AbstractRHS): def __init__(self, index, name, value, loc): assert isinstance(index, int) assert isinstance(name, str) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) # index inside __code__.co_freevars self.index = index # variable name @@ -1180,9 +1179,10 @@ class Var(EqualityCheckMixin, AbstractRHS): def __init__(self, scope, name, loc): # NOTE: Use of scope=None should be removed. - assert scope is None or isinstance(scope, Scope) + if scope is not None: + assert isinstance(scope, scope_types) assert isinstance(name, str) - assert isinstance(loc, Loc) + assert isinstance(loc, loc_types) self.scope = scope self.name = name self.loc = loc @@ -1241,8 +1241,9 @@ class Scope(EqualityCheckMixin): """ def __init__(self, parent, loc): - assert parent is None or isinstance(parent, Scope) - assert isinstance(loc, Loc) + if parent is not None: + assert isinstance(parent, scope_types) + assert isinstance(loc, loc_types) self.parent = parent self.localvars = VarMap() self.loc = loc @@ -1348,12 +1349,8 @@ class Block(EqualityCheckMixin): """A code block""" def __init__(self, scope, loc): - if HAS_NUMBA: - assert isinstance(scope, (Scope, numba.core.ir.Scope)) - assert isinstance(loc, (Loc, numba.core.ir.Loc)) - else: - assert isinstance(scope, Scope) - assert isinstance(loc, Loc) + assert isinstance(scope, scope_types) + assert isinstance(loc, loc_types) self.scope = scope self.body = [] self.loc = loc @@ -1368,9 +1365,9 @@ def find_exprs(self, op=None): Iterate over exprs of the given *op* in this block. """ for inst in self.body: - if isinstance(inst, Assign): + if isinstance(inst, assign_types): expr = inst.value - if isinstance(expr, Expr): + if isinstance(expr, expr_types): if op is None or expr.op == op: yield expr @@ -1387,21 +1384,21 @@ def find_variable_assignment(self, name): Returns the assignment inst associated with variable "name", None if it cannot be found. """ - for x in self.find_insts(cls=Assign): + for x in self.find_insts(cls=assign_types): if x.target.name == name: return x return None def prepend(self, inst): - assert isinstance(inst, Stmt) + assert isinstance(inst, stmt_types) self.body.insert(0, inst) def append(self, inst): - assert isinstance(inst, Stmt) + assert isinstance(inst, stmt_types) self.body.append(inst) def remove(self, inst): - assert isinstance(inst, Stmt) + assert isinstance(inst, stmt_types) del self.body[self.body.index(inst)] def clear(self): @@ -1443,7 +1440,7 @@ def insert_after(self, stmt, other): self.body.insert(index + 1, stmt) def insert_before_terminator(self, stmt): - assert isinstance(stmt, Stmt) + assert isinstance(stmt, stmt_types) assert self.is_terminated self.body.insert(-1, stmt) @@ -1527,8 +1524,12 @@ def diff_str(self, other): if block != other_blk: msg.append(("Block %s differs" % label).center(80, "-")) # see if the instructions are just a permutation - block_del = [x for x in block.body if isinstance(x, Del)] - oth_del = [x for x in other_blk.body if isinstance(x, Del)] + block_del = [ + x for x in block.body if isinstance(x, del_types) + ] + oth_del = [ + x for x in other_blk.body if isinstance(x, del_types) + ] if block_del != oth_del: # this is a common issue, dels are all present, but # order shuffled. @@ -1653,7 +1654,7 @@ def infer_constant(self, name): """ Try to infer the constant value of a given variable. """ - if isinstance(name, Var): + if isinstance(name, var_types): name = name.name return self._consts.infer_constant(name) @@ -1665,7 +1666,7 @@ def get_definition(self, value, lhs_only=False): """ lhs = value while True: - if isinstance(value, Var): + if isinstance(value, var_types): lhs = value name = value.name elif isinstance(value, str): @@ -1692,10 +1693,10 @@ def get_assignee(self, rhs_value, in_blocks=None): else: blocks = [self.blocks[blk] for blk in list(in_blocks)] - assert isinstance(rhs_value, AbstractRHS) + assert isinstance(rhs_value, abstractrhs_types) for blk in blocks: - for assign in blk.find_insts(Assign): + for assign in blk.find_insts(assign_types): if assign.value == rhs_value: return assign.target @@ -1810,3 +1811,74 @@ def __repr__(self): UNDEFINED = UndefinedType() + +if HAS_NUMBA: + abstractrhs_types = (AbstractRHS, numba.core.ir.AbstractRHS) + arg_types = (Arg, numba.core.ir.Arg) + assign_types = (Assign, numba.core.ir.Assign) + block_types = (Block, numba.core.ir.Block) + branch_types = (Branch, numba.core.ir.Branch) + const_types = (Const, numba.core.ir.Const) + del_types = (Del, numba.core.ir.Del) + delattr_types = (DelAttr, numba.core.ir.DelAttr) + delitem_types = (DelItem, numba.core.ir.DelItem) + dynamicraise_types = (DynamicRaise, numba.core.ir.DynamicRaise) + dynamictryraise_types = (DynamicTryRaise, numba.core.ir.DynamicTryRaise) + enterwith_types = (EnterWith, numba.core.ir.EnterWith) + expr_types = (Expr, numba.core.ir.Expr) + freevar_types = (FreeVar, numba.core.ir.FreeVar) + global_types = (Global, numba.core.ir.Global) + inst_types = (Inst, numba.core.ir.Inst) + jump_types = (Jump, numba.core.ir.Jump) + loc_types = (Loc, numba.core.ir.Loc) + popblock_types = (PopBlock, numba.core.ir.PopBlock) + print_types = (Print, numba.core.ir.Print) + raise_types = (Raise, numba.core.ir.Raise) + return_types = (Return, numba.core.ir.Return) + scope_types = (Scope, numba.core.ir.Scope) + setattr_types = (SetAttr, numba.core.ir.SetAttr) + setitem_types = (SetItem, numba.core.ir.SetItem) + staticraise_types = (StaticRaise, numba.core.ir.StaticRaise) + staticsetitem_types = (StaticSetItem, numba.core.ir.StaticSetItem) + statictryraise_types = (StaticTryRaise, numba.core.ir.StaticTryRaise) + stmt_types = (Stmt, numba.core.ir.Stmt) + storemap_types = (StoreMap, numba.core.ir.StoreMap) + tryraise_types = (TryRaise, numba.core.ir.TryRaise) + var_types = (Var, numba.core.ir.Var) + with_types = (With, numba.core.ir.With) + yield_types = (Yield, numba.core.ir.Yield) +else: + abstractrhs_types = (AbstractRHS,) + arg_types = (Arg,) + assign_types = (Assign,) + block_types = (Block,) + branch_types = (Branch,) + const_types = (Const,) + del_types = (Del,) + delattr_types = (DelAttr,) + delitem_types = (DelItem,) + dynamicraise_types = (DynamicRaise,) + dynamictryraise_types = (DynamicTryRaise,) + enterwith_types = (EnterWith,) + expr_types = (Expr,) + freevar_types = (FreeVar,) + global_types = (Global,) + inst_types = (Inst,) + jump_types = (Jump,) + loc_types = (Loc,) + popblock_types = (PopBlock,) + print_types = (Print,) + raise_types = (Raise,) + return_types = (Return,) + scope_types = (Scope,) + setattr_types = (SetAttr,) + setitem_types = (SetItem,) + staticraise_types = (StaticRaise,) + staticsetitem_types = (StaticSetItem,) + statictryraise_types = (StaticTryRaise,) + stmt_types = (Stmt,) + storemap_types = (StoreMap,) + tryraise_types = (TryRaise,) + var_types = (Var,) + with_types = (With,) + yield_types = (Yield,) diff --git a/numba_cuda/numba/cuda/core/ir_utils.py b/numba_cuda/numba/cuda/core/ir_utils.py index 225346985..ea3b6d1d7 100644 --- a/numba_cuda/numba/cuda/core/ir_utils.py +++ b/numba_cuda/numba/cuda/core/ir_utils.py @@ -215,7 +215,7 @@ def convert_size_to_var(size_var, typemap, scope, loc, nodes): size_assign = ir.Assign(ir.Const(size_var, loc), new_size, loc) nodes.append(size_assign) return new_size - assert isinstance(size_var, ir.Var) + assert isinstance(size_var, ir.var_types) return size_var @@ -275,7 +275,7 @@ def mk_range_block(typemap, start, stop, step, calltypes, scope, loc): def _mk_range_args(typemap, start, stop, step, scope, loc): nodes = [] - if isinstance(stop, ir.Var): + if isinstance(stop, ir.var_types): g_stop_var = stop else: assert isinstance(stop, int) @@ -287,7 +287,7 @@ def _mk_range_args(typemap, start, stop, step, scope, loc): if start == 0 and step == 1: return nodes, [g_stop_var] - if isinstance(start, ir.Var): + if isinstance(start, ir.var_types): g_start_var = start else: assert isinstance(start, int) @@ -299,7 +299,7 @@ def _mk_range_args(typemap, start, stop, step, scope, loc): if step == 1: return nodes, [g_start_var, g_stop_var] - if isinstance(step, ir.Var): + if isinstance(step, ir.var_types): g_step_var = step else: assert isinstance(step, int) @@ -395,7 +395,7 @@ def replace_var_names(blocks, namedict): new_namedict[l] = r def replace_name(var, namedict): - assert isinstance(var, ir.Var) + assert isinstance(var, ir.var_types) while var.name in namedict: var = ir.Var(var.scope, namedict[var.name], var.loc) return var @@ -404,7 +404,7 @@ def replace_name(var, namedict): def replace_var_callback(var, vardict): - assert isinstance(var, ir.Var) + assert isinstance(var, ir.var_types) while var.name in vardict.keys(): assert vardict[var.name].name != var.name new_var = vardict[var.name] @@ -451,44 +451,44 @@ def visit_vars_stmt(stmt, callback, cbdata): if isinstance(stmt, t): f(stmt, callback, cbdata) return - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): stmt.target = visit_vars_inner(stmt.target, callback, cbdata) stmt.value = visit_vars_inner(stmt.value, callback, cbdata) - elif isinstance(stmt, ir.Arg): + elif isinstance(stmt, ir.arg_types): stmt.name = visit_vars_inner(stmt.name, callback, cbdata) - elif isinstance(stmt, ir.Return): + elif isinstance(stmt, ir.return_types): stmt.value = visit_vars_inner(stmt.value, callback, cbdata) - elif isinstance(stmt, ir.Raise): + elif isinstance(stmt, ir.raise_types): stmt.exception = visit_vars_inner(stmt.exception, callback, cbdata) - elif isinstance(stmt, ir.Branch): + elif isinstance(stmt, ir.branch_types): stmt.cond = visit_vars_inner(stmt.cond, callback, cbdata) - elif isinstance(stmt, ir.Jump): + elif isinstance(stmt, ir.jump_types): stmt.target = visit_vars_inner(stmt.target, callback, cbdata) - elif isinstance(stmt, ir.Del): + elif isinstance(stmt, ir.del_types): # Because Del takes only a var name, we make up by # constructing a temporary variable. var = ir.Var(None, stmt.value, stmt.loc) var = visit_vars_inner(var, callback, cbdata) stmt.value = var.name - elif isinstance(stmt, ir.DelAttr): + elif isinstance(stmt, ir.delattr_types): stmt.target = visit_vars_inner(stmt.target, callback, cbdata) stmt.attr = visit_vars_inner(stmt.attr, callback, cbdata) - elif isinstance(stmt, ir.SetAttr): + elif isinstance(stmt, ir.setattr_types): stmt.target = visit_vars_inner(stmt.target, callback, cbdata) stmt.attr = visit_vars_inner(stmt.attr, callback, cbdata) stmt.value = visit_vars_inner(stmt.value, callback, cbdata) - elif isinstance(stmt, ir.DelItem): + elif isinstance(stmt, ir.delitem_types): stmt.target = visit_vars_inner(stmt.target, callback, cbdata) stmt.index = visit_vars_inner(stmt.index, callback, cbdata) - elif isinstance(stmt, ir.StaticSetItem): + elif isinstance(stmt, ir.staticsetitem_types): stmt.target = visit_vars_inner(stmt.target, callback, cbdata) stmt.index_var = visit_vars_inner(stmt.index_var, callback, cbdata) stmt.value = visit_vars_inner(stmt.value, callback, cbdata) - elif isinstance(stmt, ir.SetItem): + elif isinstance(stmt, ir.setitem_types): stmt.target = visit_vars_inner(stmt.target, callback, cbdata) stmt.index = visit_vars_inner(stmt.index, callback, cbdata) stmt.value = visit_vars_inner(stmt.value, callback, cbdata) - elif isinstance(stmt, ir.Print): + elif isinstance(stmt, ir.print_types): stmt.args = [visit_vars_inner(x, callback, cbdata) for x in stmt.args] else: # TODO: raise NotImplementedError("no replacement for IR node: ", stmt) @@ -497,13 +497,13 @@ def visit_vars_stmt(stmt, callback, cbdata): def visit_vars_inner(node, callback, cbdata): - if isinstance(node, ir.Var): + if isinstance(node, ir.var_types): return callback(node, cbdata) elif isinstance(node, list): return [visit_vars_inner(n, callback, cbdata) for n in node] elif isinstance(node, tuple): return tuple([visit_vars_inner(n, callback, cbdata) for n in node]) - elif isinstance(node, ir.Expr): + elif isinstance(node, ir.expr_types): # if node.op in ['binop', 'inplace_binop']: # lhs = node.lhs.name # rhs = node.rhs.name @@ -511,7 +511,7 @@ def visit_vars_inner(node, callback, cbdata): # node.rhs.name = callback, cbdata.get(rhs, rhs) for arg in node._kws.keys(): node._kws[arg] = visit_vars_inner(node._kws[arg], callback, cbdata) - elif isinstance(node, ir.Yield): + elif isinstance(node, ir.yield_types): node.value = visit_vars_inner(node.value, callback, cbdata) return node @@ -531,9 +531,9 @@ def add_offset_to_labels(blocks, offset): for T, f in add_offset_to_labels_extensions.items(): if isinstance(inst, T): f(inst, offset) - if isinstance(term, ir.Jump): + if isinstance(term, ir.jump_types): b.body[-1] = ir.Jump(term.target + offset, term.loc) - if isinstance(term, ir.Branch): + if isinstance(term, ir.branch_types): b.body[-1] = ir.Branch( term.cond, term.truebr + offset, term.falsebr + offset, term.loc ) @@ -578,9 +578,9 @@ def flatten_labels(blocks): term = None if b.body: term = b.body[-1] - if isinstance(term, ir.Jump): + if isinstance(term, ir.jump_types): b.body[-1] = ir.Jump(l_map[term.target], term.loc) - if isinstance(term, ir.Branch): + if isinstance(term, ir.branch_types): b.body[-1] = ir.Branch( term.cond, l_map[term.truebr], l_map[term.falsebr], term.loc ) @@ -593,7 +593,7 @@ def remove_dels(blocks): for block in blocks.values(): new_body = [] for stmt in block.body: - if not isinstance(stmt, ir.Del): + if not isinstance(stmt, ir.del_types): new_body.append(stmt) block.body = new_body return @@ -604,7 +604,9 @@ def remove_args(blocks): for block in blocks.values(): new_body = [] for stmt in block.body: - if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg): + if isinstance(stmt, ir.assign_types) and isinstance( + stmt.value, ir.arg_types + ): continue new_body.append(stmt) block.body = new_body @@ -737,7 +739,7 @@ def remove_dead_block( continue # ignore assignments that their lhs is not live or lhs==rhs - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): lhs = stmt.target rhs = stmt.value if lhs.name not in lives and has_no_side_effect( @@ -747,21 +749,21 @@ def remove_dead_block( print("Statement was removed.") removed = True continue - if isinstance(rhs, ir.Var) and lhs.name == rhs.name: + if isinstance(rhs, ir.var_types) and lhs.name == rhs.name: if config.DEBUG_ARRAY_OPT >= 2: print("Statement was removed.") removed = True continue # TODO: remove other nodes like SetItem etc. - if isinstance(stmt, ir.Del): + if isinstance(stmt, ir.del_types): if stmt.value not in lives: if config.DEBUG_ARRAY_OPT >= 2: print("Statement was removed.") removed = True continue - if isinstance(stmt, ir.SetItem): + if isinstance(stmt, ir.setitem_types): name = stmt.target.name if name not in lives_n_aliases: if config.DEBUG_ARRAY_OPT >= 2: @@ -775,9 +777,9 @@ def remove_dead_block( lives |= uses else: lives |= {v.name for v in stmt.list_vars()} - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): # make sure lhs is not used in rhs, e.g. a = g(a) - if isinstance(stmt.value, ir.Expr): + if isinstance(stmt.value, ir.expr_types): rhs_vars = {v.name for v in stmt.value.list_vars()} if lhs.name not in rhs_vars: lives.remove(lhs.name) @@ -809,7 +811,7 @@ def has_no_side_effect(rhs, lives, call_table): """ from numba.cuda.extending import _Intrinsic - if isinstance(rhs, ir.Expr) and rhs.op == "call": + if isinstance(rhs, ir.expr_types) and rhs.op == "call": func_name = rhs.func.name if func_name not in call_table or call_table[func_name] == []: return False @@ -843,11 +845,11 @@ def has_no_side_effect(rhs, lives, call_table): if f(rhs, lives, call_list): return True return False - if isinstance(rhs, ir.Expr) and rhs.op == "inplace_binop": + if isinstance(rhs, ir.expr_types) and rhs.op == "inplace_binop": return rhs.lhs.name not in lives - if isinstance(rhs, ir.Yield): + if isinstance(rhs, ir.yield_types): return False - if isinstance(rhs, ir.Expr) and rhs.op == "pair_first": + if isinstance(rhs, ir.expr_types) and rhs.op == "pair_first": # don't remove pair_first since prange looks for it return False return True @@ -861,7 +863,7 @@ def is_pure(rhs, lives, call_table): returns the same result. This is not the case for things like calls to numpy.random. """ - if isinstance(rhs, ir.Expr): + if isinstance(rhs, ir.expr_types): if rhs.op == "call": func_name = rhs.func.name if func_name not in call_table or call_table[func_name] == []: @@ -882,7 +884,7 @@ def is_pure(rhs, lives, call_table): return False elif rhs.op == "getiter" or rhs.op == "iternext": return False - if isinstance(rhs, ir.Yield): + if isinstance(rhs, ir.yield_types): return False return True @@ -926,39 +928,42 @@ def find_potential_aliases( if type(instr) in alias_analysis_extensions: f = alias_analysis_extensions[type(instr)] f(instr, args, typemap, func_ir, alias_map, arg_aliases) - if isinstance(instr, ir.Assign): + if isinstance(instr, ir.assign_types): expr = instr.value lhs = instr.target.name # only mutable types can alias if is_immutable_type(lhs, typemap): continue - if isinstance(expr, ir.Var) and lhs != expr.name: + if isinstance(expr, ir.var_types) and lhs != expr.name: _add_alias(lhs, expr.name, alias_map, arg_aliases) # subarrays like A = B[0] for 2D B - if isinstance(expr, ir.Expr) and ( + if isinstance(expr, ir.expr_types) and ( expr.op == "cast" or expr.op in ["getitem", "static_getitem"] ): _add_alias(lhs, expr.value.name, alias_map, arg_aliases) - if isinstance(expr, ir.Expr) and expr.op == "inplace_binop": + if ( + isinstance(expr, ir.expr_types) + and expr.op == "inplace_binop" + ): _add_alias(lhs, expr.lhs.name, alias_map, arg_aliases) # array attributes like A.T if ( - isinstance(expr, ir.Expr) + isinstance(expr, ir.expr_types) and expr.op == "getattr" and expr.attr in ["T", "ctypes", "flat"] ): _add_alias(lhs, expr.value.name, alias_map, arg_aliases) # a = b.c. a should alias b if ( - isinstance(expr, ir.Expr) + isinstance(expr, ir.expr_types) and expr.op == "getattr" and expr.attr not in ["shape"] and expr.value.name in arg_aliases ): _add_alias(lhs, expr.value.name, alias_map, arg_aliases) # calls that can create aliases such as B = A.ravel() - if isinstance(expr, ir.Expr) and expr.op == "call": + if isinstance(expr, ir.expr_types) and expr.op == "call": fdef = guard(find_callname, func_ir, expr, typemap) # TODO: sometimes gufunc backend creates duplicate code # causing find_callname to fail. Example: test_argmax @@ -974,7 +979,10 @@ def find_potential_aliases( _add_alias( lhs, expr.args[0].name, alias_map, arg_aliases ) - if isinstance(fmod, ir.Var) and fname in np_alias_funcs: + if ( + isinstance(fmod, ir.var_types) + and fname in np_alias_funcs + ): _add_alias(lhs, fmod.name, alias_map, arg_aliases) # copy to avoid changing size during iteration @@ -1125,9 +1133,9 @@ def get_block_copies(blocks, typemap): extra_kill[label].add(l) assign_dict = new_assign_dict extra_kill[label] |= kill_set - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): lhs = stmt.target.name - if isinstance(stmt.value, ir.Var): + if isinstance(stmt.value, ir.var_types): rhs = stmt.value.name # copy is valid only if same type (see # TestCFunc.test_locals) @@ -1139,7 +1147,7 @@ def get_block_copies(blocks, typemap): assign_dict[lhs] = rhs continue if ( - isinstance(stmt.value, ir.Expr) + isinstance(stmt.value, ir.expr_types) and stmt.value.op == "inplace_binop" ): in1_var = stmt.value.lhs.name @@ -1195,7 +1203,7 @@ def apply_copy_propagate( ) # only rhs of assignments should be replaced # e.g. if x=y is available, x in x=z shouldn't be replaced - elif isinstance(stmt, ir.Assign): + elif isinstance(stmt, ir.assign_types): stmt.value = replace_vars_inner(stmt.value, var_dict) else: replace_vars_stmt(stmt, var_dict) @@ -1209,7 +1217,9 @@ def apply_copy_propagate( for l, r in var_dict.copy().items(): if l in kill_set or r.name in kill_set: var_dict.pop(l) - if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Var): + if isinstance(stmt, ir.assign_types) and isinstance( + stmt.value, ir.var_types + ): lhs = stmt.target.name rhs = stmt.value.name # rhs could be replaced with lhs from previous copies @@ -1227,8 +1237,8 @@ def apply_copy_propagate( lhs_kill.append(k) for k in lhs_kill: var_dict.pop(k, None) - if isinstance(stmt, ir.Assign) and not isinstance( - stmt.value, ir.Var + if isinstance(stmt, ir.assign_types) and not isinstance( + stmt.value, ir.var_types ): lhs = stmt.target.name var_dict.pop(lhs, None) @@ -1249,7 +1259,7 @@ def fix_setitem_type(stmt, typemap, calltypes): with 'A' layout. The replaced variable can be 'C' or 'F', so we update setitem call type reflect this (from matrix power test) """ - if not isinstance(stmt, (ir.SetItem, ir.StaticSetItem)): + if not isinstance(stmt, ir.setitem_types + ir.staticsetitem_types): return t_typ = typemap[stmt.target.name] s_typ = calltypes[stmt].args[0] @@ -1304,7 +1314,7 @@ def find_topo_order(blocks, cfg=None): seen.add(node) succs = cfg._succs[node] last_inst = blocks[node].body[-1] - if isinstance(last_inst, ir.Branch): + if isinstance(last_inst, ir.branch_types): succs = [last_inst.truebr, last_inst.falsebr] for dest in succs: if (node, dest) not in cfg._back_edges: @@ -1351,12 +1361,12 @@ def get_call_table( for label in reversed(order): for inst in reversed(blocks[label].body): - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): lhs = inst.target.name rhs = inst.value - if isinstance(rhs, ir.Expr) and rhs.op == "call": + if isinstance(rhs, ir.expr_types) and rhs.op == "call": call_table[rhs.func.name] = [] - if isinstance(rhs, ir.Expr) and rhs.op == "getattr": + if isinstance(rhs, ir.expr_types) and rhs.op == "getattr": if lhs in call_table: call_table[lhs].append(rhs.attr) reverse_call_table[rhs.value.name] = lhs @@ -1364,19 +1374,19 @@ def get_call_table( call_var = reverse_call_table[lhs] call_table[call_var].append(rhs.attr) reverse_call_table[rhs.value.name] = call_var - if isinstance(rhs, ir.Global): + if isinstance(rhs, ir.global_types): if lhs in call_table: call_table[lhs].append(rhs.value) if lhs in reverse_call_table: call_var = reverse_call_table[lhs] call_table[call_var].append(rhs.value) - if isinstance(rhs, ir.FreeVar): + if isinstance(rhs, ir.freevar_types): if lhs in call_table: call_table[lhs].append(rhs.value) if lhs in reverse_call_table: call_var = reverse_call_table[lhs] call_table[call_var].append(rhs.value) - if isinstance(rhs, ir.Var): + if isinstance(rhs, ir.var_types): if lhs in call_table: call_table[lhs].append(rhs.name) reverse_call_table[rhs.name] = lhs @@ -1401,12 +1411,14 @@ def get_tuple_table(blocks, tuple_table=None): for block in blocks.values(): for inst in block.body: - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): lhs = inst.target.name rhs = inst.value - if isinstance(rhs, ir.Expr) and rhs.op == "build_tuple": + if isinstance(rhs, ir.expr_types) and rhs.op == "build_tuple": tuple_table[lhs] = rhs.items - if isinstance(rhs, ir.Const) and isinstance(rhs.value, tuple): + if isinstance(rhs, ir.const_types) and isinstance( + rhs.value, tuple + ): tuple_table[lhs] = rhs.value for T, f in tuple_table_extensions.items(): if isinstance(inst, T): @@ -1416,7 +1428,9 @@ def get_tuple_table(blocks, tuple_table=None): def get_stmt_writes(stmt): writes = set() - if isinstance(stmt, (ir.Assign, ir.SetItem, ir.StaticSetItem)): + if isinstance( + stmt, ir.assign_types + ir.setitem_types + ir.staticsetitem_types + ): writes.add(stmt.target.name) return writes @@ -1430,7 +1444,7 @@ def rename_labels(blocks): # make a block with return last if available (just for readability) return_label = -1 for l, b in blocks.items(): - if isinstance(b.body[-1], ir.Return): + if isinstance(b.body[-1], ir.return_types): return_label = l # some cases like generators can have no return blocks if return_label != -1: @@ -1446,9 +1460,9 @@ def rename_labels(blocks): term = b.terminator # create new IR nodes instead of mutating the existing one as copies of # the IR may also refer to the same nodes! - if isinstance(term, ir.Jump): + if isinstance(term, ir.jump_types): b.body[-1] = ir.Jump(label_map[term.target], term.loc) - if isinstance(term, ir.Branch): + if isinstance(term, ir.branch_types): b.body[-1] = ir.Branch( term.cond, label_map[term.truebr], @@ -1472,7 +1486,9 @@ def simplify_CFG(blocks): def find_single_branch(label): block = blocks[label] - return len(block.body) == 1 and isinstance(block.body[0], ir.Branch) + return len(block.body) == 1 and isinstance( + block.body[0], ir.branch_types + ) single_branch_blocks = list(filter(find_single_branch, blocks.keys())) marked_for_del = set() @@ -1482,7 +1498,7 @@ def find_single_branch(label): delete_block = True for p, q in predecessors: block = blocks[p] - if isinstance(block.body[-1], ir.Jump): + if isinstance(block.body[-1], ir.jump_types): block.body[-1] = copy.copy(inst) else: delete_block = False @@ -1523,7 +1539,9 @@ def canonicalize_array_math(func_ir, typemap, calltypes, typingctx): block = blocks[label] new_body = [] for stmt in block.body: - if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr): + if isinstance(stmt, ir.assign_types) and isinstance( + stmt.value, ir.expr_types + ): lhs = stmt.target.name rhs = stmt.value # replace A.func with np.func, and save A in saved_arr_arg @@ -1582,15 +1600,18 @@ def get_array_accesses(blocks, accesses=None): for block in blocks.values(): for inst in block.body: - if isinstance(inst, ir.SetItem): + if isinstance(inst, ir.setitem_types): accesses.add((inst.target.name, inst.index.name)) - if isinstance(inst, ir.StaticSetItem): + if isinstance(inst, ir.staticsetitem_types): accesses.add((inst.target.name, inst.index_var.name)) - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): rhs = inst.value - if isinstance(rhs, ir.Expr) and rhs.op == "getitem": + if isinstance(rhs, ir.expr_types) and rhs.op == "getitem": accesses.add((rhs.value.name, rhs.index.name)) - if isinstance(rhs, ir.Expr) and rhs.op == "static_getitem": + if ( + isinstance(rhs, ir.expr_types) + and rhs.op == "static_getitem" + ): index = rhs.index # slice is unhashable, so just keep the variable if index is None or is_slice_index(index): @@ -1743,7 +1764,7 @@ def build_definitions(blocks, definitions=None): for block in blocks.values(): for inst in block.body: - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): name = inst.target.name definition = definitions.get(name, []) if definition == []: @@ -1771,13 +1792,13 @@ def find_callname( """ from numba.cuda.extending import _Intrinsic - require(isinstance(expr, ir.Expr) and expr.op == "call") + require(isinstance(expr, ir.expr_types) and expr.op == "call") callee = expr.func callee_def = definition_finder(func_ir, callee) attrs = [] obj = None while True: - if isinstance(callee_def, (ir.Global, ir.FreeVar)): + if isinstance(callee_def, ir.global_types + ir.freevar_types): # require(callee_def.value == numpy) # these checks support modules like numpy, numpy.random as well as # calls like len() and intrinsics like assertEquiv @@ -1829,7 +1850,9 @@ def find_callname( if class_name != "module": attrs.append(class_name) break - elif isinstance(callee_def, ir.Expr) and callee_def.op == "getattr": + elif ( + isinstance(callee_def, ir.expr_types) and callee_def.op == "getattr" + ): obj = callee_def.value attrs.append(callee_def.attr) if typemap and obj.name in typemap: @@ -1851,9 +1874,9 @@ def find_build_sequence(func_ir, var): operator, or raise GuardException otherwise. Note: only build_tuple is immutable, so use with care. """ - require(isinstance(var, ir.Var)) + require(isinstance(var, ir.var_types)) var_def = get_definition(func_ir, var) - require(isinstance(var_def, ir.Expr)) + require(isinstance(var_def, ir.expr_types)) build_ops = ["build_tuple", "build_list", "build_set"] require(var_def.op in build_ops) return var_def.items, var_def.op @@ -1863,9 +1886,11 @@ def find_const(func_ir, var): """Check if a variable is defined as constant, and return the constant value, or raise GuardException otherwise. """ - require(isinstance(var, ir.Var)) + require(isinstance(var, ir.var_types)) var_def = get_definition(func_ir, var) - require(isinstance(var_def, (ir.Const, ir.Global, ir.FreeVar))) + require( + isinstance(var_def, ir.const_types + ir.global_types + ir.freevar_types) + ) return var_def.value @@ -2025,7 +2050,9 @@ def replace_arg_nodes(block, args): Replace ir.Arg(...) with variables """ for stmt in block.body: - if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg): + if isinstance(stmt, ir.assign_types) and isinstance( + stmt.value, ir.arg_types + ): idx = stmt.value.index assert idx < len(args) stmt.value = args[idx] @@ -2041,12 +2068,12 @@ def replace_returns(blocks, target, return_label): if not block.body: continue stmt = block.terminator - if isinstance(stmt, ir.Return): + if isinstance(stmt, ir.return_types): block.body.pop() # remove return cast_stmt = block.body.pop() assert ( - isinstance(cast_stmt, ir.Assign) - and isinstance(cast_stmt.value, ir.Expr) + isinstance(cast_stmt, ir.assign_types) + and isinstance(cast_stmt.value, ir.expr_types) and cast_stmt.value.op == "cast" ), "invalid return cast" block.body.append( @@ -2093,7 +2120,7 @@ def dump_blocks(blocks): def is_operator_or_getitem(expr): """true if expr is unary or binary operator or getitem""" return ( - isinstance(expr, ir.Expr) + isinstance(expr, ir.expr_types) and getattr(expr, "op", False) and expr.op in ["unary", "binop", "inplace_binop", "getitem", "static_getitem"] @@ -2108,15 +2135,15 @@ def is_get_setitem(stmt): def is_getitem(stmt): """true if stmt is a getitem or static_getitem assignment""" return ( - isinstance(stmt, ir.Assign) - and isinstance(stmt.value, ir.Expr) + isinstance(stmt, ir.assign_types) + and isinstance(stmt.value, ir.expr_types) and stmt.value.op in ["getitem", "static_getitem"] ) def is_setitem(stmt): """true if stmt is a SetItem or StaticSetItem node""" - return isinstance(stmt, (ir.SetItem, ir.StaticSetItem)) + return isinstance(stmt, ir.setitem_types + ir.staticsetitem_types) def index_var_of_get_setitem(stmt): @@ -2128,7 +2155,7 @@ def index_var_of_get_setitem(stmt): return stmt.value.index_var if is_setitem(stmt): - if isinstance(stmt, ir.SetItem): + if isinstance(stmt, ir.setitem_types): return stmt.index else: return stmt.index_var @@ -2143,7 +2170,7 @@ def set_index_var_of_get_setitem(stmt, new_index): else: stmt.value.index_var = new_index elif is_setitem(stmt): - if isinstance(stmt, ir.SetItem): + if isinstance(stmt, ir.setitem_types): stmt.index = new_index else: stmt.index_var = new_index @@ -2242,10 +2269,10 @@ def find_outer_value(func_ir, var): or raise GuardException otherwise. """ dfn = get_definition(func_ir, var) - if isinstance(dfn, (ir.Global, ir.FreeVar)): + if isinstance(dfn, ir.global_types + ir.freevar_types): return dfn.value - if isinstance(dfn, ir.Expr) and dfn.op == "getattr": + if isinstance(dfn, ir.expr_types) and dfn.op == "getattr": prev_val = find_outer_value(func_ir, dfn.value) try: val = getattr(prev_val, dfn.attr) @@ -2288,9 +2315,9 @@ def raise_on_unsupported_feature(func_ir, typemap): raise UnsupportedError(msg, func_ir.loc) for blk in func_ir.blocks.values(): - for stmt in blk.find_insts(ir.Assign): + for stmt in blk.find_insts(ir.assign_types): # This raises on finding `make_function` - if isinstance(stmt.value, ir.Expr): + if isinstance(stmt.value, ir.expr_types): if stmt.value.op == "make_function": val = stmt.value @@ -2321,7 +2348,7 @@ def raise_on_unsupported_feature(func_ir, typemap): raise UnsupportedError(msg, stmt.value.loc) # this checks for gdb initialization calls, only one is permitted - if isinstance(stmt.value, (ir.Global, ir.FreeVar)): + if isinstance(stmt.value, ir.global_types + ir.freevar_types): val = stmt.value val = getattr(val, "value", None) if val is None: @@ -2337,7 +2364,7 @@ def raise_on_unsupported_feature(func_ir, typemap): gdb_calls.append(stmt.loc) # report last seen location # this checks that np. was called if view is called - if isinstance(stmt.value, ir.Expr): + if isinstance(stmt.value, ir.expr_types): if stmt.value.op == "getattr" and stmt.value.attr == "view": var = stmt.value.value.name if isinstance(typemap[var], types.Array): @@ -2363,7 +2390,7 @@ def raise_on_unsupported_feature(func_ir, typemap): ) # checks for globals that are also reflected - if isinstance(stmt.value, ir.Global): + if isinstance(stmt.value, ir.global_types): ty = typemap[stmt.target.name] msg = ( "The use of a %s type, assigned to variable '%s' in " @@ -2380,7 +2407,10 @@ def raise_on_unsupported_feature(func_ir, typemap): # checks for generator expressions (yield in use when func_ir has # not been identified as a generator). - if isinstance(stmt.value, ir.Yield) and not func_ir.is_generator: + if ( + isinstance(stmt.value, ir.yield_types) + and not func_ir.is_generator + ): msg = "The use of generator expressions is unsupported." raise UnsupportedError(msg, loc=stmt.loc) @@ -2443,7 +2473,7 @@ def resolve_mod(mod): except KeyError: # multiple definitions return None return resolve_mod(mod) - elif isinstance(mod, (ir.Global, ir.FreeVar)): + elif isinstance(mod, ir.global_types + ir.freevar_types): if isinstance(mod.value, pytypes.ModuleType): return mod return None @@ -2466,7 +2496,7 @@ def enforce_no_dels(func_ir): Enforce there being no ir.Del nodes in the IR. """ for blk in func_ir.blocks.values(): - dels = [x for x in blk.find_insts(ir.Del)] + dels = [x for x in blk.find_insts(ir.del_types)] if dels: msg = "Illegal IR, del found at: %s" % dels[0] raise CompilerError(msg, loc=dels[0].loc) @@ -2522,7 +2552,7 @@ def convert_code_obj_to_function(code_obj, caller_ir): "are multiple definitions present." % x ) raise TypingError(msg, loc=code_obj.loc) - if isinstance(freevar_def, ir.Const): + if isinstance(freevar_def, ir.const_types): freevars.append(freevar_def.value) else: msg = ( @@ -2619,20 +2649,20 @@ def transfer_scope(block, scope): def is_setup_with(stmt): - return isinstance(stmt, ir.EnterWith) + return isinstance(stmt, ir.enterwith_types) def is_terminator(stmt): - return isinstance(stmt, ir.Terminator) + return isinstance(stmt, ir.terminator_types) def is_raise(stmt): - return isinstance(stmt, ir.Raise) + return isinstance(stmt, ir.raise_types) def is_return(stmt): - return isinstance(stmt, ir.Return) + return isinstance(stmt, ir.return_types) def is_pop_block(stmt): - return isinstance(stmt, ir.PopBlock) + return isinstance(stmt, ir.popblock_types) diff --git a/numba_cuda/numba/cuda/core/postproc.py b/numba_cuda/numba/cuda/core/postproc.py index 2f0d66fa1..61b51c7a7 100644 --- a/numba_cuda/numba/cuda/core/postproc.py +++ b/numba_cuda/numba/cuda/core/postproc.py @@ -7,8 +7,8 @@ class YieldPoint(object): def __init__(self, block, inst): - assert isinstance(block, ir.Block) - assert isinstance(inst, ir.Yield) + assert isinstance(block, ir.block_types) + assert isinstance(inst, ir.yield_types) self.block = block self.inst = inst self.live_vars = None @@ -111,9 +111,9 @@ def _populate_generator_info(self): assert not dct, "rerunning _populate_generator_info" for block in self.func_ir.blocks.values(): for inst in block.body: - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): yieldinst = inst.value - if isinstance(yieldinst, ir.Yield): + if isinstance(yieldinst, ir.yield_types): index = len(dct) + 1 yieldinst.index = index yp = YieldPoint(block, yieldinst) @@ -133,18 +133,18 @@ def _compute_generator_info(self): weak_live_vars = set() stmts = iter(yp.block.body) for stmt in stmts: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): if stmt.value is yp.inst: break live_vars.add(stmt.target.name) - elif isinstance(stmt, ir.Del): + elif isinstance(stmt, ir.del_types): live_vars.remove(stmt.value) else: assert 0, "couldn't find yield point" # Try to optimize out any live vars that are deleted immediately # after the yield point. for stmt in stmts: - if isinstance(stmt, ir.Del): + if isinstance(stmt, ir.del_types): name = stmt.value if name in live_vars: live_vars.remove(name) @@ -222,7 +222,7 @@ def _patch_var_dels( else: lastloc = stmt.loc # Ignore dels (assuming no user inserted deletes) - if not isinstance(stmt, ir.Del): + if not isinstance(stmt, ir.del_types): body.append(stmt) # note: the reverse sort is not necessary for correctness # it is just to minimize changes to test for now diff --git a/numba_cuda/numba/cuda/core/rewrites/ir_print.py b/numba_cuda/numba/cuda/core/rewrites/ir_print.py index 558fe6816..fe09f5a4e 100644 --- a/numba_cuda/numba/cuda/core/rewrites/ir_print.py +++ b/numba_cuda/numba/cuda/core/rewrites/ir_print.py @@ -16,8 +16,11 @@ def match(self, func_ir, block, typemap, calltypes): self.prints = prints = {} self.block = block # Find all assignments with a right-hand print() call - for inst in block.find_insts(ir.Assign): - if isinstance(inst.value, ir.Expr) and inst.value.op == "call": + for inst in block.find_insts(ir.assign_types): + if ( + isinstance(inst.value, ir.expr_types) + and inst.value.op == "call" + ): expr = inst.value try: callee = func_ir.infer_constant(expr.func) @@ -68,7 +71,7 @@ class DetectConstPrintArguments(Rewrite): def match(self, func_ir, block, typemap, calltypes): self.consts = consts = {} self.block = block - for inst in block.find_insts(ir.Print): + for inst in block.find_insts(ir.print_types): if inst.consts: # Already rewritten continue diff --git a/numba_cuda/numba/cuda/core/rewrites/static_getitem.py b/numba_cuda/numba/cuda/core/rewrites/static_getitem.py index 10d88f789..a650772db 100644 --- a/numba_cuda/numba/cuda/core/rewrites/static_getitem.py +++ b/numba_cuda/numba/cuda/core/rewrites/static_getitem.py @@ -37,7 +37,7 @@ def apply(self): new_block = self.block.copy() new_block.clear() for inst in self.block.body: - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): expr = inst.value if expr in self.getitems: const = self.getitems[expr] @@ -85,7 +85,7 @@ def apply(self): """ new_block = ir.Block(self.block.scope, self.block.loc) for inst in self.block.body: - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): expr = inst.value if expr in self.getitems: const, lit_val = self.getitems[expr] @@ -119,7 +119,7 @@ def match(self, func_ir, block, typemap, calltypes): self.setitems = setitems = {} self.block = block self.calltypes = calltypes - for inst in block.find_insts(ir.SetItem): + for inst in block.find_insts(ir.setitem_types): index_ty = typemap[inst.index.name] if isinstance(index_ty, types.StringLiteral): setitems[inst] = (inst.index, index_ty.literal_value) @@ -133,7 +133,7 @@ def apply(self): """ new_block = ir.Block(self.block.scope, self.block.loc) for inst in self.block.body: - if isinstance(inst, ir.SetItem): + if isinstance(inst, ir.setitem_types): if inst in self.setitems: const, lit_val = self.setitems[inst] new_inst = ir.StaticSetItem( @@ -162,7 +162,7 @@ def match(self, func_ir, block, typemap, calltypes): self.block = block # Detect all setitem statements and find which ones can be # rewritten - for inst in block.find_insts(ir.SetItem): + for inst in block.find_insts(ir.setitem_types): try: const = func_ir.infer_constant(inst.index) except errors.ConstantInferenceError: diff --git a/numba_cuda/numba/cuda/core/rewrites/static_raise.py b/numba_cuda/numba/cuda/core/rewrites/static_raise.py index 5dbac7834..2a68b2464 100644 --- a/numba_cuda/numba/cuda/core/rewrites/static_raise.py +++ b/numba_cuda/numba/cuda/core/rewrites/static_raise.py @@ -56,7 +56,7 @@ def match(self, func_ir, block, typemap, calltypes): self.block = block # Detect all raise statements and find which ones can be # rewritten - for inst in block.find_insts((ir.Raise, ir.TryRaise)): + for inst in block.find_insts(ir.raise_types + ir.tryraise_types): if inst.exception is None: # re-reraise exc_type, exc_args = None, None @@ -72,9 +72,9 @@ def match(self, func_ir, block, typemap, calltypes): loc = inst.exception.loc exc_type, exc_args = self._break_constant(const, loc) - if isinstance(inst, ir.Raise): + if isinstance(inst, ir.raise_types): raises[inst] = exc_type, exc_args - elif isinstance(inst, ir.TryRaise): + elif isinstance(inst, ir.tryraise_types): tryraises[inst] = exc_type, exc_args else: raise ValueError("unexpected: {}".format(type(inst))) diff --git a/numba_cuda/numba/cuda/core/ssa.py b/numba_cuda/numba/cuda/core/ssa.py index a157e83d8..5530ee8f8 100644 --- a/numba_cuda/numba/cuda/core/ssa.py +++ b/numba_cuda/numba/cuda/core/ssa.py @@ -211,7 +211,7 @@ def _run_ssa_block_pass(states, blk, handler): _logger.debug("Running %s", handler) for stmt in blk.body: _logger.debug("on stmt: %s", stmt) - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): ret = handler.on_assign(states, stmt) else: ret = handler.on_other(states, stmt) @@ -335,7 +335,7 @@ def __init__(self, cache_list_vars): def on_assign(self, states, assign): rhs = assign.value - if isinstance(rhs, ir.Inst): + if isinstance(rhs, ir.inst_types): newdef = self._fix_var( states, assign, @@ -353,7 +353,7 @@ def on_assign(self, states, assign): value=rhs, loc=assign.loc, ) - elif isinstance(rhs, ir.Var): + elif isinstance(rhs, ir.var_types): newdef = self._fix_var(states, assign, [rhs]) # Has a replacement that is not the current variable if newdef is not None and newdef.target is not ir.UNDEFINED: diff --git a/numba_cuda/numba/cuda/core/transforms.py b/numba_cuda/numba/cuda/core/transforms.py index 90611557e..d9bed76e3 100644 --- a/numba_cuda/numba/cuda/core/transforms.py +++ b/numba_cuda/numba/cuda/core/transforms.py @@ -55,8 +55,8 @@ def cannot_yield(loop): insiders = set(loop.body) | set(loop.entries) | set(loop.exits) for blk in map(blocks.__getitem__, insiders): for inst in blk.body: - if isinstance(inst, ir.Assign): - if isinstance(inst.value, ir.Yield): + if isinstance(inst, ir.assign_types): + if isinstance(inst.value, ir.yield_types): _logger.debug("has yield") return False _logger.debug("no yield") @@ -347,14 +347,14 @@ def replace_target(term, src, dst): def replace(target): return dst if target == src else target - if isinstance(term, ir.Branch): + if isinstance(term, ir.branch_types): return ir.Branch( cond=term.cond, truebr=replace(term.truebr), falsebr=replace(term.falsebr), loc=term.loc, ) - elif isinstance(term, ir.Jump): + elif isinstance(term, ir.jump_types): return ir.Jump(target=replace(term.target), loc=term.loc) else: assert not term.get_targets() @@ -477,7 +477,7 @@ def get_ctxmgr_obj(var_ref): """ # If the contextmanager used as a Call dfn = func_ir.get_definition(var_ref) - if isinstance(dfn, ir.Expr) and dfn.op == "call": + if isinstance(dfn, ir.expr_types) and dfn.op == "call": args = [get_var_dfn(x) for x in dfn.args] kws = {k: get_var_dfn(v) for k, v in dfn.kws} extra = {"args": args, "kwargs": kws} @@ -501,7 +501,7 @@ def get_ctxmgr_obj(var_ref): # Scan the start of the with-region for the contextmanager for stmt in blocks[blk_start].body: - if isinstance(stmt, ir.EnterWith): + if isinstance(stmt, ir.enterwith_types): var_ref = stmt.contextmanager ctxobj, extra = get_ctxmgr_obj(var_ref) if not hasattr(ctxobj, "mutate_with_body"): @@ -523,7 +523,22 @@ def _legalize_with_head(blk): """ counters = defaultdict(int) for stmt in blk.body: - counters[type(stmt)] += 1 + # The counters dict is keyed on the IR node type. As the rest of the + # function pops out specific node types, we normalize these node types + # to be the Numba-CUDA versions of them so that we don't have to have + # more complicated logic looking for both Numba and Numba-CUDA IR nodes + # of the same kind. + if isinstance(stmt, ir.enterwith_types): + stmt_type = ir.EnterWith + elif isinstance(stmt, ir.jump_types): + stmt_type = ir.Jump + elif isinstance(stmt, ir.del_types): + stmt_type = ir.Del + else: + stmt_type = type(stmt) + + counters[stmt_type] += 1 + if counters.pop(ir.EnterWith) != 1: raise errors.CompilerError( "with's head-block must have exactly 1 ENTER_WITH", @@ -744,10 +759,10 @@ def _rewrite_return(func_ir, target_block_label): # JUMP # ----------------- top_body, bottom_body = [], [] - pop_blocks = [*target_block.find_insts(ir.PopBlock)] + pop_blocks = [*target_block.find_insts(ir.popblock_types)] assert len(pop_blocks) == 1 - assert len([*target_block.find_insts(ir.Jump)]) == 1 - assert isinstance(target_block.body[-1], ir.Jump) + assert len([*target_block.find_insts(ir.jump_types)]) == 1 + assert isinstance(target_block.body[-1], ir.jump_types) pb_marker = pop_blocks[0] pb_is = target_block.body.index(pb_marker) top_body.extend(target_block.body[:pb_is]) diff --git a/numba_cuda/numba/cuda/core/typed_passes.py b/numba_cuda/numba/cuda/core/typed_passes.py index 380507e80..958e47c51 100644 --- a/numba_cuda/numba/cuda/core/typed_passes.py +++ b/numba_cuda/numba/cuda/core/typed_passes.py @@ -166,15 +166,15 @@ def legalize_return_type(return_type, interp, targetctx): argvars = set() for bid, blk in interp.blocks.items(): for inst in blk.body: - if isinstance(inst, ir.Return): + if isinstance(inst, ir.return_types): retstmts.append(inst.value.name) - elif isinstance(inst, ir.Assign): + elif isinstance(inst, ir.assign_types): if ( - isinstance(inst.value, ir.Expr) + isinstance(inst.value, ir.expr_types) and inst.value.op == "cast" ): caststmts[inst.target.name] = inst.value - elif isinstance(inst.value, ir.Arg): + elif isinstance(inst.value, ir.arg_types): argvars.add(inst.target.name) assert retstmts, "No return statements?" @@ -518,9 +518,9 @@ def run_pass(self, state): label, block = work_list.pop() for i, instr in enumerate(block.body): # TO-DO: other statements (setitem) - if isinstance(instr, ir.Assign): + if isinstance(instr, ir.assign_types): expr = instr.value - if isinstance(expr, ir.Expr): + if isinstance(expr, ir.expr_types): workfn = self._do_work_expr if guard( @@ -822,8 +822,8 @@ def _strip_phi_nodes(self, func_ir): phis = set() # Find all variables that needs to be exported for label, block in func_ir.blocks.items(): - for assign in block.find_insts(ir.Assign): - if isinstance(assign.value, ir.Expr): + for assign in block.find_insts(ir.assign_types): + if isinstance(assign.value, ir.expr_types): if assign.value.op == "phi": phis.add(assign) phi = assign.value @@ -854,7 +854,7 @@ def _strip_phi_nodes(self, func_ir): # last assignment to rhs assignments = [ stmt - for stmt in newblk.find_insts(ir.Assign) + for stmt in newblk.find_insts(ir.assign_types) if stmt.target == rhs ] if assignments: diff --git a/numba_cuda/numba/cuda/core/typeinfer.py b/numba_cuda/numba/cuda/core/typeinfer.py index 6b71463b8..4f32e9613 100644 --- a/numba_cuda/numba/cuda/core/typeinfer.py +++ b/numba_cuda/numba/cuda/core/typeinfer.py @@ -653,7 +653,7 @@ def resolve(self, typeinfer, typevars, fnty): unsatisfied = set() for idx in e.requested_args: maybe_arg = typeinfer.func_ir.get_definition(folded[idx]) - if isinstance(maybe_arg, ir.Arg): + if isinstance(maybe_arg, ir.arg_types): requested.add(maybe_arg.index) else: unsatisfied.add(idx) @@ -1081,7 +1081,7 @@ def _get_return_vars(self): rets = [] for blk in self.blocks.values(): inst = blk.terminator - if isinstance(inst, ir.Return): + if isinstance(inst, ir.return_types): rets.append(inst.value) return rets @@ -1241,7 +1241,7 @@ def diagnose_imprecision(offender): call_name = offender.value.func.name # find the offender based on the call name offender = find_offender(call_name) - if isinstance(offender.value, ir.Global): + if isinstance(offender.value, ir.global_types): if offender.value.name == "list": return list_msg except (AttributeError, KeyError): @@ -1447,9 +1447,9 @@ def check_type(atype): returns = {} for x in reversed(lst): for block in self.func_ir.blocks.values(): - for instr in block.find_insts(ir.Return): + for instr in block.find_insts(ir.return_types): value = instr.value - if isinstance(value, ir.Var): + if isinstance(value, ir.var_types): name = value.name else: pass @@ -1497,27 +1497,37 @@ def get_state_token(self): return [tv.type for name, tv in sorted(self.typevars.items())] def constrain_statement(self, inst): - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): self.typeof_assign(inst) - elif isinstance(inst, ir.SetItem): + elif isinstance(inst, ir.setitem_types): self.typeof_setitem(inst) - elif isinstance(inst, ir.StaticSetItem): + elif isinstance(inst, ir.staticsetitem_types): self.typeof_static_setitem(inst) - elif isinstance(inst, ir.DelItem): + elif isinstance(inst, ir.delitem_types): self.typeof_delitem(inst) - elif isinstance(inst, ir.SetAttr): + elif isinstance(inst, ir.setattr_types): self.typeof_setattr(inst) - elif isinstance(inst, ir.Print): + elif isinstance(inst, ir.print_types): self.typeof_print(inst) - elif isinstance(inst, ir.StoreMap): + elif isinstance(inst, ir.storemap_types): self.typeof_storemap(inst) - elif isinstance(inst, (ir.Jump, ir.Branch, ir.Return, ir.Del)): + elif isinstance(inst, ir.jump_types): pass - elif isinstance(inst, (ir.DynamicRaise, ir.DynamicTryRaise)): + elif isinstance(inst, ir.branch_types): pass - elif isinstance(inst, (ir.StaticRaise, ir.StaticTryRaise)): + elif isinstance(inst, ir.return_types): pass - elif isinstance(inst, ir.PopBlock): + elif isinstance(inst, ir.del_types): + pass + elif isinstance(inst, ir.dynamicraise_types): + pass + elif isinstance(inst, ir.dynamictryraise_types): + pass + elif isinstance(inst, ir.staticraise_types): + pass + elif isinstance(inst, ir.statictryraise_types): + pass + elif isinstance(inst, ir.popblock_types): pass # It's a marker statement elif type(inst) in typeinfer_extensions: # let external calls handle stmt if type matches @@ -1575,19 +1585,21 @@ def typeof_print(self, inst): def typeof_assign(self, inst): value = inst.value - if isinstance(value, ir.Const): + if isinstance(value, ir.const_types): self.typeof_const(inst, inst.target, value.value) - elif isinstance(value, ir.Var): + elif isinstance(value, ir.var_types): self.constraints.append( Propagate(dst=inst.target.name, src=value.name, loc=inst.loc) ) - elif isinstance(value, (ir.Global, ir.FreeVar)): + elif isinstance(value, ir.global_types) or isinstance( + value, ir.freevar_types + ): self.typeof_global(inst, inst.target, value) - elif isinstance(value, ir.Arg): + elif isinstance(value, ir.arg_types): self.typeof_arg(inst, inst.target, value) - elif isinstance(value, ir.Expr): + elif isinstance(value, ir.expr_types): self.typeof_expr(inst, inst.target, value) - elif isinstance(value, ir.Yield): + elif isinstance(value, ir.yield_types): self.typeof_yield(inst, inst.target, value) else: msg = "Unsupported assignment encountered: %s %s" % ( diff --git a/numba_cuda/numba/cuda/core/untyped_passes.py b/numba_cuda/numba/cuda/core/untyped_passes.py index 6bced45f9..98ca36687 100644 --- a/numba_cuda/numba/cuda/core/untyped_passes.py +++ b/numba_cuda/numba/cuda/core/untyped_passes.py @@ -350,9 +350,9 @@ def run_pass(self, state): while work_list: label, block = work_list.pop() for i, instr in enumerate(block.body): - if isinstance(instr, ir.Assign): + if isinstance(instr, ir.assign_types): expr = instr.value - if isinstance(expr, ir.Expr) and expr.op == "call": + if isinstance(expr, ir.expr_types) and expr.op == "call": if guard( self._do_work, state, @@ -561,14 +561,14 @@ def _split_entry_block(self, fir, cfg, loop, entry_label): # Find the start of loop entry statement that needs to be included. startpt = None - list_of_insts = list(entry_block.find_insts(ir.Assign)) + list_of_insts = list(entry_block.find_insts(ir.assign_types)) for assign in reversed(list_of_insts): if assign.target in deps: rhs = assign.value - if isinstance(rhs, ir.Var): + if isinstance(rhs, ir.var_types): if rhs.is_temp: deps.add(rhs) - elif isinstance(rhs, ir.Expr): + elif isinstance(rhs, ir.expr_types): expr = rhs if expr.op == "getiter": startpt = assign @@ -576,11 +576,11 @@ def _split_entry_block(self, fir, cfg, loop, entry_label): deps.add(expr.value) elif expr.op == "call": defn = guard(get_definition, fir, expr.func) - if isinstance(defn, ir.Global): + if isinstance(defn, ir.global_types): if expr.func.is_temp: deps.add(expr.func) elif ( - isinstance(rhs, ir.Global) + isinstance(rhs, ir.global_types) and rhs.value in self._supported_globals ): startpt = assign @@ -634,30 +634,30 @@ def run_pass(self, state): mutated = False for idx, blk in func_ir.blocks.items(): for stmt in blk.body: - if isinstance(stmt, ir.Assign): - if isinstance(stmt.value, ir.Expr): + if isinstance(stmt, ir.assign_types): + if isinstance(stmt.value, ir.expr_types): if stmt.value.op == "make_function": node = stmt.value getdef = func_ir.get_definition kw_default = getdef(node.defaults) ok = False if kw_default is None or isinstance( - kw_default, ir.Const + kw_default, ir.const_types ): ok = True elif isinstance(kw_default, tuple): ok = all( [ - isinstance(getdef(x), ir.Const) + isinstance(getdef(x), ir.const_types) for x in kw_default ] ) - elif isinstance(kw_default, ir.Expr): + elif isinstance(kw_default, ir.expr_types): if kw_default.op != "build_tuple": continue ok = all( [ - isinstance(getdef(x), ir.Const) + isinstance(getdef(x), ir.const_types) for x in kw_default.items ] ) @@ -700,7 +700,10 @@ def run_pass(self, state): calls = [_ for _ in blk.find_exprs("call")] for call in calls: glbl = guard(get_definition, func_ir, call.func) - if glbl and isinstance(glbl, (ir.Global, ir.FreeVar)): + if glbl and ( + isinstance(glbl, ir.global_types) + or isinstance(glbl, ir.freevar_types) + ): # find a literal_unroll if glbl.value is literal_unroll: if len(call.args) > 1: @@ -712,7 +715,7 @@ def run_pass(self, state): unroll_var = call.args[0] to_unroll = guard(get_definition, func_ir, unroll_var) if ( - isinstance(to_unroll, ir.Expr) + isinstance(to_unroll, ir.expr_types) and to_unroll.op == "build_list" ): # make sure they are all const items in the list @@ -726,7 +729,7 @@ def run_pass(self, state): raise errors.UnsupportedError( msg % item, to_unroll.loc ) - if not isinstance(val, ir.Const): + if not isinstance(val, ir.const_types): msg = ( "Found non-constant value at " "position %s in a list argument to " @@ -777,17 +780,18 @@ def run_pass(self, state): asgn.value = tup mutated = True elif ( - isinstance(to_unroll, ir.Expr) + isinstance(to_unroll, ir.expr_types) and to_unroll.op == "build_tuple" ): # this is fine, do nothing pass - elif isinstance( - to_unroll, (ir.Global, ir.FreeVar) + elif ( + isinstance(to_unroll, ir.global_types) + or isinstance(to_unroll, ir.freevar_types) ) and isinstance(to_unroll.value, tuple): # this is fine, do nothing pass - elif isinstance(to_unroll, ir.Arg): + elif isinstance(to_unroll, ir.arg_types): # this is only fine if the arg is a tuple ty = state.typemap[to_unroll.name] if not isinstance(ty, self._accepted_types): @@ -802,7 +806,7 @@ def run_pass(self, state): ) else: extra = None - if isinstance(to_unroll, ir.Expr): + if isinstance(to_unroll, ir.expr_types): # probably a slice if to_unroll.op == "getitem": ty = state.typemap[to_unroll.value.name] @@ -810,7 +814,7 @@ def run_pass(self, state): if not isinstance(ty, self._accepted_types): extra = "operation %s" % to_unroll.op loc = to_unroll.loc - elif isinstance(to_unroll, ir.Arg): + elif isinstance(to_unroll, ir.arg_types): extra = "non-const argument %s" % to_unroll.name loc = to_unroll.loc else: @@ -868,10 +872,10 @@ def add_offset_to_labels_w_ignore(self, blocks, offset, ignore=None): term = None if b.body: term = b.body[-1] - if isinstance(term, ir.Jump): + if isinstance(term, ir.jump_types): if term.target not in ignore: b.body[-1] = ir.Jump(term.target + offset, term.loc) - if isinstance(term, ir.Branch): + if isinstance(term, ir.branch_types): if term.truebr not in ignore: new_true = term.truebr + offset else: @@ -925,7 +929,7 @@ def inject_loop_body( sentinel_blocks = [] for lbl, blk in switch_ir.blocks.items(): for i, stmt in enumerate(blk.body): - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): if "SENTINEL" in stmt.target.name: sentinel_blocks.append(lbl) sentinel_exits.add(blk.body[-1].target) @@ -939,10 +943,10 @@ def inject_loop_body( local_lbl = [x for x in loop_ir.blocks.keys()] for lbl, blk in loop_ir.blocks.items(): for i, stmt in enumerate(blk.body): - if isinstance(stmt, ir.Jump): + if isinstance(stmt, ir.jump_types): if stmt.target not in local_lbl: ignore_set.add(stmt.target) - if isinstance(stmt, ir.Branch): + if isinstance(stmt, ir.branch_types): if stmt.truebr not in local_lbl: ignore_set.add(stmt.truebr) if stmt.falsebr not in local_lbl: @@ -968,9 +972,9 @@ def inject_loop_body( for blk in loop_blocks.values(): new_body = [] for stmt in blk.body: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): if ( - isinstance(stmt.value, ir.Expr) + isinstance(stmt.value, ir.expr_types) and stmt.value.op == "typed_getitem" ): if isinstance(branch_ty, types.Literal): @@ -1130,8 +1134,8 @@ def foo(): branches = compile_to_numba_ir(bfunc, {}) for lbl, blk in branches.blocks.items(): for stmt in blk.body: - if isinstance(stmt, ir.Assign): - if isinstance(stmt.value, ir.Global): + if isinstance(stmt, ir.assign_types): + if isinstance(stmt.value, ir.global_types): if stmt.value.name == "PLACEHOLDER_INDEX": stmt.value = index return branches @@ -1154,12 +1158,12 @@ def get_call_args(init_arg, want): # call to a global function "want" and returns the arguments # supplied to that function's call some_call = get_definition(func_ir, init_arg) - if not isinstance(some_call, ir.Expr): + if not isinstance(some_call, ir.expr_types): raise GuardException if not some_call.op == "call": raise GuardException the_global = get_definition(func_ir, some_call.func) - if not isinstance(the_global, ir.Global): + if not isinstance(the_global, ir.global_types): raise GuardException if the_global.value is not want: raise GuardException @@ -1206,7 +1210,7 @@ def find_unroll_loops(loops): ) if literal_unroll_call is None: continue - if not isinstance(literal_unroll_call, ir.Expr): + if not isinstance(literal_unroll_call, ir.expr_types): continue if literal_unroll_call.op != "call": continue @@ -1263,9 +1267,9 @@ def collect_literal_unroll_info(literal_unroll_loops): for lbli in loop.body: blk = func_ir.blocks[lbli] for stmt in blk.body: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): if ( - isinstance(stmt.value, ir.Expr) + isinstance(stmt.value, ir.expr_types) and stmt.value.op == "getitem" ): # check for something like a[i] @@ -1346,9 +1350,9 @@ def unroll_loop(self, state, loop_info): for lbl in loop_info.loop.body: blk = func_ir.blocks[lbl] for stmt in blk.body: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): if ( - isinstance(stmt.value, ir.Expr) + isinstance(stmt.value, ir.expr_types) and stmt.value.op == "getitem" ): # try a couple of spellings... a[i] and ref(a)[i] @@ -1508,7 +1512,7 @@ def assess_loop(self, loop, func_ir, partial_typemap=None): # confident that tuple unrolling is behaving require opt-in # guard of `literal_unroll`, remove this later! phi_val_defn = guard(get_definition, func_ir, phi.value) - if not isinstance(phi_val_defn, ir.Expr): + if not isinstance(phi_val_defn, ir.expr_types): return False if not phi_val_defn.op == "call": return False @@ -1518,7 +1522,7 @@ def assess_loop(self, loop, func_ir, partial_typemap=None): func_var = guard(get_definition, func_ir, call.func) func = guard(get_definition, func_ir, func_var) if func is None or not isinstance( - func, (ir.Global, ir.FreeVar) + func, ir.global_types + ir.freevar_types ): return False if ( @@ -1558,9 +1562,9 @@ def get_range(a): # look for iternext idx = 0 for stmt in entry_block.body: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): if ( - isinstance(stmt.value, ir.Expr) + isinstance(stmt.value, ir.expr_types) and stmt.value.op == "getiter" ): break @@ -1615,7 +1619,7 @@ def get_range(a): # replace RHS use of induction var with getitem for lbl in check_blocks: for stmt in func_ir.blocks[lbl].body: - if isinstance(stmt, ir.Assign): + if isinstance(stmt, ir.assign_types): # check for aliases try: lookup = getattr(stmt.value, "name", None) @@ -1675,15 +1679,20 @@ def run_pass(self, state): changed = False for block in func_ir.blocks.values(): - for assign in block.find_insts(ir.Assign): + for assign in block.find_insts(ir.assign_types): value = assign.value - if isinstance(value, (ir.Arg, ir.Const, ir.FreeVar, ir.Global)): + if ( + isinstance(value, ir.arg_types) + or isinstance(value, ir.const_types) + or isinstance(value, ir.freevar_types) + or isinstance(value, ir.global_types) + ): continue # 1) Don't change return stmt in the form # $return_xyz = cast(value=ABC) # 2) Don't propagate literal values that are not primitives - if isinstance(value, ir.Expr) and value.op in ( + if isinstance(value, ir.expr_types) and value.op in ( "cast", "build_map", "build_list", @@ -1716,13 +1725,13 @@ def run_pass(self, state): # At the moment, one avoid propagating the literal # value if the argument is a PHI node - if isinstance(value, ir.Expr) and value.op == "call": + if isinstance(value, ir.expr_types) and value.op == "call": fn = guard(get_definition, func_ir, value.func.name) if fn is None: continue if not ( - isinstance(fn, ir.Global) + isinstance(fn, ir.global_types) and fn.name in accepted_functions ): continue @@ -1731,7 +1740,10 @@ def run_pass(self, state): # check if any of the args to isinstance is a PHI node iv = func_ir._definitions[arg.name] assert len(iv) == 1 # SSA! - if isinstance(iv[0], ir.Expr) and iv[0].op == "phi": + if ( + isinstance(iv[0], ir.expr_types) + and iv[0].op == "phi" + ): msg = ( f"{fn.name}() cannot determine the " f'type of variable "{arg.unversioned_name}" ' @@ -1741,7 +1753,7 @@ def run_pass(self, state): # Only propagate a PHI node if all arguments are the same # constant - if isinstance(value, ir.Expr) and value.op == "phi": + if isinstance(value, ir.expr_types) and value.op == "phi": # typemap will return None in case `inc.name` not in typemap v = [typemap.get(inc.name) for inc in value.incoming_values] # stop if the elements in `v` do not hold the same value @@ -1788,8 +1800,10 @@ def run_pass(self, state): found = False func_ir = state.func_ir for blk in func_ir.blocks.values(): - for asgn in blk.find_insts(ir.Assign): - if isinstance(asgn.value, (ir.Global, ir.FreeVar)): + for asgn in blk.find_insts(ir.assign_types): + if isinstance(asgn.value, ir.global_types) or isinstance( + asgn.value, ir.freevar_types + ): value = asgn.value.value if value is isinstance or value is hasattr: found = True @@ -1835,8 +1849,10 @@ def run_pass(self, state): found = False func_ir = state.func_ir for blk in func_ir.blocks.values(): - for asgn in blk.find_insts(ir.Assign): - if isinstance(asgn.value, (ir.Global, ir.FreeVar)): + for asgn in blk.find_insts(ir.assign_types): + if isinstance(asgn.value, ir.global_types) or isinstance( + asgn.value, ir.freevar_types + ): if asgn.value.value is literal_unroll: found = True break @@ -1953,7 +1969,7 @@ def run_pass(self, state): changed = False for block in func_ir.blocks.values(): - for raise_ in block.find_insts((ir.Raise, ir.TryRaise)): + for raise_ in block.find_insts(ir.raise_types + ir.tryraise_types): call_inst = guard(get_definition, func_ir, raise_.exception) if call_inst is None: continue diff --git a/numba_cuda/numba/cuda/lowering.py b/numba_cuda/numba/cuda/lowering.py index 56d417c6e..c849e9ddd 100644 --- a/numba_cuda/numba/cuda/lowering.py +++ b/numba_cuda/numba/cuda/lowering.py @@ -442,7 +442,9 @@ def _find_singly_assigned_variable(self): # Ensure that the variable is not defined multiple times # in the block [defblk] = var_assign_map[var] - assign_stmts = self.blocks[defblk].find_insts(ir.Assign) + assign_stmts = self.blocks[defblk].find_insts( + ir.assign_types + ) assigns = [ stmt for stmt in assign_stmts @@ -469,7 +471,7 @@ def pre_block(self, block): self.builder.position_at_end(bb) all_names = set() for block in self.blocks.values(): - for x in block.find_insts(ir.Del): + for x in block.find_insts(ir.del_types): if x.value not in all_names: all_names.add(x.value) for name in all_names: @@ -484,9 +486,9 @@ def pre_block(self, block): self.func_ir, call.func, ) - if defn is not None and isinstance(defn, ir.Global): + if defn is not None and isinstance(defn, ir.global_types): if defn.value is eh.exception_check: - if isinstance(block.terminator, ir.Branch): + if isinstance(block.terminator, ir.branch_types): targetblk = self.blkmap[block.terminator.truebr] # NOTE: This hacks in an attribute for call_conv to # pick up. This hack is no longer needed when @@ -506,19 +508,19 @@ def lower_inst(self, inst): self.debuginfo.mark_location(self.builder, self.loc.line) self.notify_loc(self.loc) self.debug_print(str(inst)) - if isinstance(inst, ir.Assign): + if isinstance(inst, ir.assign_types): ty = self.typeof(inst.target.name) val = self.lower_assign(ty, inst) argidx = None # If this is a store from an arg, like x = arg.x then tell debuginfo # that this is the arg - if isinstance(inst.value, ir.Arg): + if isinstance(inst.value, ir.arg_types): # NOTE: debug location is the `def ` line self.debuginfo.mark_location(self.builder, self.defn_loc.line) argidx = inst.value.index + 1 # args start at 1 self.storevar(val, inst.target.name, argidx=argidx) - elif isinstance(inst, ir.Branch): + elif isinstance(inst, ir.branch_types): cond = self.loadvar(inst.cond.name) tr = self.blkmap[inst.truebr] fl = self.blkmap[inst.falsebr] @@ -530,11 +532,11 @@ def lower_inst(self, inst): ) self.builder.cbranch(pred, tr, fl) - elif isinstance(inst, ir.Jump): + elif isinstance(inst, ir.jump_types): target = self.blkmap[inst.target] self.builder.branch(target) - elif isinstance(inst, ir.Return): + elif isinstance(inst, ir.return_types): if self.generator_info: # StopIteration self.genlower.return_from_generator(self) @@ -552,10 +554,10 @@ def lower_inst(self, inst): retval = self.context.get_return_value(self.builder, ty, val) self.call_conv.return_value(self.builder, retval) - elif isinstance(inst, ir.PopBlock): + elif isinstance(inst, ir.popblock_types): pass # this is just a marker - elif isinstance(inst, ir.StaticSetItem): + elif isinstance(inst, ir.staticsetitem_types): signature = self.fndesc.calltypes[inst] assert signature is not None try: @@ -573,22 +575,22 @@ def lower_inst(self, inst): ) return impl(self.builder, (target, inst.index, value)) - elif isinstance(inst, ir.Print): + elif isinstance(inst, ir.print_types): self.lower_print(inst) - elif isinstance(inst, ir.SetItem): + elif isinstance(inst, ir.setitem_types): signature = self.fndesc.calltypes[inst] assert signature is not None return self.lower_setitem( inst.target, inst.index, inst.value, signature ) - elif isinstance(inst, ir.StoreMap): + elif isinstance(inst, ir.storemap_types): signature = self.fndesc.calltypes[inst] assert signature is not None return self.lower_setitem(inst.dct, inst.key, inst.value, signature) - elif isinstance(inst, ir.DelItem): + elif isinstance(inst, ir.delitem_types): target = self.loadvar(inst.target.name) index = self.loadvar(inst.index.name) @@ -614,10 +616,10 @@ def lower_inst(self, inst): return impl(self.builder, (target, index)) - elif isinstance(inst, ir.Del): + elif isinstance(inst, ir.del_types): self.delvar(inst.value) - elif isinstance(inst, ir.SetAttr): + elif isinstance(inst, ir.setattr_types): target = self.loadvar(inst.target.name) value = self.loadvar(inst.value.name) signature = self.fndesc.calltypes[inst] @@ -635,16 +637,16 @@ def lower_inst(self, inst): return impl(self.builder, (target, value)) - elif isinstance(inst, ir.DynamicRaise): + elif isinstance(inst, ir.dynamicraise_types): self.lower_dynamic_raise(inst) - elif isinstance(inst, ir.DynamicTryRaise): + elif isinstance(inst, ir.dynamictryraise_types): self.lower_try_dynamic_raise(inst) - elif isinstance(inst, ir.StaticRaise): + elif isinstance(inst, ir.staticraise_types): self.lower_static_raise(inst) - elif isinstance(inst, ir.StaticTryRaise): + elif isinstance(inst, ir.statictryraise_types): self.lower_static_try_raise(inst) else: @@ -696,7 +698,7 @@ def lower_dynamic_raise(self, inst): args = [] nb_types = [] for exc_arg in exc_args: - if isinstance(exc_arg, ir.Var): + if isinstance(exc_arg, ir.var_types): # dynamic values typ = self.typeof(exc_arg.name) val = self.loadvar(exc_arg.name) @@ -728,24 +730,28 @@ def lower_static_try_raise(self, inst): def lower_assign(self, ty, inst): value = inst.value # In nopython mode, closure vars are frozen like globals - if isinstance(value, (ir.Const, ir.Global, ir.FreeVar)): + if ( + isinstance(value, ir.const_types) + or isinstance(value, ir.global_types) + or isinstance(value, ir.freevar_types) + ): res = self.context.get_constant_generic( self.builder, ty, value.value ) self.incref(ty, res) return res - elif isinstance(value, ir.Expr): + elif isinstance(value, ir.expr_types): return self.lower_expr(ty, value) - elif isinstance(value, ir.Var): + elif isinstance(value, ir.var_types): val = self.loadvar(value.name) oty = self.typeof(value.name) res = self.context.cast(self.builder, val, oty, ty) self.incref(ty, res) return res - elif isinstance(value, ir.Arg): + elif isinstance(value, ir.arg_types): # Suspend debug info else all the arg repacking ends up being # associated with some line or other and it's actually just a detail # of Numba's CC. @@ -771,7 +777,7 @@ def lower_assign(self, ty, inst): self.incref(ty, res) return res - elif isinstance(value, ir.Yield): + elif isinstance(value, ir.yield_types): res = self.lower_yield(ty, value) self.incref(ty, res) return res @@ -1814,7 +1820,7 @@ def pre_block(self, block): self.dbg_val_names = set() if self.context.enable_debuginfo and self._disable_sroa_like_opt: - for x in block.find_insts(ir.Assign): + for x in block.find_insts(ir.assign_types): if x.target.name.startswith("$"): continue ssa_name = x.target.name @@ -1847,7 +1853,7 @@ def pre_lower(self): poly_map = {} # pre-scan all blocks for block in self.blocks.values(): - for x in block.find_insts(ir.Assign): + for x in block.find_insts(ir.assign_types): if x.target.name.startswith("$"): continue ssa_name = x.target.name diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_numba_interop.py b/numba_cuda/numba/cuda/tests/cudapy/test_numba_interop.py new file mode 100644 index 000000000..b7cfb5bc8 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_numba_interop.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import numpy as np + +from numba import cuda +from numba.cuda import HAS_NUMBA +from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim + +if HAS_NUMBA: + from numba.extending import overload + + +@skip_on_cudasim("Simulator does not support the extension API") +@unittest.skipUnless(HAS_NUMBA, "Tests interoperability with Numba") +class TestNumbaInterop(CUDATestCase): + def test_overload_inline_always(self): + # From Issue #624 + def get_42(): + raise NotImplementedError() + + @overload(get_42, target="cuda", inline="always") + def ol_blas_get_accumulator(): + def impl(): + return 42 + + return impl + + @cuda.jit + def kernel(a): + a[0] = get_42() + + a = np.empty(1, dtype=np.float32) + kernel[1, 1](a) + np.testing.assert_equal(a[0], 42) diff --git a/numba_cuda/numba/cuda/typing/context.py b/numba_cuda/numba/cuda/typing/context.py index 6c2c933bf..9514d1e8b 100644 --- a/numba_cuda/numba/cuda/typing/context.py +++ b/numba_cuda/numba/cuda/typing/context.py @@ -460,7 +460,9 @@ def is_for_this_target(ftcls): def is_external(obj): """Check if obj is from outside numba.* namespace.""" try: - return not obj.__module__.startswith("numba.") + is_numba_module = obj.__module__.startswith("numba.") + is_test_module = obj.__module__.startswith("numba.cuda.tests.") + return not is_numba_module or is_test_module except AttributeError: return True