diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index a47712e6b62a..f1c47e78bc45 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -228,13 +228,27 @@ class TVMScriptPrinter : public StmtFunctor, /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ Doc PrintLoopStack(); /*! - * \brief Print all simple loops in stack into one line using tir_prefix_.grid(). + * \brief Check whether a loop satisfies: + * 1. the loop is serial; + * 2. the loop has no annotation; + * 3. the loop starts from 0; + * 4. there is no optional information. * \param for_op the for node to be checked + * \return A boolean indicating whether the input loop satisfies the above conditions */ bool IsSimpleLoop(const ForNode* for_op) { return for_op->kind == ForKind::kSerial && for_op->annotations.empty() && is_zero(for_op->min) && !ContainsOptionalInfo(GetRef(for_op)); } + /*! + * \brief Check whether the `min` or `extent` of a loop depends on previous loops + * \param for_op The loop to be checked + * \return A boolean indicating whether the input loop depends on previous loops + */ + bool DependOnPrevLoops(const ForNode* for_op) { + auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); }; + return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check); + } /*! * \brief Print additional info about expr in comment. @@ -895,7 +909,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { bool simple_loop = IsSimpleLoop(op); if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out - if (simple_loop && body != nullptr && IsSimpleLoop(body)) { + if (simple_loop && body != nullptr && IsSimpleLoop(body) && !DependOnPrevLoops(body)) { doc << Print(GetRef(body)); TryDeallocVar(op->loop_var); loop_var_map_.erase(op->loop_var.get()); diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 4e1308b030f1..d7a8ff9b69c7 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3176,5 +3176,19 @@ def test_div_mod(): assert isinstance(func.body[3].value, tvm.tir.Mod) +@T.prim_func +def loop_extent_dependent(a: T.handle) -> None: + A = T.match_buffer(a, [], dtype="int32") + for i in T.serial(0, 128): + for j in T.serial(0, i): + A[()] = A[()] + j + + +def test_loop_extent_dependent(): + func = loop_extent_dependent + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))