Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/transform/annotate_warp_group_reg_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ class SetMaxNRegInjector : public StmtExprMutator {
}
auto producer_body = if_then_else->then_case;
Optional<Stmt> consumer_body = if_then_else->else_case;
ICHECK(consumer_body.defined()) << "Consumer body is undefined";
// In some degenerate warp-specialized patterns (e.g., producer-only),
// the consumer body may be absent. Handle gracefully by only annotating
// the producer side when consumer is missing.

auto dec_reg = nreg_[0].as<IntImmNode>()->value;
auto inc_reg = nreg_[1].as<IntImmNode>()->value;
Expand All @@ -150,15 +152,20 @@ class SetMaxNRegInjector : public StmtExprMutator {
producer_stmts.push_back(producer_body);
auto new_producer_body = SeqStmt(producer_stmts);

Array<Stmt> consumer_stmts;
consumer_stmts.push_back(inc_reg_stmt);
consumer_stmts.push_back(consumer_body.value());
auto new_consumer_body = SeqStmt(consumer_stmts);
Stmt new_if_stmt;
if (consumer_body.defined()) {
Array<Stmt> consumer_stmts;
consumer_stmts.push_back(inc_reg_stmt);
consumer_stmts.push_back(consumer_body.value());
auto new_consumer_body = SeqStmt(consumer_stmts);
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_consumer_body);
} else {
// No consumer branch; keep the if-then form.
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body);
}

auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_consumer_body);
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);

return new_attr;
} else {
return StmtExprMutator::VisitStmt_(op);
Expand Down
118 changes: 95 additions & 23 deletions src/transform/inject_tma_barrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,15 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer {

void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(mbarrier_expect_tx())) {
PrimExpr e = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)]
.as<CallNode>()
->args[0];
auto int_set = arith::EvalSet(e, var_int_set_);
expect_.push_back(if_depth_ == 1);
sequence.push_back(0);
int_sets_.push_back(int_set);
expect_tx_count_ += 1;
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (tma_op_to_barrier_id_.count(call_ref)) {
PrimExpr e = tma_op_to_barrier_id_[call_ref].as<CallNode>()->args[0];
auto int_set = arith::EvalSet(e, var_int_set_);
expect_.push_back(if_depth_ == 1);
sequence.push_back(0);
int_sets_.push_back(int_set);
expect_tx_count_ += 1;
}
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
sequence.push_back(1);
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
Expand Down Expand Up @@ -337,32 +338,61 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer {
class BarrierCreationRewriter : public StmtExprMutator {
public:
BarrierCreationRewriter(std::vector<int> restore_barrier_ids,
PrimExpr producer_thread_extent)
PrimExpr producer_thread_extent,
int ensure_min_count = 0,
PrimExpr default_barrier_thread_count = 1)
: restore_barrier_ids_(std::move(restore_barrier_ids)),
producer_thread_extent_(std::move(producer_thread_extent)) {}
producer_thread_extent_(std::move(producer_thread_extent)),
ensure_min_count_(ensure_min_count),
default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
}

PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(create_list_of_mbarrier())) {
std::vector<bool> tmp_(op->args.size(), false);
Array<PrimExpr> new_args;
size_t cur_n = op->args.size();
size_t need_n =
std::max<size_t>(cur_n, static_cast<size_t>(ensure_min_count_));

// Mark barriers to restore across the full needed length, not just the
// original length, so newly appended entries can be restored as well.
std::vector<bool> replace(need_n, false);
for (auto &id : restore_barrier_ids_) {
tmp_[id] = true;
if (id >= 0 && static_cast<size_t>(id) < replace.size()) {
replace[id] = true;
}
}

for (size_t i{0}; i < op->args.size(); ++i) {
if (tmp_[i]) {
Array<PrimExpr> new_args;
new_args.reserve(need_n);

// Preserve/override existing entries
for (size_t i{0}; i < cur_n; ++i) {
if (replace[i]) {
new_args.push_back(producer_thread_extent_);
} else {
new_args.push_back(op->args[i]);
}
}
// Append additional barriers if required
for (size_t i = cur_n; i < need_n; ++i) {
if (replace[i]) {
new_args.push_back(producer_thread_extent_);
} else {
new_args.push_back(default_barrier_thread_count_);
}
}

return Call(op->dtype, op->op, new_args);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}

private:
std::vector<int> restore_barrier_ids_;
PrimExpr producer_thread_extent_;
int ensure_min_count_{0};
PrimExpr default_barrier_thread_count_{1};
};

// we trust mbarrier_wait_parity to be correct
Expand Down Expand Up @@ -399,8 +429,31 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
collector.barrier_id_to_range(),
has_create_list_of_mbarrier);
f.CopyOnWrite()->body = rewriter(f->body);
// Compute the minimum number of barriers actually referenced in the body
// after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
struct GetMbarrierMaxIdxCollector : public StmtExprVisitor {
int max_idx{-1};
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(get_mbarrier())) {
if (op->args.size() == 1) {
if (const auto *imm = op->args[0].as<IntImmNode>()) {
max_idx = std::max(max_idx, static_cast<int>(imm->value));
}
}
}
StmtExprVisitor::VisitExpr_(op);
}
};

GetMbarrierMaxIdxCollector max_idx_collector;
max_idx_collector(f->body);
int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count
Comment on lines +432 to +450
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify that only constant barrier IDs are used in practice.

The GetMbarrierMaxIdxCollector only captures constant integer barrier IDs (line 439). If the codebase uses dynamic barrier IDs—computed from variables or loop indices—those won't be counted, potentially leading to insufficient barrier creation.

Run the following script to check for non-constant barrier ID usage:


🏁 Script executed:

#!/bin/bash
# Description: Search for get_mbarrier calls with non-constant arguments
# to verify the assumption that only constant IDs are used.

# Search for get_mbarrier calls and show context
rg -nP 'get_mbarrier\(' --type=cc -A 2 -B 2

# Also check test files for dynamic usage patterns
rg -nP 'get_mbarrier\(' --type=py -A 2 -B 2 -g '**/*test*.py'

Length of output: 9092


The script output reveals critical findings: dynamic barrier IDs ARE used in practice.


🏁 Script executed:

#!/bin/bash
# Find pass registration in inject_tma_barrier.cc
fd -name "inject_tma_barrier.cc" --type f -exec grep -l "TVM_REGISTER\|CreatePrimFuncPass" {} \;

# Check for registration details
rg -n "TVM_REGISTER.*inject_tma_barrier\|CreatePrimFuncPass" -A 5 -B 5 2>/dev/null || echo "Not found with grep, trying broader search..."

# Look for the full registration
rg -nP "InjectTMABarrier|inject_tma_barrier" --type=cc -A 3 -B 3 -g "*inject_tma_barrier*"

Length of output: 303


🏁 Script executed:

#!/bin/bash
# Find inject_tma_barrier.cc file
fd inject_tma_barrier.cc

# Search for pass registration in that file
rg "TVM_REGISTER|CreatePrimFuncPass" -A 5 -B 5 inject_tma_barrier.cc 2>/dev/null || true

# Look for how the pass handles expressions
rg -nP "GetMbarrierMaxIdxCollector|VisitExpr" inject_tma_barrier.cc -A 2 -B 2 2>/dev/null || true

# Check the end of the file for pass registration
tail -50 $(fd inject_tma_barrier.cc)

Length of output: 1847


🏁 Script executed:

#!/bin/bash
# Find where ensure_min_count is used
rg -n "ensure_min_count" src/transform/inject_tma_barrier.cc -B 2 -A 5

# Find BarrierCreationRewriter to understand impact
rg -n "BarrierCreationRewriter" src/transform/inject_tma_barrier.cc -B 3 -A 10

Length of output: 3343


Fix GetMbarrierMaxIdxCollector to handle dynamic barrier ID expressions.

The collector currently only recognizes constant integer barrier IDs via as<IntImmNode>() at line 439. However, test files demonstrate widespread use of dynamic barrier IDs like k % 3 and k % 3 + 3 within loops (see test_tilelang_transform_warp_specialized.py lines 83, 85, 89, etc.). These dynamic expressions are not IntImmNode and are therefore not counted, causing ensure_min_count to be underestimated. This results in create_list_of_mbarrier() allocating fewer barriers than actually needed, leading to runtime access violations.

You must enhance the collector to compute bounds on dynamic expressions (using arith::Analyzer which is already available as analyzer) to determine the maximum possible barrier ID, ensuring sufficient barrier allocation.

🤖 Prompt for AI Agents
In src/transform/inject_tma_barrier.cc around lines 432-450, the
GetMbarrierMaxIdxCollector only handles IntImmNode constants; update VisitExpr_
so when get_mbarrier() args[0] is not an IntImmNode it uses the existing
arith::Analyzer (named analyzer) to compute an upper bound for the expression
(e.g., via analyzer->int_set(expr) or const-int-bound helper) and updates
max_idx with that bound; if the analyzer reports a finite maximum use that
value, and if the bound is unbounded/unknown set max_idx to a conservative safe
upper value (e.g., 1024) to avoid under-allocation, then continue traversal as
before.


// For simple TMA-only producers, default barrier arrive count should be 1
// (only the elected leader performs the TMA arrive/expect).
auto barrier_creation_rewriter = BarrierCreationRewriter(
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_);
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_,
ensure_min_count, Integer(1));
f.CopyOnWrite()->body = barrier_creation_rewriter(f->body);
return f;
}
Expand Down Expand Up @@ -453,10 +506,27 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {

PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
// check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (!tma_op_to_barrier_id_.count(call_ref)) {
// For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
// so codegen can emit mbarrier[index]. This handles degenerate
// producer-only kernels where no arrive() is seen and mapping is empty.
auto arg0 = op->args[0].as<Call>();
bool is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
if (is_1d_tma_load && op->args.size() >= 3) {
if (const auto *imm = op->args[2].as<IntImmNode>()) {
Array<PrimExpr> new_args = op->args;
new_args.Set(2, Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32),
static_cast<int>(imm->value))}));
return Call(op->dtype, op->op, new_args);
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
auto barrier_id = tma_op_to_barrier_id_[call_ref];
auto new_args = op->args;
auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load =
Expand All @@ -469,9 +539,11 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
}
return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (!tma_op_to_barrier_id_.count(call_ref)) {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
auto barrier_id = tma_op_to_barrier_id_[call_ref];
auto new_args = op->args;
new_args.Set(0, barrier_id);
if (!has_warp_specialization_)
Expand Down
Loading