From f6217a60287504a59c1d11b3c429ab0b1b09b0d7 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Mon, 8 May 2023 08:46:56 -0400 Subject: [PATCH 1/5] Update --- python/src/triton.cc | 24 ++++++++++ python/test/unit/language/test_core.py | 54 +++++++++++++++++---- python/triton/compiler/code_generator.py | 61 +++++++++++++++++------- 3 files changed, 112 insertions(+), 27 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 6c2e70e6f386..8c36dd5658b8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -262,6 +262,11 @@ void init_triton_ir(py::module &&m) { return !self.empty() && self.back().hasTrait(); }) + .def("has_return", + [](mlir::Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) .def("erase", [](mlir::Block &self) { self.erase(); }); // using eattr = ir::attribute_kind_t; @@ -428,6 +433,25 @@ void init_triton_ir(py::module &&m) { self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val)); }, ret::reference) + .def("finalize", + [](mlir::triton::FuncOp &self) -> void { + // Remove dead code + // 1. Unreachable code after return + self.walk([&](mlir::Block *block) { + mlir::Operation *retOp = nullptr; + block->walk([&](mlir::Operation *op) { + if (mlir::isa(op)) + if (retOp == nullptr) + retOp = op; + }); + if (retOp && retOp != &block->back()) { + auto pos = retOp->getIterator(); + pos++; + auto *newBlock = block->splitBlock(pos); + newBlock->erase(); + } + }); + }) .def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType) .def("reset_type", &mlir::triton::FuncOp::setType); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 11981f54c328..8bb71cb8afe5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2550,24 +2550,30 @@ def kernel(Cond, TrueVal, FalseVal, Out): assert to_numpy(out)[0] == false_val[0] -def test_if_return(): +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode): @triton.jit - def kernel(ExitEarly, Out): - if tl.load(ExitEarly): - tl.store(Out, 0) - return + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return tl.store(Out, 1) out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda') exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda') # exit early path taken exit_early[0] = 1 - kernel[(1,)](exit_early, out) + kernel[(1,)](exit_early, out, True, mode) assert to_numpy(out)[0] == 0 # exit early path not taken exit_early[0] = 0 - kernel[(1,)](exit_early, out) + kernel[(1,)](exit_early, out, False, mode) assert to_numpy(out)[0] == 1 @@ -2576,7 +2582,20 @@ def add_fn(x): return x + 1 -@pytest.mark.parametrize("call_type", ["attribute", "jit_function"]) +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return", "jit_function_ifexp", "jit_function_expr"]) def test_if_call(call_type): @triton.jit def kernel(Out, call_type: tl.constexpr): @@ -2584,13 +2603,30 @@ def kernel(Out, call_type: tl.constexpr): o = tl.load(Out) if pid == 0: if call_type == "attribute": + # call attribute a = o + 1 a = a.to(tl.int32) o = a else: a = o - a = add_fn(a) + if call_type == "jit_function": + # regular function call + a = add_fn(a) + elif call_type == "jit_function_return": + # function without end_if block + a = add_fn_return(a, pid) + elif call_type == "jit_function_ifexp": + # ifexp expression + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + elif call_type == "jit_function_expr": + if pid == 1: + return + a = add_fn(a) + if pid == 0: + # call without return + add_fn_expr(Out, a) o = a + tl.store(Out, o) out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda') diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 7d0a28cd1f47..4eca04fde04e 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -104,6 +104,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.debug = debug self.noinline = noinline self.scf_stack = [] + self.last_ret_type = None # SSA-construction # name => language.tensor self.local_defs: Dict[str, tensor] = {} @@ -138,7 +139,7 @@ def name_lookup(name: str) -> Any: def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: ''' This function: - called by visit_Assign() & visit_FuncDef() to store left value (lvalue) + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) 1. record local defined name (FIXME: should consider control flow) 2. store tensor in self.lvalue ''' @@ -150,10 +151,9 @@ def set_value(self, name: str, # def visit_compound_statement(self, stmts): for stmt in stmts: - self.last_ret_type = self.visit(stmt) - if isinstance(stmt, ast.Return): - break - return stmts and isinstance(stmt, ast.Return) + ret_type = self.visit(stmt) + if ret_type is not None and isinstance(stmt, ast.Return): + self.last_ret_type = ret_type # TODO: should be its own AST visitor def contains_return_op(self, node): @@ -169,7 +169,11 @@ def contains_return_op(self, node): return any(pred(s) for s in node.body) elif isinstance(node, ast.Call): if isinstance(node.func, ast.Attribute): - return False + # Check if name is an undefined local variable, + # which can only be a tensor or a constexpr + name = node.func.value.id + if name not in self.lscope and name not in self.gscope: + return False fn = self.visit(node.func) if isinstance(fn, JITFunction): old_gscope = self.gscope @@ -184,6 +188,18 @@ def contains_return_op(self, node): if node.orelse: ret = ret or any(pred(s) for s in node.orelse) return ret + elif isinstance(node, ast.IfExp): + return self.contains_return_op(node.body) or self.contains_return_op(node.orelse) + elif isinstance(node, ast.Expr): + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.contains_return_op(item) + elif isinstance(value, ast.AST): + ret = ret or self.contains_return_op(value) + return ret else: return False @@ -257,9 +273,9 @@ def visit_FunctionDef(self, node): self.set_value(arg_name, arg_value) self.builder.set_insertion_point_to_start(entry) # visit function body - has_ret = self.visit_compound_statement(node.body) + self.visit_compound_statement(node.body) # finalize function - if not has_ret: + if self.last_ret_type is None: self.builder.ret([]) else: # update return type @@ -271,6 +287,9 @@ def visit_FunctionDef(self, node): fn.reset_type(self.prototype.to_ir(self.builder)) if insert_pt: self.builder.set_insertion_point_to_end(insert_pt) + # We could finalize here after last_ret_type is set + # because triton requires each return op to be the same type + fn.finalize() def visit_arguments(self, node): arg_names = [] @@ -421,6 +440,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types def visit_if_top_level(self, cond, node): + has_endif_block = True with enter_sub_region(self) as sr: liveins, ip_block = sr then_block = self.builder.create_block() @@ -435,20 +455,25 @@ def visit_if_top_level(self, cond, node): self.visit_then_else_blocks(node, liveins, then_block, else_block) # then terminator self.builder.set_insertion_point_to_end(then_block) - if not then_block.has_terminator(): + if not then_block.has_terminator() and not then_block.has_return(): self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) # else terminator self.builder.set_insertion_point_to_end(else_block) - if not else_block.has_terminator(): + if not else_block.has_terminator() and not else_block.has_return(): self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) - for ty in ir_ret_types: - endif_block.add_argument(ty) - # change block - self.builder.set_insertion_point_to_start(endif_block) - # update value - for i, name in enumerate(names): - new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) - self.set_value(name, new_tensor) + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + else: + for ty in ir_ret_types: + endif_block.add_argument(ty) + if has_endif_block: + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) # TODO: refactor def visit_if_scf(self, cond, node): From abc0eca2fe7a6bfa4211d7a6e70591ceef7e7ddc Mon Sep 17 00:00:00 2001 From: Jokeren Date: Mon, 8 May 2023 08:49:53 -0400 Subject: [PATCH 2/5] Simplify --- python/triton/compiler/code_generator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 4eca04fde04e..f262bd31397c 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -455,16 +455,16 @@ def visit_if_top_level(self, cond, node): self.visit_then_else_blocks(node, liveins, then_block, else_block) # then terminator self.builder.set_insertion_point_to_end(then_block) - if not then_block.has_terminator() and not then_block.has_return(): + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + if not then_block.has_terminator() and has_endif_block: self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) # else terminator self.builder.set_insertion_point_to_end(else_block) - if not else_block.has_terminator() and not else_block.has_return(): + if not else_block.has_terminator() and has_endif_block: self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) - if then_block.has_return() and else_block.has_return(): - has_endif_block = False - endif_block.erase() - else: + if has_endif_block: for ty in ir_ret_types: endif_block.add_argument(ty) if has_endif_block: From 2f7bea4cea507ebb3189c2960f892a5ecaf69f3a Mon Sep 17 00:00:00 2001 From: Jokeren Date: Mon, 8 May 2023 08:51:50 -0400 Subject: [PATCH 3/5] Update --- python/triton/compiler/code_generator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index f262bd31397c..3068fc64be99 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -287,8 +287,7 @@ def visit_FunctionDef(self, node): fn.reset_type(self.prototype.to_ir(self.builder)) if insert_pt: self.builder.set_insertion_point_to_end(insert_pt) - # We could finalize here after last_ret_type is set - # because triton requires each return op to be the same type + # Remove dead code fn.finalize() def visit_arguments(self, node): From 0ef3c91ed6120a7881b09ccd465bb0f1e3dde309 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Mon, 8 May 2023 09:29:50 -0400 Subject: [PATCH 4/5] Update --- python/test/unit/language/test_core.py | 24 +++++++++++++++++++++--- python/triton/compiler/code_generator.py | 6 +++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8bb71cb8afe5..fb1819f80f74 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2582,6 +2582,11 @@ def add_fn(x): return x + 1 +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + @triton.jit def add_fn_return(x, pid): if pid == 0: @@ -2595,7 +2600,16 @@ def add_fn_expr(Out, x): tl.store(Out, x) -@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return", "jit_function_ifexp", "jit_function_expr"]) +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return", + "ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"]) def test_if_call(call_type): @triton.jit def kernel(Out, call_type: tl.constexpr): @@ -2615,16 +2629,20 @@ def kernel(Out, call_type: tl.constexpr): elif call_type == "jit_function_return": # function without end_if block a = add_fn_return(a, pid) - elif call_type == "jit_function_ifexp": + elif call_type == "ifexp": # ifexp expression a = add_fn(a) if pid == 0 else add_fn_return(a, pid) - elif call_type == "jit_function_expr": + elif call_type == "expr": if pid == 1: return a = add_fn(a) if pid == 0: # call without return add_fn_expr(Out, a) + elif call_type == "jit_function_static_cond": + a = add_fn_static_cond(a, call_type) + elif call_type == "jit_function_noinline": + a = add_fn_noinline(a) o = a tl.store(Out, o) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 3068fc64be99..2f7b0b2bf39e 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -175,7 +175,7 @@ def contains_return_op(self, node): if name not in self.lscope and name not in self.gscope: return False fn = self.visit(node.func) - if isinstance(fn, JITFunction): + if isinstance(fn, JITFunction) and fn.noinline is False: old_gscope = self.gscope self.gscope = sys.modules[fn.fn.__module__].__dict__ ret = self.contains_return_op(fn.parse()) @@ -195,8 +195,8 @@ def contains_return_op(self, node): for _, value in ast.iter_fields(node): if isinstance(value, list): for item in value: - if isinstance(item, ast.AST): - ret = ret or self.contains_return_op(item) + if isinstance(item, ast.AST): + ret = ret or self.contains_return_op(item) elif isinstance(value, ast.AST): ret = ret or self.contains_return_op(value) return ret From a1af77c1a16f3f22f36248bf7c4166c62e9c0bdf Mon Sep 17 00:00:00 2001 From: Jokeren Date: Mon, 8 May 2023 11:09:51 -0400 Subject: [PATCH 5/5] Update --- python/test/unit/language/test_core.py | 2 +- python/triton/compiler/code_generator.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index fb1819f80f74..58d18d09ea29 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2619,7 +2619,7 @@ def kernel(Out, call_type: tl.constexpr): if call_type == "attribute": # call attribute a = o + 1 - a = a.to(tl.int32) + a = a.to(tl.int32).to(tl.int32) o = a else: a = o diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 2f7b0b2bf39e..7d83de796033 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -168,12 +168,21 @@ def contains_return_op(self, node): pred = lambda s: self.contains_return_op(s) return any(pred(s) for s in node.body) elif isinstance(node, ast.Call): - if isinstance(node.func, ast.Attribute): + def check_undefined_name(cur_node): # Check if name is an undefined local variable, # which can only be a tensor or a constexpr - name = node.func.value.id - if name not in self.lscope and name not in self.gscope: - return False + if isinstance(cur_node.func, ast.Attribute): + if isinstance(cur_node.func.value, ast.Name): + name = cur_node.func.value.id + if name not in self.lscope and name not in self.gscope: + return True + return False + # chain of calls + # e.g., tl.load(a).to(tl.float32) + return check_undefined_name(cur_node.func.value) + return False + if check_undefined_name(node): + return False fn = self.visit(node.func) if isinstance(fn, JITFunction) and fn.noinline is False: old_gscope = self.gscope