Skip to content

Commit 13f54e0

Browse files
authored
[BugFix][TVMScript] Fix printer for dependent loops (#9506)
1 parent 08898e1 commit 13f54e0

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/printer/tvmscript_printer.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,27 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
228228
/*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */
229229
Doc PrintLoopStack();
230230
/*!
231-
* \brief Print all simple loops in stack into one line using tir_prefix_.grid().
231+
* \brief Check whether a loop satisfies:
232+
* 1. the loop is serial;
233+
* 2. the loop has no annotation;
234+
* 3. the loop starts from 0;
235+
* 4. there is no optional information.
232236
* \param for_op the for node to be checked
237+
* \return A boolean indicating whether the input loop satisfies the above conditions
233238
*/
234239
bool IsSimpleLoop(const ForNode* for_op) {
235240
return for_op->kind == ForKind::kSerial && for_op->annotations.empty() &&
236241
is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op));
237242
}
243+
/*!
244+
* \brief Check whether the `min` or `extent` of a loop depends on previous loops
245+
* \param for_op The loop to be checked
246+
* \return A boolean indicating whether the input loop depends on previous loops
247+
*/
248+
bool DependOnPrevLoops(const ForNode* for_op) {
249+
auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); };
250+
return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
251+
}
238252

239253
/*!
240254
* \brief Print additional info about expr in comment.
@@ -895,7 +909,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
895909
bool simple_loop = IsSimpleLoop(op);
896910
if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op));
897911
// It is a loop that can be compressed, let the loops below print it out
898-
if (simple_loop && body != nullptr && IsSimpleLoop(body)) {
912+
if (simple_loop && body != nullptr && IsSimpleLoop(body) && !DependOnPrevLoops(body)) {
899913
doc << Print(GetRef<For>(body));
900914
TryDeallocVar(op->loop_var);
901915
loop_var_map_.erase(op->loop_var.get());

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3176,5 +3176,19 @@ def test_div_mod():
31763176
assert isinstance(func.body[3].value, tvm.tir.Mod)
31773177

31783178

3179+
@T.prim_func
3180+
def loop_extent_dependent(a: T.handle) -> None:
3181+
A = T.match_buffer(a, [], dtype="int32")
3182+
for i in T.serial(0, 128):
3183+
for j in T.serial(0, i):
3184+
A[()] = A[()] + j
3185+
3186+
3187+
def test_loop_extent_dependent():
3188+
func = loop_extent_dependent
3189+
rt_func = tvm.script.from_source(func.script(show_meta=True))
3190+
tvm.ir.assert_structural_equal(func, rt_func, True)
3191+
3192+
31793193
if __name__ == "__main__":
31803194
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)