Skip to content

Commit 1cb0c43

Browse files
committed
Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic.
1 parent c3212aa commit 1cb0c43

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/transform/inject_tma_barrier.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,12 +344,14 @@ class BarrierCreationRewriter : public StmtExprMutator {
344344
: restore_barrier_ids_(std::move(restore_barrier_ids)),
345345
producer_thread_extent_(std::move(producer_thread_extent)),
346346
ensure_min_count_(ensure_min_count),
347-
default_barrier_thread_count_(std::move(default_barrier_thread_count)) {}
347+
default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
348+
}
348349

349350
PrimExpr VisitExpr_(const CallNode *op) {
350351
if (op->op.same_as(create_list_of_mbarrier())) {
351352
size_t cur_n = op->args.size();
352-
size_t need_n = std::max<size_t>(cur_n, static_cast<size_t>(ensure_min_count_));
353+
size_t need_n =
354+
std::max<size_t>(cur_n, static_cast<size_t>(ensure_min_count_));
353355

354356
std::vector<bool> replace(cur_n, false);
355357
for (auto &id : restore_barrier_ids_) {
@@ -504,13 +506,15 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
504506
// so codegen can emit mbarrier[index]. This handles degenerate
505507
// producer-only kernels where no arrive() is seen and mapping is empty.
506508
auto arg0 = op->args[0].as<Call>();
507-
bool is_1d_tma_load = arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
508-
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
509+
bool is_1d_tma_load =
510+
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
511+
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
509512
if (is_1d_tma_load && op->args.size() >= 3) {
510513
if (const auto *imm = op->args[2].as<IntImmNode>()) {
511514
Array<PrimExpr> new_args = op->args;
512515
new_args.Set(2, Call(DataType::Handle(), get_mbarrier(),
513-
{IntImm(DataType::Int(32), static_cast<int>(imm->value))}));
516+
{IntImm(DataType::Int(32),
517+
static_cast<int>(imm->value))}));
514518
return Call(op->dtype, op->op, new_args);
515519
}
516520
}

0 commit comments

Comments
 (0)