diff --git a/src/transform/annotate_warp_group_reg_alloc.cc b/src/transform/annotate_warp_group_reg_alloc.cc index 537c229a2..08be53f20 100644 --- a/src/transform/annotate_warp_group_reg_alloc.cc +++ b/src/transform/annotate_warp_group_reg_alloc.cc @@ -124,7 +124,9 @@ class SetMaxNRegInjector : public StmtExprMutator { } auto producer_body = if_then_else->then_case; Optional consumer_body = if_then_else->else_case; - ICHECK(consumer_body.defined()) << "Consumer body is undefined"; + // In some degenerate warp-specialized patterns (e.g., producer-only), + // the consumer body may be absent. Handle gracefully by only annotating + // the producer side when consumer is missing. auto dec_reg = nreg_[0].as()->value; auto inc_reg = nreg_[1].as()->value; @@ -150,15 +152,20 @@ class SetMaxNRegInjector : public StmtExprMutator { producer_stmts.push_back(producer_body); auto new_producer_body = SeqStmt(producer_stmts); - Array consumer_stmts; - consumer_stmts.push_back(inc_reg_stmt); - consumer_stmts.push_back(consumer_body.value()); - auto new_consumer_body = SeqStmt(consumer_stmts); + Stmt new_if_stmt; + if (consumer_body.defined()) { + Array consumer_stmts; + consumer_stmts.push_back(inc_reg_stmt); + consumer_stmts.push_back(consumer_body.value()); + auto new_consumer_body = SeqStmt(consumer_stmts); + new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body, + new_consumer_body); + } else { + // No consumer branch; keep the if-then form. + new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body); + } - auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body, - new_consumer_body); auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt); - return new_attr; } else { return StmtExprMutator::VisitStmt_(op); diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index aad1f474b..93beb15d4 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -295,14 +295,15 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer { void VisitExpr_(const CallNode *op) final { if (op->op.same_as(mbarrier_expect_tx())) { - PrimExpr e = tma_op_to_barrier_id_[tvm::ffi::GetRef(op)] - .as() - ->args[0]; - auto int_set = arith::EvalSet(e, var_int_set_); - expect_.push_back(if_depth_ == 1); - sequence.push_back(0); - int_sets_.push_back(int_set); - expect_tx_count_ += 1; + auto call_ref = tvm::ffi::GetRef(op); + if (tma_op_to_barrier_id_.count(call_ref)) { + PrimExpr e = tma_op_to_barrier_id_[call_ref].as()->args[0]; + auto int_set = arith::EvalSet(e, var_int_set_); + expect_.push_back(if_depth_ == 1); + sequence.push_back(0); + int_sets_.push_back(int_set); + expect_tx_count_ += 1; + } } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { sequence.push_back(1); } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { @@ -337,32 +338,61 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer { class BarrierCreationRewriter : public StmtExprMutator { public: BarrierCreationRewriter(std::vector restore_barrier_ids, - PrimExpr producer_thread_extent) + PrimExpr producer_thread_extent, + int ensure_min_count = 0, + PrimExpr default_barrier_thread_count = 1) : restore_barrier_ids_(std::move(restore_barrier_ids)), - producer_thread_extent_(std::move(producer_thread_extent)) {} + producer_thread_extent_(std::move(producer_thread_extent)), + ensure_min_count_(ensure_min_count), + default_barrier_thread_count_(std::move(default_barrier_thread_count)) { + } PrimExpr VisitExpr_(const CallNode *op) { if (op->op.same_as(create_list_of_mbarrier())) { - std::vector tmp_(op->args.size(), false); - Array new_args; + size_t cur_n = op->args.size(); + size_t need_n = + std::max(cur_n, static_cast(ensure_min_count_)); + + // Mark barriers to restore across the full needed length, not just the + // original length, so newly appended entries can be restored as well. + std::vector replace(need_n, false); for (auto &id : restore_barrier_ids_) { - tmp_[id] = true; + if (id >= 0 && static_cast(id) < replace.size()) { + replace[id] = true; + } } - for (size_t i{0}; i < op->args.size(); ++i) { - if (tmp_[i]) { + Array new_args; + new_args.reserve(need_n); + + // Preserve/override existing entries + for (size_t i{0}; i < cur_n; ++i) { + if (replace[i]) { new_args.push_back(producer_thread_extent_); } else { new_args.push_back(op->args[i]); } } + // Append additional barriers if required + for (size_t i = cur_n; i < need_n; ++i) { + if (replace[i]) { + new_args.push_back(producer_thread_extent_); + } else { + new_args.push_back(default_barrier_thread_count_); + } + } + return Call(op->dtype, op->op, new_args); } else { return StmtExprMutator::VisitExpr_(op); } } + +private: std::vector restore_barrier_ids_; PrimExpr producer_thread_extent_; + int ensure_min_count_{0}; + PrimExpr default_barrier_thread_count_{1}; }; // we trust mbarrier_wait_parity to be correct @@ -399,8 +429,31 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { collector.barrier_id_to_range(), has_create_list_of_mbarrier); f.CopyOnWrite()->body = rewriter(f->body); + // Compute the minimum number of barriers actually referenced in the body + // after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA). + struct GetMbarrierMaxIdxCollector : public StmtExprVisitor { + int max_idx{-1}; + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(get_mbarrier())) { + if (op->args.size() == 1) { + if (const auto *imm = op->args[0].as()) { + max_idx = std::max(max_idx, static_cast(imm->value)); + } + } + } + StmtExprVisitor::VisitExpr_(op); + } + }; + + GetMbarrierMaxIdxCollector max_idx_collector; + max_idx_collector(f->body); + int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count + + // For simple TMA-only producers, default barrier arrive count should be 1 + // (only the elected leader performs the TMA arrive/expect). auto barrier_creation_rewriter = BarrierCreationRewriter( - rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_); + rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_, + ensure_min_count, Integer(1)); f.CopyOnWrite()->body = barrier_creation_rewriter(f->body); return f; } @@ -453,10 +506,27 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CallNode *op) { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { - // check this must be in the tma_op_to_barrier_id_ - ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef(op))) - << "tma_load must be in the tma_op_to_barrier_id_"; - auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef(op)]; + auto call_ref = tvm::ffi::GetRef(op); + if (!tma_op_to_barrier_id_.count(call_ref)) { + // For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id) + // so codegen can emit mbarrier[index]. This handles degenerate + // producer-only kernels where no arrive() is seen and mapping is empty. + auto arg0 = op->args[0].as(); + bool is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + !arg0.value()->op.same_as(create_tma_im2col_descriptor()); + if (is_1d_tma_load && op->args.size() >= 3) { + if (const auto *imm = op->args[2].as()) { + Array new_args = op->args; + new_args.Set(2, Call(DataType::Handle(), get_mbarrier(), + {IntImm(DataType::Int(32), + static_cast(imm->value))})); + return Call(op->dtype, op->op, new_args); + } + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + auto barrier_id = tma_op_to_barrier_id_[call_ref]; auto new_args = op->args; auto arg0 = op->args[0].as(); auto is_1d_tma_load = @@ -469,9 +539,11 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { } return Call(op->dtype, op->op, new_args); } else if (op->op.same_as(mbarrier_expect_tx())) { - ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef(op))) - << "mbarrier_expect_tx must be in the tma_op_to_barrier_id_"; - auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef(op)]; + auto call_ref = tvm::ffi::GetRef(op); + if (!tma_op_to_barrier_id_.count(call_ref)) { + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + auto barrier_id = tma_op_to_barrier_id_[call_ref]; auto new_args = op->args; new_args.Set(0, barrier_id); if (!has_warp_specialization_)