@@ -295,14 +295,15 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer {
295295
296296 void VisitExpr_ (const CallNode *op) final {
297297 if (op->op .same_as (mbarrier_expect_tx ())) {
298- PrimExpr e = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)]
299- .as <CallNode>()
300- ->args [0 ];
301- auto int_set = arith::EvalSet (e, var_int_set_);
302- expect_.push_back (if_depth_ == 1 );
303- sequence.push_back (0 );
304- int_sets_.push_back (int_set);
305- expect_tx_count_ += 1 ;
298+ auto call_ref = tvm::ffi::GetRef<Call>(op);
299+ if (tma_op_to_barrier_id_.count (call_ref)) {
300+ PrimExpr e = tma_op_to_barrier_id_[call_ref].as <CallNode>()->args [0 ];
301+ auto int_set = arith::EvalSet (e, var_int_set_);
302+ expect_.push_back (if_depth_ == 1 );
303+ sequence.push_back (0 );
304+ int_sets_.push_back (int_set);
305+ expect_tx_count_ += 1 ;
306+ }
306307 } else if (op->op .same_as (builtin::ptx_arrive_barrier ())) {
307308 sequence.push_back (1 );
308309 } else if (op->op .same_as (builtin::ptx_cp_async_barrier ())) {
@@ -337,32 +338,61 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer {
337338class BarrierCreationRewriter : public StmtExprMutator {
338339public:
339340 BarrierCreationRewriter (std::vector<int > restore_barrier_ids,
340- PrimExpr producer_thread_extent)
341+ PrimExpr producer_thread_extent,
342+ int ensure_min_count = 0 ,
343+ PrimExpr default_barrier_thread_count = 1 )
341344 : restore_barrier_ids_(std::move(restore_barrier_ids)),
342- producer_thread_extent_ (std::move(producer_thread_extent)) {}
345+ producer_thread_extent_ (std::move(producer_thread_extent)),
346+ ensure_min_count_(ensure_min_count),
347+ default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
348+ }
343349
344350 PrimExpr VisitExpr_ (const CallNode *op) {
345351 if (op->op .same_as (create_list_of_mbarrier ())) {
346- std::vector<bool > tmp_ (op->args .size (), false );
347- Array<PrimExpr> new_args;
352+ size_t cur_n = op->args .size ();
353+ size_t need_n =
354+ std::max<size_t >(cur_n, static_cast <size_t >(ensure_min_count_));
355+
356+ // Mark barriers to restore across the full needed length, not just the
357+ // original length, so newly appended entries can be restored as well.
358+ std::vector<bool > replace (need_n, false );
348359 for (auto &id : restore_barrier_ids_) {
349- tmp_[id] = true ;
360+ if (id >= 0 && static_cast <size_t >(id) < replace.size ()) {
361+ replace[id] = true ;
362+ }
350363 }
351364
352- for (size_t i{0 }; i < op->args .size (); ++i) {
353- if (tmp_[i]) {
365+ Array<PrimExpr> new_args;
366+ new_args.reserve (need_n);
367+
368+ // Preserve/override existing entries
369+ for (size_t i{0 }; i < cur_n; ++i) {
370+ if (replace[i]) {
354371 new_args.push_back (producer_thread_extent_);
355372 } else {
356373 new_args.push_back (op->args [i]);
357374 }
358375 }
376+ // Append additional barriers if required
377+ for (size_t i = cur_n; i < need_n; ++i) {
378+ if (replace[i]) {
379+ new_args.push_back (producer_thread_extent_);
380+ } else {
381+ new_args.push_back (default_barrier_thread_count_);
382+ }
383+ }
384+
359385 return Call (op->dtype , op->op , new_args);
360386 } else {
361387 return StmtExprMutator::VisitExpr_ (op);
362388 }
363389 }
390+
391+ private:
364392 std::vector<int > restore_barrier_ids_;
365393 PrimExpr producer_thread_extent_;
394+ int ensure_min_count_{0 };
395+ PrimExpr default_barrier_thread_count_{1 };
366396};
367397
368398// we trust mbarrier_wait_parity to be correct
@@ -399,8 +429,31 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
399429 collector.barrier_id_to_range (),
400430 has_create_list_of_mbarrier);
401431 f.CopyOnWrite ()->body = rewriter (f->body );
432+ // Compute the minimum number of barriers actually referenced in the body
433+ // after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
434+ struct GetMbarrierMaxIdxCollector : public StmtExprVisitor {
435+ int max_idx{-1 };
436+ void VisitExpr_ (const CallNode *op) final {
437+ if (op->op .same_as (get_mbarrier ())) {
438+ if (op->args .size () == 1 ) {
439+ if (const auto *imm = op->args [0 ].as <IntImmNode>()) {
440+ max_idx = std::max (max_idx, static_cast <int >(imm->value ));
441+ }
442+ }
443+ }
444+ StmtExprVisitor::VisitExpr_ (op);
445+ }
446+ };
447+
448+ GetMbarrierMaxIdxCollector max_idx_collector;
449+ max_idx_collector (f->body );
450+ int ensure_min_count = max_idx_collector.max_idx + 1 ; // 0-based -> count
451+
452+ // For simple TMA-only producers, default barrier arrive count should be 1
453+ // (only the elected leader performs the TMA arrive/expect).
402454 auto barrier_creation_rewriter = BarrierCreationRewriter (
403- rewriter.restore_barrier_ids_ , rewriter.producer_thread_extent_ );
455+ rewriter.restore_barrier_ids_ , rewriter.producer_thread_extent_ ,
456+ ensure_min_count, Integer (1 ));
404457 f.CopyOnWrite ()->body = barrier_creation_rewriter (f->body );
405458 return f;
406459 }
@@ -453,10 +506,27 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
453506
454507 PrimExpr VisitExpr_ (const CallNode *op) {
455508 if (op->op .same_as (tma_load ()) || op->op .same_as (tma_load_im2col ())) {
456- // check this must be in the tma_op_to_barrier_id_
457- ICHECK (tma_op_to_barrier_id_.count (tvm::ffi::GetRef<Call>(op)))
458- << " tma_load must be in the tma_op_to_barrier_id_" ;
459- auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
509+ auto call_ref = tvm::ffi::GetRef<Call>(op);
510+ if (!tma_op_to_barrier_id_.count (call_ref)) {
511+ // For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
512+ // so codegen can emit mbarrier[index]. This handles degenerate
513+ // producer-only kernels where no arrive() is seen and mapping is empty.
514+ auto arg0 = op->args [0 ].as <Call>();
515+ bool is_1d_tma_load =
516+ arg0 && !arg0.value ()->op .same_as (create_tma_descriptor ()) &&
517+ !arg0.value ()->op .same_as (create_tma_im2col_descriptor ());
518+ if (is_1d_tma_load && op->args .size () >= 3 ) {
519+ if (const auto *imm = op->args [2 ].as <IntImmNode>()) {
520+ Array<PrimExpr> new_args = op->args ;
521+ new_args.Set (2 , Call (DataType::Handle (), get_mbarrier (),
522+ {IntImm (DataType::Int (32 ),
523+ static_cast <int >(imm->value ))}));
524+ return Call (op->dtype , op->op , new_args);
525+ }
526+ }
527+ return IRMutatorWithAnalyzer::VisitExpr_ (op);
528+ }
529+ auto barrier_id = tma_op_to_barrier_id_[call_ref];
460530 auto new_args = op->args ;
461531 auto arg0 = op->args [0 ].as <Call>();
462532 auto is_1d_tma_load =
@@ -469,9 +539,11 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
469539 }
470540 return Call (op->dtype , op->op , new_args);
471541 } else if (op->op .same_as (mbarrier_expect_tx ())) {
472- ICHECK (tma_op_to_barrier_id_.count (tvm::ffi::GetRef<Call>(op)))
473- << " mbarrier_expect_tx must be in the tma_op_to_barrier_id_" ;
474- auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
542+ auto call_ref = tvm::ffi::GetRef<Call>(op);
543+ if (!tma_op_to_barrier_id_.count (call_ref)) {
544+ return IRMutatorWithAnalyzer::VisitExpr_ (op);
545+ }
546+ auto barrier_id = tma_op_to_barrier_id_[call_ref];
475547 auto new_args = op->args ;
476548 new_args.Set (0 , barrier_id);
477549 if (!has_warp_specialization_)
0 commit comments