Skip to content

Commit 898f87f

Browse files
authored
[Bugfix][TIR] Handle AttrStmt of upcoming tir.Var in ConvertSSA (#16682)
In some cases, an `AttrStmt` may legally refer to a TIR variable that hasn't yet been defined. For example, the `"pragma_parallel_launch_point"` attribute, which annotates a variable that is about to occur in a ForNode. Prior to this commit, `ConvertSSA` treated the `AttrStmt` as the usage of a variable, followed by a nested definition to be de-duplicated. This resulted in the output `AttrStmt` containing a reference to an undefined variable. This commit updates `ConvertSSA` to handle this case. If an `AttrStmt` refers to a not-yet-defined variable, the body is visited before marking it as defined. This implementation may be simplified in the future by moving "pragma_parallel_launch_point" to be an annotation on the `ForNode`, rather than an `AttrStmt`.
1 parent 7b7677f commit 898f87f

File tree

2 files changed

+90
-5
lines changed

2 files changed

+90
-5
lines changed

src/tir/transforms/ir_utils.cc

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class IRConvertSSA final : public StmtExprMutator {
358358
}
359359

360360
Var var = iter_var->var;
361+
bool delayed_define = false;
361362
if (auto it = function_scope_var_remap_.find(var.get());
362363
it != function_scope_var_remap_.end()) {
363364
var = it->second;
@@ -373,8 +374,23 @@ class IRConvertSSA final : public StmtExprMutator {
373374
function_scope_var_remap_.insert({var.get(), new_var});
374375
var = new_var;
375376
} else {
376-
function_scope_var_remap_.insert({var.get(), var});
377-
defined_.insert(var.get());
377+
// The AttrStmt refers to an undefined variable. This is
378+
// allowed for some attributes, such as
379+
// "pragma_parallel_launch_point", which annotates a variable
380+
// that is about to occur in a ForNode. In these cases, the
381+
// ForNode and the AttrStmt must continue using the same
382+
// variable defintion.
383+
//
384+
// However, other AttrStmt, such as "thread_extent", act as
385+
// points of definition for the variable they annotate. If
386+
// the variable has not been defined after visiting the body,
387+
// we should mark it as defined before exiting. This ensures
388+
// correct de-duplication between multiple functions.
389+
//
390+
// This implementation may be simplified in the future by
391+
// moving "pragma_parallel_launch_point" to be an annotation
392+
// on the `ForNode`, rather than an `AttrStmt`.
393+
delayed_define = true;
378394
}
379395

380396
IterVar new_iter_var;
@@ -387,12 +403,22 @@ class IRConvertSSA final : public StmtExprMutator {
387403
auto value = VisitExpr(op->value);
388404
auto body = VisitStmt(op->body);
389405

406+
Stmt output;
390407
if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) {
391-
return GetRef<Stmt>(op);
408+
output = GetRef<Stmt>(op);
392409
} else {
393-
return AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span);
410+
output = AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span);
394411
}
395412

413+
if (delayed_define) {
414+
if (!defined_.count(var.get())) {
415+
function_scope_var_remap_.insert({var.get(), var});
416+
defined_.insert(var.get());
417+
}
418+
}
419+
420+
return output;
421+
396422
} else if (const VarNode* v = op->node.as<VarNode>()) {
397423
Stmt stmt = StmtExprMutator::VisitStmt_(op);
398424
op = stmt.as<AttrStmtNode>();

tests/python/tir-transform/test_tir_transform_convert_ssa.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import tvm
1919
import tvm.testing
20-
from tvm import tir
20+
from tvm import tir, ir
2121
from tvm.script import tir as T, ir as I
2222

2323

@@ -485,5 +485,64 @@ def kernel_2(A: T.Buffer([256], "float32")):
485485
return mod
486486

487487

488+
class TestTrackForwardDeclarationsInAttrStmt(BaseBeforeAfter):
489+
"""T.attr statements may refer to a about-to-be-defined tir.Var"""
490+
491+
def before(self):
492+
"""Generate the PrimFunc, which is already SSA
493+
494+
This is constructed directly, rather than using TVMScript or
495+
the `tvm.tir.ir_builder`. This test case requires a
496+
`tir.AttrStmt` that references a variable, followed by the
497+
`tir.For` defining that variable. This is not expressible in
498+
either TVMScript or `tvm.tir.ir_builder`, as they only provide
499+
the loop iterator within the body of the loop.
500+
"""
501+
i0_outer_outer = tir.Var("i0_outer_outer", "int32")
502+
i0_outer_inner = tir.Var("i0_outer_inner", "int32")
503+
i0_inner = tir.Var("i0_inner", "int32")
504+
505+
A = tir.decl_buffer(1024, "float32", "A")
506+
B = tir.decl_buffer(1024, "float32", "B")
507+
508+
index = i0_outer_outer * 52 + i0_outer_inner * 4 + i0_inner
509+
510+
stmt = tir.BufferStore(B, tir.BufferLoad(A, [index]), [index])
511+
stmt = tir.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256, stmt, None)
512+
stmt = tir.For(i0_inner, 0, 4, tir.ForKind.VECTORIZED, stmt)
513+
stmt = tir.For(i0_outer_inner, 0, 13, tir.ForKind.PARALLEL, stmt)
514+
stmt = tir.AttrStmt(
515+
T.iter_var(i0_outer_inner, None, "DataPar", ""),
516+
"pragma_parallal_barrier_when_finish",
517+
1,
518+
stmt,
519+
)
520+
stmt = tir.AttrStmt(
521+
T.iter_var(i0_outer_inner, None, "DataPar", ""),
522+
"pragma_parallal_stride_pattern",
523+
1,
524+
stmt,
525+
)
526+
stmt = tir.For(i0_outer_outer, 0, 20, tir.ForKind.SERIAL, stmt)
527+
stmt = tir.AttrStmt(
528+
T.iter_var(i0_outer_outer, None, "DataPar", ""),
529+
"pragma_parallal_launch_point",
530+
1,
531+
stmt,
532+
)
533+
534+
A_handle = tir.Var("A_handle", "handle")
535+
B_handle = tir.Var("B_handle", "handle")
536+
537+
func = tir.PrimFunc(
538+
[A_handle, B_handle],
539+
stmt,
540+
buffer_map={A_handle: A, B_handle: B},
541+
)
542+
return func
543+
544+
expected = before
545+
546+
488547
if __name__ == "__main__":
489548
tvm.testing.main()

0 commit comments

Comments
 (0)