Skip to content

Commit 98008c2

Browse files
authored
[Fix][TVMScript] Fix LetStmt printing logic (#13900)
This PR is the bug fix reported in #13892. Initially, we mix the logic of `LetStmt` docsifying method with and without concise scoping. For example, in ```python x = T.var("int32") with T.let(x, 0): ``` `x` in the `LetStmt` works as a right value, while in ```python x: T.int32 = 0 ``` `x` in the `LetStmt` works as a left value as result. Our old logic mixed them together to generate the wrong code for the first case. Meanwhile, during the fix, we found another bug in concise scoping check. For example, we have ```python x = T.var("int32") y = T.var("int32") with T.let(x, y): with T.let(y, 0): ``` here we should not output ```python x = T.var("int32") y = T.var("int32") with T.let(x, y): y: int32 = 0 ``` becase this will define a new `y_1: int32 = 0` indeed, due the the variable shadowing logic of the parser, which is different from the `y` we define and refer to. Our concise scoping `v: ... = ...` should launch if and only if the `v` is never defined before. Otherwise, we use `with T.let(v, ...):` instead.
1 parent e34506c commit 98008c2

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

src/script/printer/tir/stmt.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
5757
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
5858
.set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
5959
bool concise = AllowConciseScoping(d);
60-
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
61-
With<TIRFrame> f(d, stmt);
62-
ExprDoc lhs = d->IsVarDefined(stmt->var) ? d->GetVarDoc(stmt->var).value()
63-
: DefineVar(stmt->var, *f, d);
64-
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
65-
Array<StmtDoc>* stmts = &(*f)->stmts;
66-
if (concise) {
60+
if (concise && !d->IsVarDefined(stmt->var)) {
61+
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
62+
With<TIRFrame> f(d, stmt);
63+
ExprDoc lhs = DefineVar(stmt->var, *f, d);
64+
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
65+
Array<StmtDoc>* stmts = &(*f)->stmts;
6766
Type type = stmt->var->type_annotation;
6867
Optional<ExprDoc> type_doc =
6968
d->AsDoc<ExprDoc>(type, p->Attr("var")->Attr("type_annotation"));
@@ -75,6 +74,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
7574
stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc));
7675
return StmtBlockDoc(*stmts);
7776
} else {
77+
ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var"));
78+
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
79+
With<TIRFrame> f(d, stmt);
80+
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
81+
Array<StmtDoc>* stmts = &(*f)->stmts;
7882
rhs = TIR(d, "let")->Call({lhs, rhs});
7983
return ScopeDoc(NullOpt, rhs, *stmts);
8084
}

tests/python/unittest/test_tvmscript_printer_tir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def test_let_stmt():
254254
_assert_print(
255255
obj,
256256
"""
257+
v = T.var("float32")
257258
with T.let(v, T.float32(10)):
258259
T.evaluate(0)
259260
""",

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3543,6 +3543,32 @@ def func():
35433543
return func
35443544

35453545

3546+
def let_stmt_var():
3547+
@T.prim_func
3548+
def func():
3549+
x = T.var("int32")
3550+
y = T.var("int32")
3551+
with T.let(x, 0):
3552+
with T.let(y, 0):
3553+
T.evaluate(0)
3554+
T.evaluate(0)
3555+
3556+
return func
3557+
3558+
3559+
def let_stmt_value():
3560+
@T.prim_func
3561+
def func():
3562+
x = T.var("int32")
3563+
y = T.var("int32")
3564+
with T.let(x, y):
3565+
with T.let(y, 0):
3566+
T.evaluate(0)
3567+
T.evaluate(0)
3568+
3569+
return func
3570+
3571+
35463572
ir_generator = tvm.testing.parameter(
35473573
opt_gemm_normalize,
35483574
opt_gemm_lower,
@@ -3601,6 +3627,8 @@ def func():
36013627
*nested_boolean_expressions(),
36023628
multi_env_threads,
36033629
intrinsic_pow,
3630+
let_stmt_var,
3631+
let_stmt_value,
36043632
)
36053633

36063634

0 commit comments

Comments
 (0)