Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions numba_cuda/numba/cuda/core/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@ def compute_use_defs(blocks):
func = ir_extension_usedefs[type(stmt)]
func(stmt, use_set, def_set)
continue
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Inst):
if isinstance(stmt, ir.assign_types):
if isinstance(stmt.value, ir.inst_types):
rhs_set = set(var.name for var in stmt.value.list_vars())
elif isinstance(stmt.value, ir.Var):
elif isinstance(stmt.value, ir.var_types):
rhs_set = set([stmt.value.name])
elif isinstance(
stmt.value, (ir.Arg, ir.Const, ir.Global, ir.FreeVar)
elif (
isinstance(stmt.value, ir.arg_types)
or isinstance(stmt.value, ir.const_types)
or isinstance(stmt.value, ir.global_types)
or isinstance(stmt.value, ir.freevar_types)
):
rhs_set = ()
else:
Expand Down Expand Up @@ -326,7 +329,7 @@ def rewrite_array_ndim(val, func_ir, called_args):
if getattr(val, "op", None) == "getattr":
if val.attr == "ndim":
arg_def = guard(get_definition, func_ir, val.value)
if isinstance(arg_def, ir.Arg):
if isinstance(arg_def, ir.arg_types):
argty = called_args[arg_def.index]
if isinstance(argty, types.Array):
rewrite_statement(func_ir, stmt, argty.ndim)
Expand All @@ -337,17 +340,17 @@ def rewrite_tuple_len(val, func_ir, called_args):
func = guard(get_definition, func_ir, val.func)
if (
func is not None
and isinstance(func, ir.Global)
and isinstance(func, ir.global_types)
and getattr(func, "value", None) is len
):
(arg,) = val.args
arg_def = guard(get_definition, func_ir, arg)
if isinstance(arg_def, ir.Arg):
if isinstance(arg_def, ir.arg_types):
argty = called_args[arg_def.index]
if isinstance(argty, types.BaseTuple):
rewrite_statement(func_ir, stmt, argty.count)
elif (
isinstance(arg_def, ir.Expr)
isinstance(arg_def, ir.expr_types)
and arg_def.op == "typed_getitem"
):
argty = arg_def.dtype
Expand All @@ -358,9 +361,9 @@ def rewrite_tuple_len(val, func_ir, called_args):

for blk in func_ir.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt, ir.assign_types):
val = stmt.value
if isinstance(val, ir.Expr):
if isinstance(val, ir.expr_types):
rewrite_array_ndim(val, func_ir, called_args)
rewrite_tuple_len(val, func_ir, called_args)

Expand Down Expand Up @@ -391,7 +394,7 @@ def find_literally_calls(func_ir, argtypes):
for blk in func_ir.blocks.values():
for assign in blk.find_exprs(op="call"):
var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func)
if isinstance(var, (ir.Global, ir.FreeVar)):
if isinstance(var, ir.global_types + ir.freevar_types):
fnobj = var.value
else:
fnobj = ir_utils.guard(
Expand All @@ -401,7 +404,7 @@ def find_literally_calls(func_ir, argtypes):
# Found
[arg] = assign.args
defarg = func_ir.get_definition(arg)
if isinstance(defarg, ir.Arg):
if isinstance(defarg, ir.arg_types):
argindex = defarg.index
marked_args.add(argindex)
first_loc.setdefault(argindex, assign.loc)
Expand Down Expand Up @@ -473,14 +476,14 @@ def find_branches(func_ir):
branches = []
for blk in func_ir.blocks.values():
branch_or_jump = blk.body[-1]
if isinstance(branch_or_jump, ir.Branch):
if isinstance(branch_or_jump, ir.branch_types):
branch = branch_or_jump
pred = guard(get_definition, func_ir, branch.cond.name)
if pred is not None and getattr(pred, "op", None) == "call":
function = guard(get_definition, func_ir, pred.func)
if (
function is not None
and isinstance(function, ir.Global)
and isinstance(function, ir.global_types)
and function.value is bool
):
condition = guard(get_definition, func_ir, pred.args[0])
Expand Down Expand Up @@ -539,7 +542,9 @@ def prune_by_predicate(branch, pred, blk):
try:
# Just to prevent accidents, whilst already guarded, ensure this
# is an ir.Const
if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)):
if not isinstance(
pred, ir.const_types + ir.freevar_types + ir.global_types
):
raise TypeError("Expected constant Numba IR node")
take_truebr = bool(pred.value)
except TypeError:
Expand Down Expand Up @@ -584,8 +589,11 @@ def resolve_input_arg_const(input_arg_idx):
phi2asgn = dict()
for lbl, blk in func_ir.blocks.items():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Expr) and stmt.value.op == "phi":
if isinstance(stmt, ir.assign_types):
if (
isinstance(stmt.value, ir.expr_types)
and stmt.value.op == "phi"
):
phi2lbl[stmt.value] = lbl
phi2asgn[stmt.value] = stmt

Expand All @@ -599,12 +607,12 @@ def resolve_input_arg_const(input_arg_idx):

for branch, condition, blk in branch_info:
const_conds = []
if isinstance(condition, ir.Expr) and condition.op == "binop":
if isinstance(condition, ir.expr_types) and condition.op == "binop":
prune = prune_by_value
for arg in [condition.lhs, condition.rhs]:
resolved_const = Unknown()
arg_def = guard(get_definition, func_ir, arg)
if isinstance(arg_def, ir.Arg):
if isinstance(arg_def, ir.arg_types):
# it's an e.g. literal argument to the function
resolved_const = resolve_input_arg_const(arg_def.index)
prune = prune_by_type
Expand Down Expand Up @@ -668,7 +676,7 @@ def resolve_input_arg_const(input_arg_idx):
for _, cond, blk in branch_info:
if cond in deadcond:
for x in blk.body:
if isinstance(x, ir.Assign) and x.value is cond:
if isinstance(x, ir.assign_types) and x.value is cond:
# rewrite the condition as a true/false bit
nullified_info = nullified_conditions[deadcond.index(cond)]
# only do a rewrite of conditions, predicates need to retain
Expand Down
8 changes: 4 additions & 4 deletions numba_cuda/numba/cuda/core/annotations/type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ def prepare_annotations(self):
for inst in blk.body:
lineno = inst.loc.line

if isinstance(inst, ir.Assign):
if isinstance(inst, ir.assign_types):
if found_lifted_loop:
atype = "XXX Lifted Loop XXX"
found_lifted_loop = False
elif (
isinstance(inst.value, ir.Expr)
isinstance(inst.value, ir.expr_types)
and inst.value.op == "call"
):
atype = self.calltypes[inst.value]
elif isinstance(inst.value, ir.Const) and isinstance(
elif isinstance(inst.value, ir.const_types) and isinstance(
inst.value.value, LiftedLoop
):
atype = "XXX Lifted Loop XXX"
Expand All @@ -113,7 +113,7 @@ def prepare_annotations(self):
atype = self.typemap.get(inst.target.name, "<missing>")

aline = "%s = %s :: %s" % (inst.target, inst.value, atype)
elif isinstance(inst, ir.SetItem):
elif isinstance(inst, ir.setitem_types):
atype = self.calltypes[inst]
aline = "%s :: %s" % (inst, atype)
else:
Expand Down
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/core/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _do_infer(self, name):
try:
const = defn.infer_constant()
except ConstantInferenceError:
if isinstance(defn, ir.Expr):
if isinstance(defn, ir.expr_types):
return self._infer_expr(defn)
self._fail(defn)
return const
Expand Down
Loading
Loading