Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ class IRConvertSSA final : public StmtExprMutator {
}

Var var = iter_var->var;
bool delayed_define = false;
if (auto it = function_scope_var_remap_.find(var.get());
it != function_scope_var_remap_.end()) {
var = it->second;
Expand All @@ -373,8 +374,23 @@ class IRConvertSSA final : public StmtExprMutator {
function_scope_var_remap_.insert({var.get(), new_var});
var = new_var;
} else {
function_scope_var_remap_.insert({var.get(), var});
defined_.insert(var.get());
// The AttrStmt refers to an undefined variable. This is
// allowed for some attributes, such as
// "pragma_parallel_launch_point", which annotates a variable
// that is about to occur in a ForNode. In these cases, the
// ForNode and the AttrStmt must continue using the same
// variable defintion.
//
// However, other AttrStmt, such as "thread_extent", act as
// points of definition for the variable they annotate. If
// the variable has not been defined after visiting the body,
// we should mark it as defined before exiting. This ensures
// correct de-duplication between multiple functions.
//
// This implementation may be simplified in the future by
// moving "pragma_parallel_launch_point" to be an annotation
// on the `ForNode`, rather than an `AttrStmt`.
delayed_define = true;
}

IterVar new_iter_var;
Expand All @@ -387,12 +403,22 @@ class IRConvertSSA final : public StmtExprMutator {
auto value = VisitExpr(op->value);
auto body = VisitStmt(op->body);

Stmt output;
if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<Stmt>(op);
output = GetRef<Stmt>(op);
} else {
return AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span);
output = AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span);
}

if (delayed_define) {
if (!defined_.count(var.get())) {
function_scope_var_remap_.insert({var.get(), var});
defined_.insert(var.get());
}
}

return output;

} else if (const VarNode* v = op->node.as<VarNode>()) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
Expand Down
61 changes: 60 additions & 1 deletion tests/python/tir-transform/test_tir_transform_convert_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import tvm
import tvm.testing
from tvm import tir
from tvm import tir, ir
from tvm.script import tir as T, ir as I


Expand Down Expand Up @@ -485,5 +485,64 @@ def kernel_2(A: T.Buffer([256], "float32")):
return mod


class TestTrackForwardDeclarationsInAttrStmt(BaseBeforeAfter):
"""T.attr statements may refer to a about-to-be-defined tir.Var"""

def before(self):
"""Generate the PrimFunc, which is already SSA

This is constructed directly, rather than using TVMScript or
the `tvm.tir.ir_builder`. This test case requires a
`tir.AttrStmt` that references a variable, followed by the
`tir.For` defining that variable. This is not expressible in
either TVMScript or `tvm.tir.ir_builder`, as they only provide
the loop iterator within the body of the loop.
"""
i0_outer_outer = tir.Var("i0_outer_outer", "int32")
i0_outer_inner = tir.Var("i0_outer_inner", "int32")
i0_inner = tir.Var("i0_inner", "int32")

A = tir.decl_buffer(1024, "float32", "A")
B = tir.decl_buffer(1024, "float32", "B")

index = i0_outer_outer * 52 + i0_outer_inner * 4 + i0_inner

stmt = tir.BufferStore(B, tir.BufferLoad(A, [index]), [index])
stmt = tir.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256, stmt, None)
stmt = tir.For(i0_inner, 0, 4, tir.ForKind.VECTORIZED, stmt)
stmt = tir.For(i0_outer_inner, 0, 13, tir.ForKind.PARALLEL, stmt)
stmt = tir.AttrStmt(
T.iter_var(i0_outer_inner, None, "DataPar", ""),
"pragma_parallal_barrier_when_finish",
1,
stmt,
)
stmt = tir.AttrStmt(
T.iter_var(i0_outer_inner, None, "DataPar", ""),
"pragma_parallal_stride_pattern",
1,
stmt,
)
stmt = tir.For(i0_outer_outer, 0, 20, tir.ForKind.SERIAL, stmt)
stmt = tir.AttrStmt(
T.iter_var(i0_outer_outer, None, "DataPar", ""),
"pragma_parallal_launch_point",
1,
stmt,
)

A_handle = tir.Var("A_handle", "handle")
B_handle = tir.Var("B_handle", "handle")

func = tir.PrimFunc(
[A_handle, B_handle],
stmt,
buffer_map={A_handle: A, B_handle: B},
)
return func

expected = before


if __name__ == "__main__":
tvm.testing.main()