Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
auto body = Evaluate(Call(DataType::Handle(), op, args));
For for_node =
For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
for_node = LoopPragmaUnroll(for_node);
for_node = PragmaUnrollLoop(for_node);
auto range = T.thread_bounds;
if (range.defined()) {
auto thread_var = T.thread_var;
Expand Down
23 changes: 12 additions & 11 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,19 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop =
VectorizeLoop(thread_loop, analyzer, T.layout_map);
auto vectorized_loop = VectorizeLoop(thread_loop, analyzer, T.layout_map);
auto unrolled_loop = PragmaUnrollLoop(vectorized_loop);

if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
unrolled_loop);
}
return vectorized_thread_loop;
return unrolled_loop;
} else if (IsLocalBuffer(dst) || IsLocalVarBuffer(dst)) {
auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop =
VectorizeLoop(init_loop, analyzer, T.layout_map);
return vectorized_thread_loop;
auto vectorized_loop = VectorizeLoop(init_loop, analyzer, T.layout_map);
auto unrolled_loop = PragmaUnrollLoop(vectorized_loop);
return unrolled_loop;
} else if (IsSharedBuffer(dst) || IsGlobalBuffer(dst)) {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target,
Expand All @@ -191,13 +192,13 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop =
VectorizeLoop(thread_loop, analyzer, T.layout_map);
auto vectorized_loop = VectorizeLoop(thread_loop, analyzer, T.layout_map);
auto unrolled_loop = PragmaUnrollLoop(vectorized_loop);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
unrolled_loop);
}
return vectorized_thread_loop;
return unrolled_loop;
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
return Stmt();
Expand Down
1 change: 1 addition & 0 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst_layout->InputDim() > 0) {
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer,
red_layout);
body = PragmaUnrollLoop(Downcast<For>(body));
} else {
auto guard = (T.thread_var == T.thread_bounds->min);
body = IfThenElse(guard, body);
Expand Down
8 changes: 4 additions & 4 deletions src/transform/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
if (has_thread_offset) {
body = Substitute(body, thread_offset_map);
}

auto for_node = LoopPragmaUnroll(Downcast<For>(body));
return for_node;
return Downcast<For>(body);
}

class LoopPramaUnroller : public StmtExprMutator {
Expand Down Expand Up @@ -264,7 +262,7 @@ Fragment PlanLoopPartition(const For &op, int vectorize_size,
return fragment->BindThreadRange(thread_range);
}

For LoopPragmaUnroll(For stmt) {
For PragmaUnrollLoop(For stmt) {
LoopPramaUnroller unroller;
For unrolled = Downcast<For>(unroller(std::move(stmt)));
return unrolled;
Expand Down Expand Up @@ -297,6 +295,8 @@ Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var,
result_loop = VectorizeLoop(result_loop, saved_analyzer.get(), layout_map);
}

result_loop = PragmaUnrollLoop(result_loop);

// Step 3: Wrap with predicate if provided and this is a parallel loop
if (predicate.defined() && parallel_loop) {
return IfThenElse(predicate.value(), result_loop);
Expand Down
2 changes: 1 addition & 1 deletion src/transform/loop_partition.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Fragment PlanLoopPartition(const For &op, size_t num_thread,
Fragment PlanLoopPartition(const For &op, int vectorize_size,
const Range &thread_range);

For LoopPragmaUnroll(For stmt);
For PragmaUnrollLoop(For stmt);

/*!
* \brief Lower a parallel loop by partitioning and vectorizing it.
Expand Down
53 changes: 49 additions & 4 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,13 +658,26 @@ class VectorizeRewriter : public StmtExprMutator {
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
// TileLang uses ForKind::kParallel in frontend SIMT loops. After
// vectorization, keep semantics equivalent but downgrade to serial so
// subsequent passes (e.g. pragma-unroll) can run.
ForKind outer_kind = fnode->kind;
if (outer_kind == ForKind::kParallel) {
outer_kind = ForKind::kSerial;
}
body = For(outer_var, 0, extent / vector_size_, outer_kind, body,
fnode->thread_binding, fnode->annotations, fnode->step,
fnode->span);
return body;
}
} else {
return ret;
// Keep other loops intact, except for TileLang frontend "parallel" loops
// which should behave as serial loops after lowering.
For loop = ret.as<For>().value();
if (loop->kind == ForKind::kParallel) {
loop.CopyOnWrite()->kind = ForKind::kSerial;
}
return loop;
}
}

Expand Down Expand Up @@ -780,6 +793,38 @@ bool IndicesCanVectorize(const PrimExpr &expr, Var var,
}
}

namespace {

/*!
* \brief Convert TIR parallel loops into serial loops.
*
* TileLang uses ForKind::kParallel in a few places as a frontend "SIMT loop"
* marker. When vectorize size resolves to 1 (i.e. no vectorization is applied),
* keeping these loops as kParallel can block later loop transforms that only
* apply to serial loops (e.g. pragma-unroll rewriting).
*
* This rewriter is intentionally conservative: it only downgrades kParallel to
* kSerial and leaves all other loop kinds untouched.
*/
class ParallelToSerialRewriter : public StmtExprMutator {
private:
Stmt VisitStmt_(const ForNode *node) final {
Stmt visited = StmtExprMutator::VisitStmt_(node);
For loop = Downcast<For>(visited);
if (loop->kind == ForKind::kParallel) {
loop.CopyOnWrite()->kind = ForKind::kSerial;
}
return loop;
}
};

For ParallelToSerial(const For &loop) {
ParallelToSerialRewriter rewriter;
return Downcast<For>(rewriter(loop));
}

} // namespace

For VectorizeLoop(const For &loop, const LayoutMap &layout_map,
int vectorize_hint) {
if (vectorize_hint <= 0) {
Expand All @@ -788,7 +833,7 @@ For VectorizeLoop(const For &loop, const LayoutMap &layout_map,
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
return ParallelToSerial(loop);
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}
Expand All @@ -800,7 +845,7 @@ For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
return ParallelToSerial(loop);
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}
Expand Down
Loading