Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,27 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
/*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */
Doc PrintLoopStack();
/*!
* \brief Print all simple loops in stack into one line using tir_prefix_.grid().
* \brief Check whether a loop satisfies:
* 1. the loop is serial;
* 2. the loop has no annotation;
* 3. the loop starts from 0;
* 4. there is no optional information.
* \param for_op the for node to be checked
* \return A boolean indicating whether the input loop satisfies the above conditions
*/
bool IsSimpleLoop(const ForNode* for_op) {
return for_op->kind == ForKind::kSerial && for_op->annotations.empty() &&
is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op));
}
/*!
* \brief Check whether the `min` or `extent` of a loop depends on previous loops
* \param for_op The loop to be checked
* \return A boolean indicating whether the input loop depends on previous loops
*/
bool DependOnPrevLoops(const ForNode* for_op) {
auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); };
return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
}

/*!
* \brief Print additional info about expr in comment.
Expand Down Expand Up @@ -895,7 +909,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
bool simple_loop = IsSimpleLoop(op);
if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op));
// It is a loop that can be compressed, let the loops below print it out
if (simple_loop && body != nullptr && IsSimpleLoop(body)) {
if (simple_loop && body != nullptr && IsSimpleLoop(body) && !DependOnPrevLoops(body)) {
doc << Print(GetRef<For>(body));
TryDeallocVar(op->loop_var);
loop_var_map_.erase(op->loop_var.get());
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,5 +3176,19 @@ def test_div_mod():
assert isinstance(func.body[3].value, tvm.tir.Mod)


@T.prim_func
def loop_extent_dependent(a: T.handle) -> None:
A = T.match_buffer(a, [], dtype="int32")
for i in T.serial(0, 128):
for j in T.serial(0, i):
A[()] = A[()] + j


def test_loop_extent_dependent():
func = loop_extent_dependent
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func, True)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))