Skip to content

Commit 7a80b6d

Browse files
authored
[Bugfix] Enable code lowering with producer‑copy‑only program (#1168)
* bugfix * lint fix * Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns. * Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic. * Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions.
1 parent 10911e2 commit 7a80b6d

File tree

2 files changed

+110
-31
lines changed

2 files changed

+110
-31
lines changed

src/transform/annotate_warp_group_reg_alloc.cc

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ class SetMaxNRegInjector : public StmtExprMutator {
124124
}
125125
auto producer_body = if_then_else->then_case;
126126
Optional<Stmt> consumer_body = if_then_else->else_case;
127-
ICHECK(consumer_body.defined()) << "Consumer body is undefined";
127+
// In some degenerate warp-specialized patterns (e.g., producer-only),
128+
// the consumer body may be absent. Handle gracefully by only annotating
129+
// the producer side when consumer is missing.
128130

129131
auto dec_reg = nreg_[0].as<IntImmNode>()->value;
130132
auto inc_reg = nreg_[1].as<IntImmNode>()->value;
@@ -150,15 +152,20 @@ class SetMaxNRegInjector : public StmtExprMutator {
150152
producer_stmts.push_back(producer_body);
151153
auto new_producer_body = SeqStmt(producer_stmts);
152154

153-
Array<Stmt> consumer_stmts;
154-
consumer_stmts.push_back(inc_reg_stmt);
155-
consumer_stmts.push_back(consumer_body.value());
156-
auto new_consumer_body = SeqStmt(consumer_stmts);
155+
Stmt new_if_stmt;
156+
if (consumer_body.defined()) {
157+
Array<Stmt> consumer_stmts;
158+
consumer_stmts.push_back(inc_reg_stmt);
159+
consumer_stmts.push_back(consumer_body.value());
160+
auto new_consumer_body = SeqStmt(consumer_stmts);
161+
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
162+
new_consumer_body);
163+
} else {
164+
// No consumer branch; keep the if-then form.
165+
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body);
166+
}
157167

158-
auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
159-
new_consumer_body);
160168
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
161-
162169
return new_attr;
163170
} else {
164171
return StmtExprMutator::VisitStmt_(op);

src/transform/inject_tma_barrier.cc

Lines changed: 95 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,15 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer {
295295

296296
void VisitExpr_(const CallNode *op) final {
297297
if (op->op.same_as(mbarrier_expect_tx())) {
298-
PrimExpr e = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)]
299-
.as<CallNode>()
300-
->args[0];
301-
auto int_set = arith::EvalSet(e, var_int_set_);
302-
expect_.push_back(if_depth_ == 1);
303-
sequence.push_back(0);
304-
int_sets_.push_back(int_set);
305-
expect_tx_count_ += 1;
298+
auto call_ref = tvm::ffi::GetRef<Call>(op);
299+
if (tma_op_to_barrier_id_.count(call_ref)) {
300+
PrimExpr e = tma_op_to_barrier_id_[call_ref].as<CallNode>()->args[0];
301+
auto int_set = arith::EvalSet(e, var_int_set_);
302+
expect_.push_back(if_depth_ == 1);
303+
sequence.push_back(0);
304+
int_sets_.push_back(int_set);
305+
expect_tx_count_ += 1;
306+
}
306307
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
307308
sequence.push_back(1);
308309
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
@@ -337,32 +338,61 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer {
337338
class BarrierCreationRewriter : public StmtExprMutator {
338339
public:
339340
BarrierCreationRewriter(std::vector<int> restore_barrier_ids,
340-
PrimExpr producer_thread_extent)
341+
PrimExpr producer_thread_extent,
342+
int ensure_min_count = 0,
343+
PrimExpr default_barrier_thread_count = 1)
341344
: restore_barrier_ids_(std::move(restore_barrier_ids)),
342-
producer_thread_extent_(std::move(producer_thread_extent)) {}
345+
producer_thread_extent_(std::move(producer_thread_extent)),
346+
ensure_min_count_(ensure_min_count),
347+
default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
348+
}
343349

344350
PrimExpr VisitExpr_(const CallNode *op) {
345351
if (op->op.same_as(create_list_of_mbarrier())) {
346-
std::vector<bool> tmp_(op->args.size(), false);
347-
Array<PrimExpr> new_args;
352+
size_t cur_n = op->args.size();
353+
size_t need_n =
354+
std::max<size_t>(cur_n, static_cast<size_t>(ensure_min_count_));
355+
356+
// Mark barriers to restore across the full needed length, not just the
357+
// original length, so newly appended entries can be restored as well.
358+
std::vector<bool> replace(need_n, false);
348359
for (auto &id : restore_barrier_ids_) {
349-
tmp_[id] = true;
360+
if (id >= 0 && static_cast<size_t>(id) < replace.size()) {
361+
replace[id] = true;
362+
}
350363
}
351364

352-
for (size_t i{0}; i < op->args.size(); ++i) {
353-
if (tmp_[i]) {
365+
Array<PrimExpr> new_args;
366+
new_args.reserve(need_n);
367+
368+
// Preserve/override existing entries
369+
for (size_t i{0}; i < cur_n; ++i) {
370+
if (replace[i]) {
354371
new_args.push_back(producer_thread_extent_);
355372
} else {
356373
new_args.push_back(op->args[i]);
357374
}
358375
}
376+
// Append additional barriers if required
377+
for (size_t i = cur_n; i < need_n; ++i) {
378+
if (replace[i]) {
379+
new_args.push_back(producer_thread_extent_);
380+
} else {
381+
new_args.push_back(default_barrier_thread_count_);
382+
}
383+
}
384+
359385
return Call(op->dtype, op->op, new_args);
360386
} else {
361387
return StmtExprMutator::VisitExpr_(op);
362388
}
363389
}
390+
391+
private:
364392
std::vector<int> restore_barrier_ids_;
365393
PrimExpr producer_thread_extent_;
394+
int ensure_min_count_{0};
395+
PrimExpr default_barrier_thread_count_{1};
366396
};
367397

368398
// we trust mbarrier_wait_parity to be correct
@@ -399,8 +429,31 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
399429
collector.barrier_id_to_range(),
400430
has_create_list_of_mbarrier);
401431
f.CopyOnWrite()->body = rewriter(f->body);
432+
// Compute the minimum number of barriers actually referenced in the body
433+
// after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
434+
struct GetMbarrierMaxIdxCollector : public StmtExprVisitor {
435+
int max_idx{-1};
436+
void VisitExpr_(const CallNode *op) final {
437+
if (op->op.same_as(get_mbarrier())) {
438+
if (op->args.size() == 1) {
439+
if (const auto *imm = op->args[0].as<IntImmNode>()) {
440+
max_idx = std::max(max_idx, static_cast<int>(imm->value));
441+
}
442+
}
443+
}
444+
StmtExprVisitor::VisitExpr_(op);
445+
}
446+
};
447+
448+
GetMbarrierMaxIdxCollector max_idx_collector;
449+
max_idx_collector(f->body);
450+
int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count
451+
452+
// For simple TMA-only producers, default barrier arrive count should be 1
453+
// (only the elected leader performs the TMA arrive/expect).
402454
auto barrier_creation_rewriter = BarrierCreationRewriter(
403-
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_);
455+
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_,
456+
ensure_min_count, Integer(1));
404457
f.CopyOnWrite()->body = barrier_creation_rewriter(f->body);
405458
return f;
406459
}
@@ -453,10 +506,27 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
453506

454507
PrimExpr VisitExpr_(const CallNode *op) {
455508
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
456-
// check this must be in the tma_op_to_barrier_id_
457-
ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
458-
<< "tma_load must be in the tma_op_to_barrier_id_";
459-
auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
509+
auto call_ref = tvm::ffi::GetRef<Call>(op);
510+
if (!tma_op_to_barrier_id_.count(call_ref)) {
511+
// For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
512+
// so codegen can emit mbarrier[index]. This handles degenerate
513+
// producer-only kernels where no arrive() is seen and mapping is empty.
514+
auto arg0 = op->args[0].as<Call>();
515+
bool is_1d_tma_load =
516+
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
517+
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
518+
if (is_1d_tma_load && op->args.size() >= 3) {
519+
if (const auto *imm = op->args[2].as<IntImmNode>()) {
520+
Array<PrimExpr> new_args = op->args;
521+
new_args.Set(2, Call(DataType::Handle(), get_mbarrier(),
522+
{IntImm(DataType::Int(32),
523+
static_cast<int>(imm->value))}));
524+
return Call(op->dtype, op->op, new_args);
525+
}
526+
}
527+
return IRMutatorWithAnalyzer::VisitExpr_(op);
528+
}
529+
auto barrier_id = tma_op_to_barrier_id_[call_ref];
460530
auto new_args = op->args;
461531
auto arg0 = op->args[0].as<Call>();
462532
auto is_1d_tma_load =
@@ -469,9 +539,11 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
469539
}
470540
return Call(op->dtype, op->op, new_args);
471541
} else if (op->op.same_as(mbarrier_expect_tx())) {
472-
ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
473-
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_";
474-
auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
542+
auto call_ref = tvm::ffi::GetRef<Call>(op);
543+
if (!tma_op_to_barrier_id_.count(call_ref)) {
544+
return IRMutatorWithAnalyzer::VisitExpr_(op);
545+
}
546+
auto barrier_id = tma_op_to_barrier_id_[call_ref];
475547
auto new_args = op->args;
476548
new_args.Set(0, barrier_id);
477549
if (!has_warp_specialization_)

0 commit comments

Comments
 (0)