Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ void init_triton_ir(py::module &&m) {
return !self.empty() &&
self.back().hasTrait<mlir::OpTrait::IsTerminator>();
})
.def("has_return",
[](mlir::Block &self) {
return !self.empty() &&
self.back().hasTrait<mlir::OpTrait::ReturnLike>();
})
.def("erase", [](mlir::Block &self) { self.erase(); });

// using eattr = ir::attribute_kind_t;
Expand Down Expand Up @@ -428,6 +433,25 @@ void init_triton_ir(py::module &&m) {
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
},
ret::reference)
.def("finalize",
[](mlir::triton::FuncOp &self) -> void {
// Remove dead code
// 1. Unreachable code after return
self.walk([&](mlir::Block *block) {
mlir::Operation *retOp = nullptr;
block->walk([&](mlir::Operation *op) {
if (mlir::isa<mlir::triton::ReturnOp>(op))
if (retOp == nullptr)
retOp = op;
});
if (retOp && retOp != &block->back()) {
auto pos = retOp->getIterator();
pos++;
auto *newBlock = block->splitBlock(pos);
newBlock->erase();
}
});
})
.def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType)
.def("reset_type", &mlir::triton::FuncOp::setType);

Expand Down
74 changes: 64 additions & 10 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,24 +2550,30 @@ def kernel(Cond, TrueVal, FalseVal, Out):
assert to_numpy(out)[0] == false_val[0]


def test_if_return():
@pytest.mark.parametrize("mode", ["dynamic", "static"])
def test_if_return(mode):

@triton.jit
def kernel(ExitEarly, Out):
if tl.load(ExitEarly):
tl.store(Out, 0)
return
def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr):
if mode == "dynamic":
if tl.load(ExitEarly):
tl.store(Out, 0)
return
else:
if cond:
tl.store(Out, 0)
return
tl.store(Out, 1)

out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
# exit early path taken
exit_early[0] = 1
kernel[(1,)](exit_early, out)
kernel[(1,)](exit_early, out, True, mode)
assert to_numpy(out)[0] == 0
# exit early path not taken
exit_early[0] = 0
kernel[(1,)](exit_early, out)
kernel[(1,)](exit_early, out, False, mode)
assert to_numpy(out)[0] == 1


Expand All @@ -2576,21 +2582,69 @@ def add_fn(x):
return x + 1


@pytest.mark.parametrize("call_type", ["attribute", "jit_function"])
@triton.jit(noinline=True)
def add_fn_noinline(x):
return x + 1


@triton.jit
def add_fn_return(x, pid):
if pid == 0:
return x + 1
else:
return x + 2


@triton.jit
def add_fn_expr(Out, x):
tl.store(Out, x)


@triton.jit
def add_fn_static_cond(x, cond: tl.constexpr):
if cond == "":
return x
else:
return x + 1


@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
def test_if_call(call_type):
@triton.jit
def kernel(Out, call_type: tl.constexpr):
pid = tl.program_id(0)
o = tl.load(Out)
if pid == 0:
if call_type == "attribute":
# call attribute
a = o + 1
a = a.to(tl.int32)
a = a.to(tl.int32).to(tl.int32)
o = a
else:
a = o
a = add_fn(a)
if call_type == "jit_function":
# regular function call
a = add_fn(a)
elif call_type == "jit_function_return":
# function without end_if block
a = add_fn_return(a, pid)
elif call_type == "ifexp":
# ifexp expression
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
elif call_type == "expr":
if pid == 1:
return
a = add_fn(a)
if pid == 0:
# call without return
add_fn_expr(Out, a)
elif call_type == "jit_function_static_cond":
a = add_fn_static_cond(a, call_type)
elif call_type == "jit_function_noinline":
a = add_fn_noinline(a)
o = a

tl.store(Out, o)

out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
Expand Down
71 changes: 52 additions & 19 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
self.debug = debug
self.noinline = noinline
self.scf_stack = []
self.last_ret_type = None
# SSA-construction
# name => language.tensor
self.local_defs: Dict[str, tensor] = {}
Expand Down Expand Up @@ -138,7 +139,7 @@ def name_lookup(name: str) -> Any:
def set_value(self, name: str,
value: Union[tensor, constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
2. store tensor in self.lvalue
'''
Expand All @@ -150,10 +151,9 @@ def set_value(self, name: str,
#
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret_type = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
ret_type = self.visit(stmt)
if ret_type is not None and isinstance(stmt, ast.Return):
self.last_ret_type = ret_type

# TODO: should be its own AST visitor
def contains_return_op(self, node):
Expand All @@ -168,10 +168,23 @@ def contains_return_op(self, node):
pred = lambda s: self.contains_return_op(s)
return any(pred(s) for s in node.body)
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute):
def check_undefined_name(cur_node):
# Check if name is an undefined local variable,
# which can only be a tensor or a constexpr
if isinstance(cur_node.func, ast.Attribute):
if isinstance(cur_node.func.value, ast.Name):
name = cur_node.func.value.id
if name not in self.lscope and name not in self.gscope:
return True
return False
# chain of calls
# e.g., tl.load(a).to(tl.float32)
return check_undefined_name(cur_node.func.value)
return False
if check_undefined_name(node):
return False
fn = self.visit(node.func)
if isinstance(fn, JITFunction):
if isinstance(fn, JITFunction) and fn.noinline is False:
old_gscope = self.gscope
self.gscope = sys.modules[fn.fn.__module__].__dict__
ret = self.contains_return_op(fn.parse())
Expand All @@ -184,6 +197,18 @@ def contains_return_op(self, node):
if node.orelse:
ret = ret or any(pred(s) for s in node.orelse)
return ret
elif isinstance(node, ast.IfExp):
return self.contains_return_op(node.body) or self.contains_return_op(node.orelse)
elif isinstance(node, ast.Expr):
ret = False
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
ret = ret or self.contains_return_op(item)
elif isinstance(value, ast.AST):
ret = ret or self.contains_return_op(value)
return ret
else:
return False

Expand Down Expand Up @@ -257,9 +282,9 @@ def visit_FunctionDef(self, node):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
has_ret = self.visit_compound_statement(node.body)
self.visit_compound_statement(node.body)
# finalize function
if not has_ret:
if self.last_ret_type is None:
self.builder.ret([])
else:
# update return type
Expand All @@ -271,6 +296,8 @@ def visit_FunctionDef(self, node):
fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
# Remove dead code
fn.finalize()

def visit_arguments(self, node):
arg_names = []
Expand Down Expand Up @@ -421,6 +448,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block):
return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types

def visit_if_top_level(self, cond, node):
has_endif_block = True
with enter_sub_region(self) as sr:
liveins, ip_block = sr
then_block = self.builder.create_block()
Expand All @@ -435,20 +463,25 @@ def visit_if_top_level(self, cond, node):
self.visit_then_else_blocks(node, liveins, then_block, else_block)
# then terminator
self.builder.set_insertion_point_to_end(then_block)
if not then_block.has_terminator():
if then_block.has_return() and else_block.has_return():
has_endif_block = False
endif_block.erase()
if not then_block.has_terminator() and has_endif_block:
self.builder.create_branch(endif_block, [then_defs[n].handle for n in names])
# else terminator
self.builder.set_insertion_point_to_end(else_block)
if not else_block.has_terminator():
if not else_block.has_terminator() and has_endif_block:
self.builder.create_branch(endif_block, [else_defs[n].handle for n in names])
for ty in ir_ret_types:
endif_block.add_argument(ty)
# change block
self.builder.set_insertion_point_to_start(endif_block)
# update value
for i, name in enumerate(names):
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
self.set_value(name, new_tensor)
if has_endif_block:
for ty in ir_ret_types:
endif_block.add_argument(ty)
if has_endif_block:
# change block
self.builder.set_insertion_point_to_start(endif_block)
# update value
for i, name in enumerate(names):
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
self.set_value(name, new_tensor)

# TODO: refactor
def visit_if_scf(self, cond, node):
Expand Down