diff --git a/src/op/copy.cc b/src/op/copy.cc index 31a0c0092..e9aaa1547 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -934,7 +934,7 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, auto body = Evaluate(Call(DataType::Handle(), op, args)); For for_node = For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body); - for_node = LoopPragmaUnroll(for_node); + for_node = PragmaUnrollLoop(for_node); auto range = T.thread_bounds; if (range.defined()) { auto thread_var = T.thread_var; diff --git a/src/op/fill.cc b/src/op/fill.cc index feab45b56..7eb22daa8 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -167,18 +167,19 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = - VectorizeLoop(thread_loop, analyzer, T.layout_map); + auto vectorized_loop = VectorizeLoop(thread_loop, analyzer, T.layout_map); + auto unrolled_loop = PragmaUnrollLoop(vectorized_loop); + if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), - vectorized_thread_loop); + unrolled_loop); } - return vectorized_thread_loop; + return unrolled_loop; } else if (IsLocalBuffer(dst) || IsLocalVarBuffer(dst)) { auto init_loop = MakeSIMTLoop(analyzer); - auto vectorized_thread_loop = - VectorizeLoop(init_loop, analyzer, T.layout_map); - return vectorized_thread_loop; + auto vectorized_loop = VectorizeLoop(init_loop, analyzer, T.layout_map); + auto unrolled_loop = PragmaUnrollLoop(vectorized_loop); + return unrolled_loop; } else if (IsSharedBuffer(dst) || IsGlobalBuffer(dst)) { auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, @@ -191,13 +192,13 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = - VectorizeLoop(thread_loop, analyzer, T.layout_map); + auto vectorized_loop = VectorizeLoop(thread_loop, analyzer, T.layout_map); + auto unrolled_loop = PragmaUnrollLoop(vectorized_loop); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), - vectorized_thread_loop); + unrolled_loop); } - return vectorized_thread_loop; + return unrolled_loop; } else { LOG(FATAL) << "Unsupported scope " << dst.scope(); return Stmt(); diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 765ec1786..3d5888492 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -448,6 +448,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (dst_layout->InputDim() > 0) { body = PartitionLoop(Downcast(body), T.thread_var, analyzer, red_layout); + body = PragmaUnrollLoop(Downcast(body)); } else { auto guard = (T.thread_var == T.thread_bounds->min); body = IfThenElse(guard, body); diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 115b48b35..53daa2be5 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -161,9 +161,7 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, if (has_thread_offset) { body = Substitute(body, thread_offset_map); } - - auto for_node = LoopPragmaUnroll(Downcast(body)); - return for_node; + return Downcast(body); } class LoopPramaUnroller : public StmtExprMutator { @@ -264,7 +262,7 @@ Fragment PlanLoopPartition(const For &op, int vectorize_size, return fragment->BindThreadRange(thread_range); } -For LoopPragmaUnroll(For stmt) { +For PragmaUnrollLoop(For stmt) { LoopPramaUnroller unroller; For unrolled = Downcast(unroller(std::move(stmt))); return unrolled; @@ -297,6 +295,8 @@ Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var, result_loop = VectorizeLoop(result_loop, saved_analyzer.get(), layout_map); } + result_loop = PragmaUnrollLoop(result_loop); + // Step 3: Wrap with predicate if provided and this is a parallel loop if (predicate.defined() && parallel_loop) { return IfThenElse(predicate.value(), result_loop); diff --git a/src/transform/loop_partition.h b/src/transform/loop_partition.h index ffc32ec45..033a25615 100644 --- a/src/transform/loop_partition.h +++ b/src/transform/loop_partition.h @@ -45,7 +45,7 @@ Fragment PlanLoopPartition(const For &op, size_t num_thread, Fragment PlanLoopPartition(const For &op, int vectorize_size, const Range &thread_range); -For LoopPragmaUnroll(For stmt); +For PragmaUnrollLoop(For stmt); /*! * \brief Lower a parallel loop by partitioning and vectorizing it. diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 222a6e79a..7dc17102a 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -658,13 +658,26 @@ class VectorizeRewriter : public StmtExprMutator { vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var); Stmt body = Substitute(fnode->body, vmap); body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); - body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, + // TileLang uses ForKind::kParallel in frontend SIMT loops. After + // vectorization, keep semantics equivalent but downgrade to serial so + // subsequent passes (e.g. pragma-unroll) can run. + ForKind outer_kind = fnode->kind; + if (outer_kind == ForKind::kParallel) { + outer_kind = ForKind::kSerial; + } + body = For(outer_var, 0, extent / vector_size_, outer_kind, body, fnode->thread_binding, fnode->annotations, fnode->step, fnode->span); return body; } } else { - return ret; + // Keep other loops intact, except for TileLang frontend "parallel" loops + // which should behave as serial loops after lowering. + For loop = ret.as().value(); + if (loop->kind == ForKind::kParallel) { + loop.CopyOnWrite()->kind = ForKind::kSerial; + } + return loop; } } @@ -780,6 +793,38 @@ bool IndicesCanVectorize(const PrimExpr &expr, Var var, } } +namespace { + +/*! + * \brief Convert TIR parallel loops into serial loops. + * + * TileLang uses ForKind::kParallel in a few places as a frontend "SIMT loop" + * marker. When vectorize size resolves to 1 (i.e. no vectorization is applied), + * keeping these loops as kParallel can block later loop transforms that only + * apply to serial loops (e.g. pragma-unroll rewriting). + * + * This rewriter is intentionally conservative: it only downgrades kParallel to + * kSerial and leaves all other loop kinds untouched. + */ +class ParallelToSerialRewriter : public StmtExprMutator { +private: + Stmt VisitStmt_(const ForNode *node) final { + Stmt visited = StmtExprMutator::VisitStmt_(node); + For loop = Downcast(visited); + if (loop->kind == ForKind::kParallel) { + loop.CopyOnWrite()->kind = ForKind::kSerial; + } + return loop; + } +}; + +For ParallelToSerial(const For &loop) { + ParallelToSerialRewriter rewriter; + return Downcast(rewriter(loop)); +} + +} // namespace + For VectorizeLoop(const For &loop, const LayoutMap &layout_map, int vectorize_hint) { if (vectorize_hint <= 0) { @@ -788,7 +833,7 @@ For VectorizeLoop(const For &loop, const LayoutMap &layout_map, vectorize_hint = planner.Plan(loop); } if (vectorize_hint == 1) - return loop; + return ParallelToSerial(loop); auto rewriter = VectorizeRewriter(vectorize_hint); return Downcast(rewriter(loop)); } @@ -800,7 +845,7 @@ For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, vectorize_hint = planner.Plan(loop); } if (vectorize_hint == 1) - return loop; + return ParallelToSerial(loop); auto rewriter = VectorizeRewriter(vectorize_hint); return Downcast(rewriter(loop)); }