@@ -344,12 +344,14 @@ class BarrierCreationRewriter : public StmtExprMutator {
344344 : restore_barrier_ids_(std::move(restore_barrier_ids)),
345345 producer_thread_extent_ (std::move(producer_thread_extent)),
346346 ensure_min_count_(ensure_min_count),
347- default_barrier_thread_count_(std::move(default_barrier_thread_count)) {}
347+ default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
348+ }
348349
349350 PrimExpr VisitExpr_ (const CallNode *op) {
350351 if (op->op .same_as (create_list_of_mbarrier ())) {
351352 size_t cur_n = op->args .size ();
352- size_t need_n = std::max<size_t >(cur_n, static_cast <size_t >(ensure_min_count_));
353+ size_t need_n =
354+ std::max<size_t >(cur_n, static_cast <size_t >(ensure_min_count_));
353355
354356 std::vector<bool > replace (cur_n, false );
355357 for (auto &id : restore_barrier_ids_) {
@@ -504,13 +506,15 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
504506 // so codegen can emit mbarrier[index]. This handles degenerate
505507 // producer-only kernels where no arrive() is seen and mapping is empty.
506508 auto arg0 = op->args [0 ].as <Call>();
507- bool is_1d_tma_load = arg0 && !arg0.value ()->op .same_as (create_tma_descriptor ()) &&
508- !arg0.value ()->op .same_as (create_tma_im2col_descriptor ());
509+ bool is_1d_tma_load =
510+ arg0 && !arg0.value ()->op .same_as (create_tma_descriptor ()) &&
511+ !arg0.value ()->op .same_as (create_tma_im2col_descriptor ());
509512 if (is_1d_tma_load && op->args .size () >= 3 ) {
510513 if (const auto *imm = op->args [2 ].as <IntImmNode>()) {
511514 Array<PrimExpr> new_args = op->args ;
512515 new_args.Set (2 , Call (DataType::Handle (), get_mbarrier (),
513- {IntImm (DataType::Int (32 ), static_cast <int >(imm->value ))}));
516+ {IntImm (DataType::Int (32 ),
517+ static_cast <int >(imm->value ))}));
514518 return Call (op->dtype , op->op , new_args);
515519 }
516520 }
0 commit comments