From f480fa4ca315a387aebf2d9b16aa402269cddedc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 6 Mar 2024 08:50:38 -0600 Subject: [PATCH] [Bugfix][TIR] Handle AttrStmt of upcoming tir.Var in ConvertSSA In some cases, an `AttrStmt` may legally refer to a TIR variable that hasn't yet been defined. For example, the `"pragma_parallel_launch_point"` attribute, which annotates a variable that is about to occur in a ForNode. Prior to this commit, `ConvertSSA` treated the `AttrStmt` as the usage of a variable, followed by a nested definition to be de-duplicated. This resulted in the output `AttrStmt` containing a reference to an undefined variable. This commit updates `ConvertSSA` to handle this case. If an `AttrStmt` refers to a not-yet-defined variable, the body is visited before marking it as defined. This implementation may be simplified in the future by moving "pragma_parallel_launch_point" to be an annotation on the `ForNode`, rather than an `AttrStmt`. --- src/tir/transforms/ir_utils.cc | 34 +++++++++-- .../test_tir_transform_convert_ssa.py | 61 ++++++++++++++++++- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index a85bde6787f0..584b3cbf58f4 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -358,6 +358,7 @@ class IRConvertSSA final : public StmtExprMutator { } Var var = iter_var->var; + bool delayed_define = false; if (auto it = function_scope_var_remap_.find(var.get()); it != function_scope_var_remap_.end()) { var = it->second; @@ -373,8 +374,23 @@ class IRConvertSSA final : public StmtExprMutator { function_scope_var_remap_.insert({var.get(), new_var}); var = new_var; } else { - function_scope_var_remap_.insert({var.get(), var}); - defined_.insert(var.get()); + // The AttrStmt refers to an undefined variable. This is + // allowed for some attributes, such as + // "pragma_parallel_launch_point", which annotates a variable + // that is about to occur in a ForNode. In these cases, the + // ForNode and the AttrStmt must continue using the same + // variable defintion. + // + // However, other AttrStmt, such as "thread_extent", act as + // points of definition for the variable they annotate. If + // the variable has not been defined after visiting the body, + // we should mark it as defined before exiting. This ensures + // correct de-duplication between multiple functions. + // + // This implementation may be simplified in the future by + // moving "pragma_parallel_launch_point" to be an annotation + // on the `ForNode`, rather than an `AttrStmt`. + delayed_define = true; } IterVar new_iter_var; @@ -387,12 +403,22 @@ class IRConvertSSA final : public StmtExprMutator { auto value = VisitExpr(op->value); auto body = VisitStmt(op->body); + Stmt output; if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); + output = GetRef(op); } else { - return AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span); + output = AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span); } + if (delayed_define) { + if (!defined_.count(var.get())) { + function_scope_var_remap_.insert({var.get(), var}); + defined_.insert(var.get()); + } + } + + return output; + } else if (const VarNode* v = op->node.as()) { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index 140adcd35bd2..644ab3b624ef 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -17,7 +17,7 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tir, ir from tvm.script import tir as T, ir as I @@ -485,5 +485,64 @@ def kernel_2(A: T.Buffer([256], "float32")): return mod +class TestTrackForwardDeclarationsInAttrStmt(BaseBeforeAfter): + """T.attr statements may refer to a about-to-be-defined tir.Var""" + + def before(self): + """Generate the PrimFunc, which is already SSA + + This is constructed directly, rather than using TVMScript or + the `tvm.tir.ir_builder`. This test case requires a + `tir.AttrStmt` that references a variable, followed by the + `tir.For` defining that variable. This is not expressible in + either TVMScript or `tvm.tir.ir_builder`, as they only provide + the loop iterator within the body of the loop. + """ + i0_outer_outer = tir.Var("i0_outer_outer", "int32") + i0_outer_inner = tir.Var("i0_outer_inner", "int32") + i0_inner = tir.Var("i0_inner", "int32") + + A = tir.decl_buffer(1024, "float32", "A") + B = tir.decl_buffer(1024, "float32", "B") + + index = i0_outer_outer * 52 + i0_outer_inner * 4 + i0_inner + + stmt = tir.BufferStore(B, tir.BufferLoad(A, [index]), [index]) + stmt = tir.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256, stmt, None) + stmt = tir.For(i0_inner, 0, 4, tir.ForKind.VECTORIZED, stmt) + stmt = tir.For(i0_outer_inner, 0, 13, tir.ForKind.PARALLEL, stmt) + stmt = tir.AttrStmt( + T.iter_var(i0_outer_inner, None, "DataPar", ""), + "pragma_parallal_barrier_when_finish", + 1, + stmt, + ) + stmt = tir.AttrStmt( + T.iter_var(i0_outer_inner, None, "DataPar", ""), + "pragma_parallal_stride_pattern", + 1, + stmt, + ) + stmt = tir.For(i0_outer_outer, 0, 20, tir.ForKind.SERIAL, stmt) + stmt = tir.AttrStmt( + T.iter_var(i0_outer_outer, None, "DataPar", ""), + "pragma_parallal_launch_point", + 1, + stmt, + ) + + A_handle = tir.Var("A_handle", "handle") + B_handle = tir.Var("B_handle", "handle") + + func = tir.PrimFunc( + [A_handle, B_handle], + stmt, + buffer_map={A_handle: A, B_handle: B}, + ) + return func + + expected = before + + if __name__ == "__main__": tvm.testing.main()