@@ -228,13 +228,27 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
228228 /* ! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */
229229 Doc PrintLoopStack ();
230230 /* !
231- * \brief Print all simple loops in stack into one line using tir_prefix_.grid().
231+ * \brief Check whether a loop satisfies:
232+ * 1. the loop is serial;
233+ * 2. the loop has no annotation;
234+ * 3. the loop starts from 0;
235+ * 4. there is no optional information.
232236 * \param for_op the for node to be checked
237+ * \return A boolean indicating whether the input loop satisfies the above conditions
233238 */
234239 bool IsSimpleLoop (const ForNode* for_op) {
235240 return for_op->kind == ForKind::kSerial && for_op->annotations .empty () &&
236241 is_zero (for_op->min ) && !ContainsOptionalInfo (GetRef<Stmt>(for_op));
237242 }
243+ /* !
244+ * \brief Check whether the `min` or `extent` of a loop depends on previous loops
245+ * \param for_op The loop to be checked
246+ * \return A boolean indicating whether the input loop depends on previous loops
247+ */
248+ bool DependOnPrevLoops (const ForNode* for_op) {
249+ auto f_check = [&var_map = this ->loop_var_map_ ](const VarNode* v) { return var_map.count (v); };
250+ return UsesVar (for_op->min , f_check) || UsesVar (for_op->extent , f_check);
251+ }
238252
239253 /* !
240254 * \brief Print additional info about expr in comment.
@@ -895,7 +909,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
895909 bool simple_loop = IsSimpleLoop (op);
896910 if (simple_loop) simple_loop_stack_.push_back (GetRef<For>(op));
897911 // It is a loop that can be compressed, let the loops below print it out
898- if (simple_loop && body != nullptr && IsSimpleLoop (body)) {
912+ if (simple_loop && body != nullptr && IsSimpleLoop (body) && ! DependOnPrevLoops (body) ) {
899913 doc << Print (GetRef<For>(body));
900914 TryDeallocVar (op->loop_var );
901915 loop_var_map_.erase (op->loop_var .get ());
0 commit comments