From cd11fcb052ee1054d0aeb94a3fd4a46fbfa21b1e Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Wed, 7 Jan 2026 15:51:24 +0800 Subject: [PATCH 01/27] reimplement ThreadSync pass --- src/transform/common/constr_visitor.h | 3 + src/transform/thread_storage_sync.cc | 1716 ++++++++++++++++--------- 2 files changed, 1078 insertions(+), 641 deletions(-) diff --git a/src/transform/common/constr_visitor.h b/src/transform/common/constr_visitor.h index f54540c4f..bc855ced6 100644 --- a/src/transform/common/constr_visitor.h +++ b/src/transform/common/constr_visitor.h @@ -104,6 +104,7 @@ struct ConstrSet { struct ConstrVisitor : public tir::StmtExprVisitor { private: using Base = tir::StmtExprVisitor; + struct Guard { std::vector &constrs; ~Guard() { constrs.pop_back(); } @@ -114,6 +115,8 @@ struct ConstrVisitor : public tir::StmtExprVisitor { } public: + using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; void VisitIfThenElseExpr(const PrimExpr cond, const PrimExpr true_value, const PrimExpr false_value) { { diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 0627678e1..4cd8c5f40 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -33,392 +33,372 @@ #include #include "../op/builtin.h" +#include "./common/constr_visitor.h" #include "./common/thread_sync_types.h" -#include "./storage_access.h" #include "arith/ir_mutator_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" +#include +#include +#include +#include + +#include +#include + +#include "arith/ir_visitor_with_analyzer.h" +#include "runtime/thread_storage_scope.h" +#include +#include +#include + +#include +#include + +#include "../op/builtin.h" +#include "tir/transforms/ir_utils.h" namespace tvm { namespace tl { +using namespace tir; +using namespace ffi; +using arith::IRVisitorWithAnalyzer; +using runtime::StorageRank; +using runtime::StorageScope; + +bool IsThreadInvariant_(const PrimExpr &cond) { + if (auto call = cond.as()) { + if (auto opt_call_op = call->op.as()) { + const auto &call_op = opt_call_op.value(); + if (call_op.same_as(builtin::tvm_thread_invariant())) { + return true; + } + } + } + return false; +} + using namespace tir; using arith::IRMutatorWithAnalyzer; -class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor { +// There are cases where necessary syncthreads is not inserted by +// ThreadSyncInserter. For example, syncthreads is needed after async_wait_queue +// in the second loop below, but since ThreadSyncInserter is not aware of the +// asynchronous semantics, it cannot tell that the syncthreads is needed there. +// +// // Pipeline prologue +// for i in range(125): +// async_commit_queue(0): +// async_scope: +// shared[(i + 3) % 4] = ... +// ... +// +// // Pipeline Epilogue +// for i in range(3): +// async_wait_queue(0, 2 - i): +// local[...] = shared[(i + 125) % 4] + +// This class adds syncthreads after all async_wait_queue. That includes +// syncthreads that can be inserted by ThreadSyncInserter as well, but +// ThreadSyncInserter will not insert duplicate syncthreads if it finds an +// existing one at the synchronization point. +class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator { public: - explicit TileLangThreadSyncPlanner(StorageScope sync_scope) + explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) : sync_scope_(std::move(sync_scope)) {} - // The syncs inserted before each statement - std::unordered_set syncs_inserted_; + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tvm::tir::attr::async_wait_queue_scope) { + auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm(sync_scope_.to_string())})); + auto inner = op->body.as(); + ICHECK(inner && + inner->attr_key == tvm::tir::attr::async_wait_inflight_count); + auto zero = make_zero(DataType::Int(32)); + auto new_body = SeqStmt({sync, inner->body}); + return AttrStmt(zero, tvm::tir::attr::async_wait_queue_scope, op->value, + AttrStmt(zero, tvm::tir::attr::async_wait_inflight_count, + inner->value, new_body)); + } + return StmtExprMutator::VisitStmt_(op); + } -protected: - bool Enabled(const VarNode *buf, const StorageScope &scope) const final { - return in_device_env() && scope == sync_scope_; +private: + StorageScope sync_scope_; +}; + +class ThreadSyncInserter : public StmtExprMutator { +public: + ThreadSyncInserter(StorageScope sync_scope, + const std::unordered_set &syncs) + : sync_scope_(std::move(sync_scope)), syncs_(syncs) {} + + Stmt VisitStmt(const Stmt &stmt) final { + if (syncs_.empty()) + return stmt; + if (syncs_.count(stmt.get())) { + Stmt barrier; + if (sync_scope_.rank == StorageRank::kGlobal) { + barrier = MakeGlobalBarrier(); + } else { + barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm(sync_scope_.to_string())})); + } + // Mutate after query, to avoid stmt change. + auto ret = StmtExprMutator::VisitStmt(stmt); + ret = SeqStmt({barrier, ret}); + return ret; + } else { + return StmtExprMutator::VisitStmt(stmt); + } } - // Plan the sync - std::vector Summarize(std::vector seq, - const ForNode *loop) final { - // Redirect all "shared.dyn" buffer access to the same buffer var - // so that the accesses can be planned together. - Var shared_dyn_buf; - for (StmtEntry &entry : seq) { - for (AccessEntry &access : entry.access) { - if (access.scope.rank == StorageRank::kShared && - access.scope.tag == ".dyn" && access.buffer.defined()) { - if (!shared_dyn_buf.defined()) { - shared_dyn_buf = access.buffer; - } else { - access.buffer = shared_dyn_buf; - } - } + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(op->buffer->data).rank == StorageRank::kGlobal) { + ++rw_stats_[op->buffer->data].read_count; + } + return StmtExprMutator::VisitExpr_(op); + } + Stmt VisitStmt_(const BufferStoreNode *op) final { + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(op->buffer->data).rank == StorageRank::kGlobal) { + ++rw_stats_[op->buffer->data].write_count; + } + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tvm::tir::attr::thread_extent) { + bool temp = true; + std::swap(temp, in_thread_env_); + thread_extents_.push_back(op); + Stmt ret = StmtExprMutator::VisitStmt_(op); + thread_extents_.pop_back(); + std::swap(temp, in_thread_env_); + // first thread scope. + if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) { + ret = InitGlobalBarrier(ret.as()); + num_blocks_ = PrimExpr(); + is_lead_ = PrimExpr(); } + return ret; + } else { + return StmtExprMutator::VisitStmt_(op); } + } - // Unsynced reads and writes - std::vector reads; - std::vector writes; - // if it is a loop, rotate two times to consider effect of loop. - // simulation based approach to find dependencies - for (size_t i = 0; i < seq.size(); ++i) { - const StmtEntry &s = seq[i]; - // check if sync before statement is needed. - bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); - // Apply the syncs added already. - - if (sync_before_stmt) { - reads.clear(); - writes.clear(); + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK_EQ(op->args.size(), 5U); + Var buffer_var(Downcast(op->args[1])); + const IntImmNode *flag = op->args[4].as(); + if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[buffer_var].read_count; + } + if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[buffer_var].write_count; } + return expr; + } else if (op->op.same_as(builtin::address_of())) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK_EQ(op->args.size(), 1U) + << "address_of should only have one argument (Buffer)"; - for (const AccessEntry &acc : s.access) { - if (acc.type == kRead) { - if (FindConflict(writes, acc, false)) { - sync_before_stmt = true; - break; - } - } else if (acc.type == kWrite) { - if (FindConflict(reads, acc, false) || - FindConflict(writes, acc, false)) { - sync_before_stmt = true; - break; - } - } else if (acc.type == kSync) { - reads.clear(); - writes.clear(); + if (auto load = op->args[0].as()) { + Var buffer_var(Downcast(load->buffer->data)); + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[buffer_var].read_count; } - } - // If sync is inserted. remove the irrelevant things. - if (sync_before_stmt) { - reads.clear(); - writes.clear(); - } - // Add the read/write of current statement - for (const AccessEntry &acc : s.access) { - if (acc.type == kRead) { - reads.push_back(acc); - } else if (acc.type == kWrite) { - writes.push_back(acc); - } else if (acc.type == kSync) { - reads.clear(); - writes.clear(); + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[buffer_var].write_count; } + return expr; + } else { + return StmtExprMutator::VisitExpr_(op); } + } else { + return StmtExprMutator::VisitExpr_(op); + } + } - if (sync_before_stmt) { - insert_syncs(s.stmt); +private: + // RW statistics about data + struct Entry { + int read_count{0}; + int write_count{0}; + }; + + // Get current storage scope. + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetPtrStorageScope(std::move(buffer_var))); + } + + // private functions. + Stmt InitGlobalBarrier(const AttrStmtNode *op) { + ICHECK(op != nullptr); + Array pargs = { + StringImm(runtime::symbol::tvm_prepare_global_barrier)}; + Stmt prep = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); + Stmt body = op->body; + for (const auto &kv : rw_stats_) { + const auto &e = kv.second; + if (e.read_count != 0 && e.write_count != 0) { + body = AttrStmt(kv.first, tvm::tir::attr::volatile_scope, 1, body); } } - if (loop != nullptr) { - // Check if the loop body contains any reads in the same sync scope. - // If there are reads, we conservatively keep the sync within the loop - // body to preserve per-iteration ordering when needed. If there are no - // reads (e.g., only writes to shared.dyn), we can safely hoist the sync - // to before the loop to avoid redundant barriers. - bool has_read_in_scope = false; - for (const StmtEntry &s : seq) { - for (const AccessEntry &acc : s.access) { - if (acc.type == kRead && acc.scope == sync_scope_) { - has_read_in_scope = true; - break; - } - } - if (has_read_in_scope) - break; - } - // If there is a loop-carried dependency, insert a single sync - // before the loop rather than hoisting a sync into the loop body. - // This reduces redundant per-iteration synchronizations for cases - // where each iteration touches disjoint regions (e.g., stmatrix - // writes to shared.dyn) and only a global ordering before/after the - // loop is required. - for (size_t i = 0; i < seq.size(); ++i) { - const StmtEntry &s = seq[i]; - if (syncs_inserted_.count(s.stmt) != 0) - break; - if (reads.empty() && writes.empty()) - break; - bool need_loop_sync = false; - for (const AccessEntry &acc : s.access) { - if (acc.type == kRead) { - if (FindConflict(writes, acc, true)) { - need_loop_sync = true; - break; - } - } else if (acc.type == kWrite) { - if (FindConflict(reads, acc, true) || - FindConflict(writes, acc, true)) { - need_loop_sync = true; - break; - } - } else if (acc.type == kSync) { - reads.clear(); - writes.clear(); - } - } - if (need_loop_sync) { - if (!has_read_in_scope) { - // Mark the loop itself to receive a sync before it, instead of - // inserting inside the loop body. This ensures a single sync is - // emitted outside the loop and avoids per-iteration overhead. - insert_syncs(loop); - } else { - // Fall back to inserting before the first conflicting statement - // inside the loop to maintain correctness when reads are present. - insert_syncs(s.stmt); - } - break; + rw_stats_.clear(); + Stmt kinit = Evaluate( + Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {})); + body = SeqStmt({kinit, body}); + body = AttrStmt(op->node, op->attr_key, op->value, body); + return SeqStmt({prep, body}); + } + Stmt MakeGlobalBarrier() { + ICHECK(sync_scope_.rank == StorageRank::kGlobal); + if (!num_blocks_.defined()) { + ICHECK(!is_lead_.defined()); + num_work_dim_ = thread_extents_.size(); + for (const AttrStmtNode *attr : thread_extents_) { + IterVar iv = Downcast(attr->node); + runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag); + if (s.rank == 0) { + num_blocks_ = + (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); + } else if (s.rank == 1) { + PrimExpr cond = iv->var == make_zero(iv->var.dtype()); + is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; } } + } else { + ICHECK_EQ(num_work_dim_, thread_extents_.size()); } - // return the exposed entries, remove unnecessary ones. - int sync_count = 0; - // head are before first sync, tail are after last sync - std::vector head, tail; - AccessEntry esync; - esync.threads = this->env_threads(); - esync.thread_range = this->ComputeThreadRange(esync.threads); - esync.type = kSync; - esync.scope = sync_scope_; + return Evaluate( + Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_})); + } + // data structure. + StorageScope sync_scope_; + const std::unordered_set &syncs_; - for (const StmtEntry &s : seq) { - if (syncs_inserted_.count(s.stmt)) { - if (sync_count != 0) { - tail.clear(); - } else { - head.push_back(esync); - } - ++sync_count; - } - for (const AccessEntry &acc : s.access) { - if (acc.type == kSync) { - if (sync_count != 0) { - tail.clear(); - } else { - head.push_back(esync); - } - ++sync_count; - } else { - if (sync_count != 0) { - tail.push_back(acc); - } else { - head.push_back(acc); - } - } - } - } - head.insert(head.end(), tail.begin(), tail.end()); - if (loop != nullptr) { - // clear double buffer flag after a loop is finished. - for (AccessEntry &e : head) { - e.double_buffer_write = false; - } - } - return head; + // The read write statistics of storage + std::unordered_map rw_stats_; + // The statistics for global barrier + bool in_thread_env_{false}; + // memorized results + std::vector thread_extents_; + size_t num_work_dim_{0}; + PrimExpr num_blocks_; + PrimExpr is_lead_; +}; + +class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer { +public: + static Stmt Rewrite(Stmt stmt) { + arith::Analyzer analyzer; + ThreadPartialSyncRewriter rewriter(&analyzer); + return rewriter(std::move(stmt)); } private: - // find conflicting entry in vec. - bool FindConflict(const std::vector &prev, - const AccessEntry &curr, bool loop_carry) { - for (const AccessEntry &x : prev) { - if (FindConflict(x, curr, loop_carry)) { - return true; - } - } - return false; - } + explicit ThreadPartialSyncRewriter(arith::Analyzer *analyzer) + : IRMutatorWithAnalyzer(analyzer) {} - bool FindConflict(const AccessEntry &prev, const AccessEntry &curr, - bool loop_carry) { - // Special case: ignore conflicts between async-copy writes (e.g., TMA - // loads into shared memory). Multiple async writes do not require - // interspersed barriers among themselves. We still respect conflicts with - // reads to ensure visibility before consumption. - if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy && - curr.is_async_copy) { - return false; - } - // Access to different buffers does not conflict. - if (!prev.buffer.same_as(curr.buffer)) { - return false; - } + Stmt VisitStmt_(const EvaluateNode *op) final { + const CallNode *call = nullptr; + if (op->value->IsInstance()) { + call = op->value.as(); + if (call->op.same_as(builtin::tvm_storage_sync())) { + const auto &args = call->args; + ICHECK(!args.empty()); + const auto *scope_node = args[0].as(); + ICHECK(scope_node != nullptr); + const std::string &scope = scope_node->value; - // Assumes no race between threads - // Same index value means no conflicts - // TODO(tqchen) more standard set based testing. - bool has_same_index = true; - bool range_is_equal = true; - bool range_is_overlap = true; + if (args.size() != 1 || (scope != "shared" && scope != "shared.dyn")) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } - for (const auto &kv : prev.thread_range) { - if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) { - range_is_equal = false; - break; + return ProcessSharedSync(call, scope); } } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } - if (prev.buffer_indices.size() != curr.buffer_indices.size()) { - // They are not the same indices, should be conflict. - return true; - } - if (prev.is_pointer_access || curr.is_pointer_access) { - // For accesses created via tvm_access_ptr we may still be able to prove - // disjointness using their byte ranges. If both sides expose a touched - // interval and we can show they don't overlap, skip the conflict. - if (prev.is_pointer_access && curr.is_pointer_access && - PointerAccessIsDisjoint(prev, curr)) { - return false; - } - // Otherwise fall back to the conservative answer: treat them as - // overlapping. - return true; + Stmt ProcessSharedSync(const CallNode *op, const std::string &scope) { + // Get thread bounds + auto bound_tx = analyzer_->const_int_bound(tx_); + auto bound_ty = analyzer_->const_int_bound(ty_); + auto bound_tz = analyzer_->const_int_bound(tz_); + + // Check if all threads are participating (full extent) + if (IsFullThreadExtent(tx_, bound_tx) && + IsFullThreadExtent(ty_, bound_ty) && + IsFullThreadExtent(tz_, bound_tz)) { + return Evaluate(IRMutatorWithAnalyzer::VisitExpr_(op)); } - for (size_t i = 0; i < prev.buffer_indices.size(); i++) { - auto prev_dtype = prev.dtype; - auto curr_dtype = curr.dtype; + // Calculate thread extents + auto extent_tx = CalculateThreadExtent(tx_, bound_tx); + auto extent_ty = CalculateThreadExtent(ty_, bound_ty); + auto extent_tz = CalculateThreadExtent(tz_, bound_tz); - const auto &prev_indice = prev.buffer_indices[i]; - const auto &curr_indice = curr.buffer_indices[i]; + // Create or get barrier info + ThreadBoundKey key{bound_tx->min_value, bound_tx->max_value, + bound_ty->min_value, bound_ty->max_value, + bound_tz->min_value, bound_tz->max_value}; - if (!ExprDeepEqual()(prev_indice, curr_indice)) { - PrimExpr prev_indice_bytes = - analyzer_.Simplify(prev_indice * prev_dtype.bytes()); - PrimExpr curr_indice_bytes = - analyzer_.Simplify(curr_indice * curr_dtype.bytes()); + auto [barrier_id, thread_count] = + GetOrCreateBarrier(key, extent_tx, extent_ty, extent_tz); + if (thread_count % 32 != 0) { + // TODO(lei): This is a workaround for the case where the thread count is + // not a multiple of 32. we should enhance the pass to analysis index + // instead of buffer expression etc. + return Stmt(); + } - has_same_index = false; + // Create new sync call with barrier info + Array new_args = {StringImm(scope), + IntImm(DataType::Int(32), barrier_id), + IntImm(DataType::Int(32), thread_count)}; + return Evaluate(Call(op->dtype, op->op, new_args)); + } - // If both are const, we can check if they are disjoint - // by checking if the bounds are disjoint - // [1024, 2048], [2048, 3072] are disjoint - // [1024, 2048], [1024, 1024] are not disjoint - auto prev_bound = analyzer_.const_int_bound(prev_indice_bytes); - auto curr_bound = analyzer_.const_int_bound(curr_indice_bytes); - if (prev_bound.defined() && curr_bound.defined()) { - if ((prev_bound->min_value) > (curr_bound->max_value) || - (curr_bound->min_value) > (prev_bound->max_value)) { - range_is_overlap = false; - break; - } - } + std::pair GetOrCreateBarrier(const ThreadBoundKey &key, + size_t extent_tx, + size_t extent_ty, + size_t extent_tz) { + if (barrier_id_map_.count(key)) { + return {barrier_id_map_[key], thread_count_map_[key]}; + } - // if we can prove prev_indice < curr_indice or prev_indice > - // curr_indice, then they are not overlap - auto prev_indices_dtype = prev_indice.dtype(); - auto curr_indices_dtype = curr_indice.dtype(); - if (prev_indices_dtype.lanes() != curr_indices_dtype.lanes()) { - // can not support different lanes binary op like <, >, <=, >= - // skip otherwise it will lead to error - continue; - } + size_t barrier_id = + barrier_id_map_.size() + + static_cast(ReservedNamedBarriers::kFirstUsedBarrier); + size_t thread_count = extent_tx * extent_ty * extent_tz; - // provably disjoint means no overlap, for example: - // we can prove that tx - 128 < tx + 128, tx in [0, 128] - // However, we should apply tx split because - // tx < tx + 32 when tx in [0, 128] is not disjoint - // because [0, 128] is not disjoint with [32, 160] - // so we should split tx into tx0 and tx1. + barrier_id_map_[key] = barrier_id; + thread_count_map_[key] = thread_count; - struct ThreadVarInfo { - const char *name_prev; - const char *name_curr; - IterVar iv; - } thread_vars[] = { - {"tx1", "tx2", tx_}, - {"ty1", "ty2", ty_}, - {"tz1", "tz2", tz_}, - }; + return {barrier_id, thread_count}; + } - for (const auto &info : thread_vars) { - Var prev_var(info.name_prev, info.iv->var.dtype()); - Var curr_var(info.name_curr, info.iv->var.dtype()); - analyzer_.Bind(prev_var, info.iv->dom); - analyzer_.Bind(curr_var, info.iv->dom); - prev_indice_bytes = - Substitute(prev_indice_bytes, {{info.iv->var, prev_var}}); - curr_indice_bytes = - Substitute(curr_indice_bytes, {{info.iv->var, curr_var}}); - } - - bool provably_disjoint = - analyzer_.CanProve(prev_indice_bytes < curr_indice_bytes, - arith::ProofStrength::kSymbolicBound) || - analyzer_.CanProve(prev_indice_bytes > curr_indice_bytes, - arith::ProofStrength::kSymbolicBound); - - if (provably_disjoint) { - range_is_overlap = false; - break; - } - } - - if (!has_same_index) { - break; - } - } - - if (has_same_index && range_is_equal) { - return false; - } - - // If this is a read into a double buffer that was previously - // swapped out, then it doesn't conflict. - if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { - return false; - } - - // If nothing else allows sharing the same buffer, then they are - // in conflict. - // if range_is_overlap is true, then they are in conflict, we should return - // true. if range_is_overlap is false, then they are not in conflict, we - // should return false. - return range_is_overlap; - } - - bool PointerAccessIsDisjoint(const AccessEntry &lhs, const AccessEntry &rhs) { - if (lhs.touched.size() != 1 || rhs.touched.size() != 1) { - return false; - } - PrimExpr lhs_min = analyzer_.Simplify(lhs.touched[0].min()); - PrimExpr lhs_max = analyzer_.Simplify(lhs.touched[0].max()); - PrimExpr rhs_min = analyzer_.Simplify(rhs.touched[0].min()); - PrimExpr rhs_max = analyzer_.Simplify(rhs.touched[0].max()); - - if (analyzer_.CanProve(lhs_max < rhs_min, - arith::ProofStrength::kSymbolicBound)) { - return true; - } - if (analyzer_.CanProve(rhs_max < lhs_min, - arith::ProofStrength::kSymbolicBound)) { - return true; + size_t CalculateThreadExtent(const IterVar &iv, + const arith::ConstIntBound &bound) { + if (!analyzer_->const_int_bound.IsBound(iv->var)) { + return 1; } - return false; + return bound->max_value - bound->min_value + 1; } - void VisitStmt_(const AttrStmtNode *op) final { + Stmt VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == tvm::tir::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { @@ -429,16 +409,29 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor { tz_ = iv; } } - TileLangStorageAccessVisitor::VisitStmt_(op); + return IRMutatorWithAnalyzer::VisitStmt_(op); } - void insert_syncs(const Object *obj) { - if (syncs_inserted_.count(obj)) - return; - syncs_inserted_.insert(obj); + bool IsFullThreadExtent(const IterVar &iv, + const arith::ConstIntBound &bound) { + if (!analyzer_->const_int_bound.IsBound(iv->var)) { + return true; + } + + if (!iv->dom.defined()) { + return true; + } + + const auto *min_node = iv->dom->min.as(); + const auto *extent_node = iv->dom->extent.as(); + + int64_t min = min_node->value; + int64_t extent = extent_node->value; + int64_t max = min + extent - 1; + + return min == bound->min_value && max == bound->max_value; } -private: // Member variables IterVar tx_ = IterVar(Range::FromMinExtent(0, 1), Var("tx"), IterVarType::kDataPar); @@ -446,369 +439,810 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor { IterVar(Range::FromMinExtent(0, 1), Var("ty"), IterVarType::kDataPar); IterVar tz_ = IterVar(Range::FromMinExtent(0, 1), Var("tz"), IterVarType::kDataPar); - // synchronization scope - StorageScope sync_scope_; + std::unordered_map barrier_id_map_; + std::unordered_map thread_count_map_; }; -// There are cases where necessary syncthreads is not inserted by -// ThreadSyncInserter. For example, syncthreads is needed after async_wait_queue -// in the second loop below, but since ThreadSyncInserter is not aware of the -// asynchronous semantics, it cannot tell that the syncthreads is needed there. -// -// // Pipeline prologue -// for i in range(125): -// async_commit_queue(0): -// async_scope: -// shared[(i + 3) % 4] = ... -// ... -// -// // Pipeline Epilogue -// for i in range(3): -// async_wait_queue(0, 2 - i): -// local[...] = shared[(i + 125) % 4] - -// This class adds syncthreads after all async_wait_queue. That includes -// syncthreads that can be inserted by ThreadSyncInserter as well, but -// ThreadSyncInserter will not insert duplicate syncthreads if it finds an -// existing one at the synchronization point. -class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator { -public: - explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) +struct TileLangThreadSyncPlanner : public ConstrVisitor { + explicit TileLangThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(std::move(sync_scope)) {} - - Stmt VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == tvm::tir::attr::async_wait_queue_scope) { - auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), - {StringImm(sync_scope_.to_string())})); - auto inner = op->body.as(); - ICHECK(inner && - inner->attr_key == tvm::tir::attr::async_wait_inflight_count); - auto zero = make_zero(DataType::Int(32)); - auto new_body = SeqStmt({sync, inner->body}); - return AttrStmt(zero, tvm::tir::attr::async_wait_queue_scope, op->value, - AttrStmt(zero, tvm::tir::attr::async_wait_inflight_count, - inner->value, new_body)); + /*! \brief Storage access type */ + enum AccessType : uint8_t { + kRead, + kWrite, + kSync, + kAlloc, + // acquired version of read, only need to handle WAR dep. + kReadAcquire + }; + /*! \brief An access entry */ + struct AccessEntry { + /*! \brief The thread index that access this entry */ + Array threads; + /*! \brief The buffer variable, if any */ + Array buffer_indices; + ConstrSet cset; + /*! \brief The buffer ranges for pointer access */ + Array buffer_ranges; + Var buffer = NullValue(); + /*! \brief The access data type */ + DataType dtype; + /*! \brief The touched access range + * + * Has one IntSet for each index in the buffer being accessed. + */ + Array touched; + /*! \brief The type of access */ + AccessType type; + /*! \brief The storage scope */ + StorageScope scope; + /*! \brief Whether the access is double buffer write */ + bool double_buffer_write = false; + /*! \brief Whether the access is pointer access */ + bool is_pointer_access = false; + /*! \brief Whether this access originates from an async copy context + * (e.g., inside a TMA load) and therefore multiple writes + * among themselves should not force barriers between them. */ + bool is_async_copy = false; + }; + /*! \brief Access pattern about a single statement */ + struct StmtEntry { + /*! \brief The statement */ + const Object *stmt{}; + /*! \brief access patterns in the statement */ + std::vector access; + }; + // access scope + std::vector> scope_; + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetPtrStorageScope(std::move(buffer_var))); + } + void VisitExpr_(const BufferLoadNode *op) final { + Var buf = op->buffer->data; + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); + StorageScope scope = GetScope(buf); + if (Enabled(buf.get(), scope)) { + ICHECK(allow_append_) + << tvm::ffi::GetRef(op) << " " << scope.to_string(); + AccessEntry e{.cset = constr_stack_}; + e.threads = env_threads(); + e.buffer = buf; + e.buffer_indices = op->indices; + e.dtype = op->dtype.element_of(); + for (const auto &index : op->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } + e.type = kRead; + e.scope = scope; + curr_stmt_.access.emplace_back(std::move(e)); } - return StmtExprMutator::VisitStmt_(op); + // traverse child + ConstrVisitor::VisitExpr_(op); } - -private: - StorageScope sync_scope_; -}; - -class ThreadSyncInserter : public StmtExprMutator { -public: - ThreadSyncInserter(StorageScope sync_scope, - const std::unordered_set &syncs) - : sync_scope_(std::move(sync_scope)), syncs_(syncs) {} - - Stmt VisitStmt(const Stmt &stmt) final { - if (syncs_.empty()) - return stmt; - if (syncs_.count(stmt.get())) { - Stmt barrier; - if (sync_scope_.rank == StorageRank::kGlobal) { - barrier = MakeGlobalBarrier(); - } else { - barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), - {StringImm(sync_scope_.to_string())})); + void VisitStmt_(const BufferStoreNode *op) final { + allow_append_ = true; + ICHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.stmt = op; + + Var buf = op->buffer->data; + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); + StorageScope scope = GetScope(buf); + if (Enabled(buf.get(), scope)) { + AccessEntry e{.cset = constr_stack_}; + e.threads = env_threads(); + e.buffer = buf; + e.buffer_indices = op->indices; + e.dtype = op->value.dtype().element_of(); + for (const auto &index : op->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); } - // Mutate after query, to avoid stmt change. - auto ret = StmtExprMutator::VisitStmt(stmt); - ret = SeqStmt({barrier, ret}); - return ret; - } else { - return StmtExprMutator::VisitStmt(stmt); + e.type = kWrite; + e.scope = scope; + curr_stmt_.access.emplace_back(std::move(e)); } + // traverse child + ConstrVisitor::VisitStmt_(op); + // push to the scope + scope_.back().push_back(curr_stmt_); + // clear access entry. + curr_stmt_.access.clear(); + allow_append_ = false; } - PrimExpr VisitExpr_(const BufferLoadNode *op) final { - if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer->data).rank == StorageRank::kGlobal) { - ++rw_stats_[op->buffer->data].read_count; + void VisitStmt_(const EvaluateNode *op) final { + allow_append_ = true; + ICHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.stmt = op; + ConstrVisitor::VisitStmt_(op); + // push to the scope + if (!curr_stmt_.access.empty()) { + scope_.back().push_back(curr_stmt_); + curr_stmt_.access.clear(); } - return StmtExprMutator::VisitExpr_(op); + allow_append_ = false; } - Stmt VisitStmt_(const BufferStoreNode *op) final { - if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer->data).rank == StorageRank::kGlobal) { - ++rw_stats_[op->buffer->data].write_count; + + void VisitStmt_(const LetStmtNode *op) final { + allow_append_ = true; + ICHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.stmt = op; + this->VisitExpr(op->value); + // push to the scope + scope_.back().push_back(curr_stmt_); + // clear access entry. + curr_stmt_.access.clear(); + allow_append_ = false; + // traverse body block + this->VisitStmt(op->body); + } + void VisitStmt_(const BlockNode *op) final { + auto block = Downcast(op); + for (const auto &buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); } - return StmtExprMutator::VisitStmt_(op); + ConstrVisitor::VisitStmt_(op); } - Stmt VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == tvm::tir::attr::thread_extent) { - bool temp = true; - std::swap(temp, in_thread_env_); - thread_extents_.push_back(op); - Stmt ret = StmtExprMutator::VisitStmt_(op); - thread_extents_.pop_back(); - std::swap(temp, in_thread_env_); - // first thread scope. - if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) { - ret = InitGlobalBarrier(ret.as()); - num_blocks_ = PrimExpr(); - is_lead_ = PrimExpr(); + void VisitStmt_(const AttrStmtNode *op) override { + if (op->attr_key == tvm::tir::attr::double_buffer_write) { + ICHECK(double_buffer_write_ == nullptr); + double_buffer_write_ = op->node.as(); + scope_.push_back(std::vector()); + ConstrVisitor::VisitStmt_(op); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + if (!s.access.empty()) { + for (AccessEntry &e : s.access) { + if (e.type == kWrite && e.buffer.get() == double_buffer_write_) { + e.double_buffer_write = true; + } + } + scope_.back().emplace_back(std::move(s)); } - return ret; + double_buffer_write_ = nullptr; + } else if (op->attr_key == tvm::tir::attr::coproc_scope) { + IterVar iv = Downcast(op->node); + env_threads_.push_back(iv); + ConstrVisitor::VisitStmt_(op); + env_threads_.pop_back(); + } else if (op->attr_key == tvm::tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + env_threads_.push_back(iv); + ICHECK_NE(iv->thread_tag.length(), 0U); + // analyzer_.Bind( + // iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), + // op->value)); + + if (!in_device_env_) { + in_device_env_ = true; + scope_.push_back(std::vector()); + ConstrVisitor::VisitStmt_(op); + // no need to take the result as the thread barrier automatically syncs. + Summarize(std::move(scope_.back()), nullptr); + in_device_env_ = false; + scope_.pop_back(); + } else { + ConstrVisitor::VisitStmt_(op); + } + env_threads_.pop_back(); + } else if (op->attr_key == tvm::tir::attr::hand_threaded) { + // skip this pass on blocks that were hand_threaded + // this avoids control flow and read/write conflicts + // between hand-threaded kernels and automatic threading } else { - return StmtExprMutator::VisitStmt_(op); + ConstrVisitor::VisitStmt_(op); } } - PrimExpr VisitExpr_(const CallNode *op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - ICHECK_EQ(op->args.size(), 5U); - Var buffer_var(Downcast(op->args[1])); - const IntImmNode *flag = op->args[4].as(); - if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && - GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[buffer_var].read_count; + void VisitStmt_(const ForNode *op) final { + scope_.push_back(std::vector()); + ConstrVisitor::VisitStmt_(op); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), op); + scope_.pop_back(); + if (!s.access.empty()) { + // relax the touched set to contain all ranges in the loop. + std::unordered_map relax_map; + relax_map[op->loop_var.get()] = + arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); + for (AccessEntry &e : s.access) { + if (e.buffer.defined()) { + ICHECK(!e.touched.empty()); + Array new_touched; + for (const auto &touched : e.touched) { + new_touched.push_back(arith::EvalSet(touched, relax_map)); + } + e.touched = std::move(new_touched); + } } - if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && - GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[buffer_var].write_count; + } + if (!s.access.empty()) { + scope_.back().emplace_back(std::move(s)); + } + } + /** + * @brief Visit an IfThenElse statement and collect storage access summaries + * for its branches. + * + * Visits the if-then-else node's condition and both branches to summarize + * buffer reads, writes, and synchronization events under the condition's + * constraints. If the condition is not thread-invariant, increments an + * internal condition counter for the duration of processing. + * + * Behavior and side effects: + * - Evaluates the condition expression (using ExtractRealCondition) and + * applies it as a constraint while summarizing the then-branch. + * - For the else-branch (when present), applies the negated, + * analyzer-simplified condition + * (analyzer_.rewrite_simplify(Not(real_condition))) as the constraint. + * - Accumulates summarized StmtEntry access information for the then/else + * branches and appends a combined StmtEntry for the IfThenElseNode into the + * current scope. + * - Temporarily toggles allow_append_ and clears curr_stmt_.access during + * condition evaluation and branch summarization. + * - Modifies internal state: scope_ (push/pop of temporary branch scopes), + * curr_stmt_.access, and condition_counter_ (incremented/decremented when the + * condition is not thread-invariant). + */ + void VisitStmt_(const IfThenElseNode *op) final { + bool is_thread_invariant = IsThreadInvariant_(op->condition); + if (!is_thread_invariant) { + ++condition_counter_; + } + + allow_append_ = true; + this->VisitExpr(op->condition); + + // Preserve accesses collected from the condition expression so they + // participate in dependency analysis. Otherwise, a write to shared memory + // immediately followed by an if-condition reading that memory would not + // trigger a sync before the if-statement. + std::vector cond_access = std::move(curr_stmt_.access); + allow_append_ = false; + + scope_.push_back(std::vector()); + { + this->VisitStmt(op->then_case); + } + + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + // Merge the condition's access summary into the if-statement's access list + // so the planner can insert a sync before the if when necessary. + if (!cond_access.empty()) { + s.access.insert(s.access.begin(), cond_access.begin(), cond_access.end()); + } + if (op->else_case) { + scope_.push_back(std::vector()); + { + this->VisitStmt(op->else_case.value()); } - return expr; - } else if (op->op.same_as(builtin::address_of())) { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - ICHECK_EQ(op->args.size(), 1U) - << "address_of should only have one argument (Buffer)"; + auto v = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + s.access.insert(s.access.end(), v.begin(), v.end()); + } + scope_.back().emplace_back(std::move(s)); + if (!is_thread_invariant) { + --condition_counter_; + } + } + + void VisitStmt_(const WhileNode *op) final { + bool is_thread_invariant = IsThreadInvariant_(op->condition); + if (!is_thread_invariant) { + ++condition_counter_; + } + this->VisitExpr(op->condition); + scope_.push_back(std::vector()); + this->VisitStmt(op->body); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + scope_.back().emplace_back(std::move(s)); + if (!is_thread_invariant) { + --condition_counter_; + } + } + void VisitExpr_(const CallNode *op) final { + // Mark async TMA load context so that tvm_access_ptr within the call + // can be tagged accordingly. + auto is_tma_load = [&]() { + if (auto opt = op->op.as()) { + const Op &call_op = opt.value(); + return call_op.same_as(tl::tma_load()) || + call_op.same_as(tl::tma_load_im2col()); + } + return false; + }(); + if (is_tma_load) { + tma_depth_++; + for (const auto &a : op->args) { + this->VisitExpr(a); + } + tma_depth_--; + return; + } + if (op->op.same_as(builtin::address_of())) { + ICHECK_EQ(op->args.size(), 1U); if (auto load = op->args[0].as()) { - Var buffer_var(Downcast(load->buffer->data)); - if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[buffer_var].read_count; + Buffer buffer = load->buffer; + DataType dtype = buffer->dtype; + const VarNode *buffer_var = buffer->data.as(); + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buffer_var), buffer); + StorageScope scope = GetScope(tvm::ffi::GetRef(buffer_var)); + Array buffer_ranges; + // from indices to buffer indices + ICHECK(buffer->shape.size() == load->indices.size()); + // Use buffer shape and indices to compute the buffer_ranges for each + // dimension. + for (size_t i = 0; i < buffer->shape.size(); ++i) { + PrimExpr min = load->indices[i]; + PrimExpr extent = make_const(buffer->shape[i].dtype(), 1); + buffer_ranges.push_back(Range::FromMinExtent(min, extent)); } - if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[buffer_var].write_count; + if (Enabled(buffer_var, scope)) { + ICHECK(allow_append_); + AccessEntry e{.cset = constr_stack_}; + e.threads = env_threads(); + e.dtype = dtype; + e.buffer = Downcast(buffer->data); + e.buffer_ranges = buffer_ranges; + for (const auto &index : load->indices) { + e.touched.push_back(arith::IntSet::Vector(index)); + } + e.is_pointer_access = true; + e.type = kRead; + e.scope = scope; + curr_stmt_.access.emplace_back(e); } - return expr; + ConstrVisitor::VisitExpr_(load); } else { - return StmtExprMutator::VisitExpr_(op); + ConstrVisitor::VisitExpr_(op); + } + } else if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + const VarNode *buffer_var = op->args[1].as(); + PrimExpr offset = op->args[2]; + PrimExpr extent = op->args[3]; + const IntImmNode *flag = op->args[4].as(); + StorageScope scope = GetScope(tvm::ffi::GetRef(buffer_var)); + // The buffer scope. + if (Enabled(buffer_var, scope)) { + ICHECK(allow_append_); + Array buffer_ranges; + if (buffer_data_to_buffer_.find(tvm::ffi::GetRef(buffer_var)) == + buffer_data_to_buffer_.end()) { + // cannot find buffer map, use the default buffer + buffer_ranges = {Range::FromMinExtent(offset, extent)}; + } else { + Buffer buffer = + buffer_data_to_buffer_.at(tvm::ffi::GetRef(buffer_var)); + auto buffer_shape = buffer->shape; + // convert 1d offset to multi-dimensional index + auto linear_to_indices = [this](PrimExpr offset, + const Array &shape) { + Array indices; + PrimExpr remaining = std::move(offset); + for (size_t i = 0; i < shape.size(); ++i) { + PrimExpr stride = make_const(DataType::Int(32), 1); + for (size_t j = i + 1; j < shape.size(); ++j) { + stride = stride * shape[j]; + } + PrimExpr idx = FloorDiv(remaining, stride); + remaining = FloorMod(remaining, stride); + indices.push_back(idx); + } + return indices; + }; + Array start_indices = + linear_to_indices(offset, buffer_shape); + Array end_indices = + linear_to_indices(offset + extent, buffer_shape); + for (size_t i = 0; i < buffer_shape.size(); ++i) { + buffer_ranges.push_back(Range::FromMinExtent( + start_indices[i], end_indices[i] - start_indices[i])); + } + } + AccessEntry e{.cset = constr_stack_}; + e.threads = env_threads(); + e.dtype = dtype; + e.buffer = tvm::ffi::GetRef(buffer_var); + e.buffer_ranges = buffer_ranges; + e.is_pointer_access = true; + e.touched = { + arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))}; + e.scope = scope; + if (flag->value & 1) { + e.type = kRead; + e.is_async_copy = (tma_depth_ > 0); + curr_stmt_.access.emplace_back(e); + } + if (flag->value & 2) { + e.type = kWrite; + e.is_async_copy = (tma_depth_ > 0); + curr_stmt_.access.emplace_back(e); + } + } + ConstrVisitor::VisitExpr_(op); + } else if (op->op.same_as(builtin::tvm_storage_sync())) { + ICHECK(allow_append_); + const std::string &s = op->args[0].as()->value; + if (s != "warp") { + StorageScope scope = StorageScope::Create(s); + AccessEntry e{.cset = constr_stack_}; + e.threads = env_threads(); + e.type = kSync; + e.scope = StorageScope::Create(s); + curr_stmt_.access.emplace_back(std::move(e)); } } else { - return StmtExprMutator::VisitExpr_(op); + ConstrVisitor::VisitExpr_(op); } } -private: - // RW statistics about data - struct Entry { - int read_count{0}; - int write_count{0}; - }; - - // Get current storage scope. - StorageScope GetScope(Var buffer_var) const { - return StorageScope::Create(GetPtrStorageScope(std::move(buffer_var))); + void SetBufferDataToBuffer(const Var &buffer_var, const Buffer &buffer) { + buffer_data_to_buffer_.Set(buffer_var, buffer); } - // private functions. - Stmt InitGlobalBarrier(const AttrStmtNode *op) { - ICHECK(op != nullptr); - Array pargs = { - StringImm(runtime::symbol::tvm_prepare_global_barrier)}; - Stmt prep = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); - Stmt body = op->body; - for (const auto &kv : rw_stats_) { - const auto &e = kv.second; - if (e.read_count != 0 && e.write_count != 0) { - body = AttrStmt(kv.first, tvm::tir::attr::volatile_scope, 1, body); + std::vector Summarize(std::vector seq, + const ForNode *loop) { + // Redirect all "shared.dyn" buffer access to the same buffer var + // so that the accesses can be planned together. + Var shared_dyn_buf; + for (StmtEntry &entry : seq) { + for (AccessEntry &access : entry.access) { + if (access.scope.rank == StorageRank::kShared && + access.scope.tag == ".dyn" && access.buffer.defined()) { + if (!shared_dyn_buf.defined()) { + shared_dyn_buf = access.buffer; + } else { + access.buffer = shared_dyn_buf; + } + } } } - rw_stats_.clear(); - Stmt kinit = Evaluate( - Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {})); - body = SeqStmt({kinit, body}); - body = AttrStmt(op->node, op->attr_key, op->value, body); - return SeqStmt({prep, body}); - } - Stmt MakeGlobalBarrier() { - ICHECK(sync_scope_.rank == StorageRank::kGlobal); - if (!num_blocks_.defined()) { - ICHECK(!is_lead_.defined()); - num_work_dim_ = thread_extents_.size(); - for (const AttrStmtNode *attr : thread_extents_) { - IterVar iv = Downcast(attr->node); - runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag); - if (s.rank == 0) { - num_blocks_ = - (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); - } else if (s.rank == 1) { - PrimExpr cond = iv->var == make_zero(iv->var.dtype()); - is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; + + // Unsynced reads and writes + std::vector reads; + std::vector writes; + // if it is a loop, rotate two times to consider effect of loop. + // simulation based approach to find dependencies + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry &s = seq[i]; + // check if sync before statement is needed. + bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); + // Apply the syncs added already. + + if (sync_before_stmt) { + reads.clear(); + writes.clear(); + } + + for (const AccessEntry &acc : s.access) { + if (acc.type == kRead) { + if (FindConflict(writes, acc, false)) { + sync_before_stmt = true; + break; + } + } else if (acc.type == kWrite) { + if (FindConflict(reads, acc, false) || + FindConflict(writes, acc, false)) { + sync_before_stmt = true; + break; + } + } else if (acc.type == kSync) { + reads.clear(); + writes.clear(); + } + } + // If sync is inserted. remove the irrelevant things. + if (sync_before_stmt) { + reads.clear(); + writes.clear(); + } + // Add the read/write of current statement + for (const AccessEntry &acc : s.access) { + if (acc.type == kRead) { + reads.push_back(acc); + } else if (acc.type == kWrite) { + writes.push_back(acc); + } else if (acc.type == kSync) { + reads.clear(); + writes.clear(); + } + } + + if (sync_before_stmt) { + insert_syncs(s.stmt); + } + } + if (loop != nullptr) { + // Check if the loop body contains any reads in the same sync scope. + // If there are reads, we conservatively keep the sync within the loop + // body to preserve per-iteration ordering when needed. If there are no + // reads (e.g., only writes to shared.dyn), we can safely hoist the sync + // to before the loop to avoid redundant barriers. + bool has_read_in_scope = false; + for (const StmtEntry &s : seq) { + for (const AccessEntry &acc : s.access) { + if (acc.type == kRead && acc.scope == sync_scope_) { + has_read_in_scope = true; + break; + } + } + if (has_read_in_scope) + break; + } + // If there is a loop-carried dependency, insert a single sync + // before the loop rather than hoisting a sync into the loop body. + // This reduces redundant per-iteration synchronizations for cases + // where each iteration touches disjoint regions (e.g., stmatrix + // writes to shared.dyn) and only a global ordering before/after the + // loop is required. + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry &s = seq[i]; + if (syncs_inserted_.count(s.stmt) != 0) + break; + if (reads.empty() && writes.empty()) + break; + bool need_loop_sync = false; + for (const AccessEntry &acc : s.access) { + if (acc.type == kRead) { + if (FindConflict(writes, acc, true)) { + need_loop_sync = true; + break; + } + } else if (acc.type == kWrite) { + if (FindConflict(reads, acc, true) || + FindConflict(writes, acc, true)) { + need_loop_sync = true; + break; + } + } else if (acc.type == kSync) { + reads.clear(); + writes.clear(); + } + } + if (need_loop_sync) { + if (!has_read_in_scope) { + // Mark the loop itself to receive a sync before it, instead of + // inserting inside the loop body. This ensures a single sync is + // emitted outside the loop and avoids per-iteration overhead. + insert_syncs(loop); + } else { + // Fall back to inserting before the first conflicting statement + // inside the loop to maintain correctness when reads are present. + insert_syncs(s.stmt); + } + break; + } + } + } + // return the exposed entries, remove unnecessary ones. + int sync_count = 0; + // head are before first sync, tail are after last sync + std::vector head, tail; + AccessEntry esync{.cset = constr_stack_}; + ; + esync.type = kSync; + esync.scope = sync_scope_; + + for (const StmtEntry &s : seq) { + if (syncs_inserted_.count(s.stmt)) { + if (sync_count != 0) { + tail.clear(); + } else { + head.push_back(esync); + } + ++sync_count; + } + for (const AccessEntry &acc : s.access) { + if (acc.type == kSync) { + if (sync_count != 0) { + tail.clear(); + } else { + head.push_back(esync); + } + ++sync_count; + } else { + if (sync_count != 0) { + tail.push_back(acc); + } else { + head.push_back(acc); + } } } - } else { - ICHECK_EQ(num_work_dim_, thread_extents_.size()); } - return Evaluate( - Call(DataType::Int(32), builtin::tvm_storage_sync(), - {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_})); - } - // data structure. - StorageScope sync_scope_; - const std::unordered_set &syncs_; - - // The read write statistics of storage - std::unordered_map rw_stats_; - // The statistics for global barrier - bool in_thread_env_{false}; - // memorized results - std::vector thread_extents_; - size_t num_work_dim_{0}; - PrimExpr num_blocks_; - PrimExpr is_lead_; -}; - -class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer { -public: - static Stmt Rewrite(Stmt stmt) { - arith::Analyzer analyzer; - ThreadPartialSyncRewriter rewriter(&analyzer); - return rewriter(std::move(stmt)); + head.insert(head.end(), tail.begin(), tail.end()); + if (loop != nullptr) { + // clear double buffer flag after a loop is finished. + for (AccessEntry &e : head) { + e.double_buffer_write = false; + } + } + return head; } + // The syncs inserted before each statement + std::unordered_set syncs_inserted_; + const Array &env_threads() const { return env_threads_; } private: - explicit ThreadPartialSyncRewriter(arith::Analyzer *analyzer) - : IRMutatorWithAnalyzer(analyzer) {} - - Stmt VisitStmt_(const EvaluateNode *op) final { - const CallNode *call = nullptr; - if (op->value->IsInstance()) { - call = op->value.as(); - if (call->op.same_as(builtin::tvm_storage_sync())) { - const auto &args = call->args; - ICHECK(!args.empty()); - const auto *scope_node = args[0].as(); - ICHECK(scope_node != nullptr); - const std::string &scope = scope_node->value; - - if (args.size() != 1 || (scope != "shared" && scope != "shared.dyn")) { - return IRMutatorWithAnalyzer::VisitStmt_(op); - } - - return ProcessSharedSync(call, scope); - } - } - return IRMutatorWithAnalyzer::VisitStmt_(op); + bool Enabled(const VarNode *buf, const StorageScope &scope) { + return in_device_env() && scope == sync_scope_; + } + /*! \return whether we are in device environment. */ + bool in_device_env() const { return in_device_env_; } + // whether access appending is enabled. + bool allow_append_{false}; + // Whether we are in device environment + bool in_device_env_{false}; + // Nesting depth of tma_load/tma_load_im2col calls + int tma_depth_{0}; + // Whether we are inside condition. + int condition_counter_{0}; + // The current double buffer write scope. + const VarNode *double_buffer_write_{nullptr}; + // the current free stmt entry. + StmtEntry curr_stmt_; + // The involving threads + Array env_threads_; + // The buffer map + Map buffer_data_to_buffer_; + // Member variables + IterVar tx_ = + IterVar(Range::FromMinExtent(0, 1), Var("tx"), IterVarType::kDataPar); + IterVar ty_ = + IterVar(Range::FromMinExtent(0, 1), Var("ty"), IterVarType::kDataPar); + IterVar tz_ = + IterVar(Range::FromMinExtent(0, 1), Var("tz"), IterVarType::kDataPar); + // synchronization scope + StorageScope sync_scope_; + void insert_syncs(const Object *obj) { + if (syncs_inserted_.count(obj)) + return; + syncs_inserted_.insert(obj); } + bool FindConflict(const AccessEntry &prev, const AccessEntry &curr, + bool loop_carry) { + // Special case: ignore conflicts between async-copy writes (e.g., TMA + // loads into shared memory). Multiple async writes do not require + // interspersed barriers among themselves. We still respect conflicts with + // reads to ensure visibility before consumption. + // print_access_tentry(prev); + // print_access_tentry(curr); + if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy && + curr.is_async_copy) { + return false; + } + // Access to different buffers does not conflict. + if (!prev.buffer.same_as(curr.buffer)) { + return false; + } - Stmt ProcessSharedSync(const CallNode *op, const std::string &scope) { - // Get thread bounds - auto bound_tx = analyzer_->const_int_bound(tx_); - auto bound_ty = analyzer_->const_int_bound(ty_); - auto bound_tz = analyzer_->const_int_bound(tz_); + // Assumes no race between threads + // Same index value means no conflicts + // TODO(tqchen) more standard set based testing. + bool has_same_index = true; + bool range_is_equal = true; + bool range_is_overlap = true; - // Check if all threads are participating (full extent) - if (IsFullThreadExtent(tx_, bound_tx) && - IsFullThreadExtent(ty_, bound_ty) && - IsFullThreadExtent(tz_, bound_tz)) { - return Evaluate(IRMutatorWithAnalyzer::VisitExpr_(op)); + // for (const auto &kv : prev.thread_range) { + // if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) { + // range_is_equal = false; + // break; + // } + // } + + if (prev.buffer_indices.size() != curr.buffer_indices.size()) { + // They are not the same indices, should be conflict. + return true; } + // if (prev.is_pointer_access || curr.is_pointer_access) { + // // For accesses created via tvm_access_ptr we may still be able to + // prove + // // disjointness using their byte ranges. If both sides expose a + // touched + // // interval and we can show they don't overlap, skip the conflict. + // if (prev.is_pointer_access && curr.is_pointer_access && + // PointerAccessIsDisjoint(prev, curr)) { + // return false; + // } + // // Otherwise fall back to the conservative answer: treat them as + // // overlapping. + // return true; + // } - // Calculate thread extents - auto extent_tx = CalculateThreadExtent(tx_, bound_tx); - auto extent_ty = CalculateThreadExtent(ty_, bound_ty); - auto extent_tz = CalculateThreadExtent(tz_, bound_tz); + for (size_t i = 0; i < prev.buffer_indices.size(); i++) { + auto prev_dtype = prev.dtype; + auto curr_dtype = curr.dtype; - // Create or get barrier info - ThreadBoundKey key{bound_tx->min_value, bound_tx->max_value, - bound_ty->min_value, bound_ty->max_value, - bound_tz->min_value, bound_tz->max_value}; + const auto &prev_indice = prev.buffer_indices[i]; + const auto &curr_indice = curr.buffer_indices[i]; - auto [barrier_id, thread_count] = - GetOrCreateBarrier(key, extent_tx, extent_ty, extent_tz); - if (thread_count % 32 != 0) { - // TODO(lei): This is a workaround for the case where the thread count is - // not a multiple of 32. we should enhance the pass to analysis index - // instead of buffer expression etc. - return Stmt(); - } + if (!ExprDeepEqual()(prev_indice, curr_indice)) { + PrimExpr prev_indice_bytes = prev_indice * prev_dtype.bytes(); + PrimExpr curr_indice_bytes = curr_indice * curr_dtype.bytes(); - // Create new sync call with barrier info - Array new_args = {StringImm(scope), - IntImm(DataType::Int(32), barrier_id), - IntImm(DataType::Int(32), thread_count)}; - return Evaluate(Call(op->dtype, op->op, new_args)); - } + has_same_index = false; - std::pair GetOrCreateBarrier(const ThreadBoundKey &key, - size_t extent_tx, - size_t extent_ty, - size_t extent_tz) { - if (barrier_id_map_.count(key)) { - return {barrier_id_map_[key], thread_count_map_[key]}; - } + ConstrSet prev_cset{prev.cset}; + ConstrSet curr_cset{curr.cset}; + arith::Analyzer analyzer; - size_t barrier_id = - barrier_id_map_.size() + - static_cast(ReservedNamedBarriers::kFirstUsedBarrier); - size_t thread_count = extent_tx * extent_ty * extent_tz; + struct ThreadVarInfo { + const char *name_prev; + const char *name_curr; + IterVar iv; + } thread_vars[] = { + {"tx1", "tx2", tx_}, + {"ty1", "ty2", ty_}, + {"tz1", "tz2", tz_}, + }; - barrier_id_map_[key] = barrier_id; - thread_count_map_[key] = thread_count; + for (const auto &info : thread_vars) { + Var prev_var(info.name_prev, info.iv->var.dtype()); + Var curr_var(info.name_curr, info.iv->var.dtype()); + prev_indice_bytes = + Substitute(prev_indice_bytes, {{info.iv->var, prev_var}}); + prev_cset.Substitute({{info.iv->var, prev_var}}); + curr_indice_bytes = + Substitute(curr_indice_bytes, {{info.iv->var, curr_var}}); + curr_cset.Substitute({{info.iv->var, curr_var}}); + } + prev_cset.Populate(analyzer); + curr_cset.Populate(analyzer); - return {barrier_id, thread_count}; - } + bool provably_disjoint = + analyzer.CanProve(prev_indice_bytes != curr_indice_bytes); - size_t CalculateThreadExtent(const IterVar &iv, - const arith::ConstIntBound &bound) { - if (!analyzer_->const_int_bound.IsBound(iv->var)) { - return 1; - } - return bound->max_value - bound->min_value + 1; - } + if (provably_disjoint) { + range_is_overlap = false; + break; + } + } - Stmt VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == tvm::tir::attr::thread_extent) { - IterVar iv = Downcast(op->node); - if (iv->thread_tag == "threadIdx.x") { - tx_ = iv; - } else if (iv->thread_tag == "threadIdx.y") { - ty_ = iv; - } else if (iv->thread_tag == "threadIdx.z") { - tz_ = iv; + if (!has_same_index) { + break; } } - return IRMutatorWithAnalyzer::VisitStmt_(op); - } - bool IsFullThreadExtent(const IterVar &iv, - const arith::ConstIntBound &bound) { - if (!analyzer_->const_int_bound.IsBound(iv->var)) { - return true; + if (has_same_index && range_is_equal) { + return false; } - if (!iv->dom.defined()) { - return true; + // If this is a read into a double buffer that was previously + // swapped out, then it doesn't conflict. + if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { + return false; } - const auto *min_node = iv->dom->min.as(); - const auto *extent_node = iv->dom->extent.as(); - - int64_t min = min_node->value; - int64_t extent = extent_node->value; - int64_t max = min + extent - 1; - - return min == bound->min_value && max == bound->max_value; + // If nothing else allows sharing the same buffer, then they are + // in conflict. + // if range_is_overlap is true, then they are in conflict, we should return + // true. if range_is_overlap is false, then they are not in conflict, we + // should return false. + // LOG(WARNING) << range_is_overlap; + return range_is_overlap; + } + bool FindConflict(const std::vector &prev, + const AccessEntry &curr, bool loop_carry) { + // LOG(WARNING) << "FIND: "; + // print_access_tentry(curr); + // LOG(WARNING) << prev.size() << " " << loop_carry; + for (const AccessEntry &x : prev) { + if (FindConflict(x, curr, loop_carry)) { + return true; + } + } + return false; } - - // Member variables - IterVar tx_ = - IterVar(Range::FromMinExtent(0, 1), Var("tx"), IterVarType::kDataPar); - IterVar ty_ = - IterVar(Range::FromMinExtent(0, 1), Var("ty"), IterVarType::kDataPar); - IterVar tz_ = - IterVar(Range::FromMinExtent(0, 1), Var("tz"), IterVarType::kDataPar); - std::unordered_map barrier_id_map_; - std::unordered_map thread_count_map_; }; PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) { From fcd3659dfe071fc155d3c390faaf0240302caeaf Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Wed, 7 Jan 2026 16:28:03 +0800 Subject: [PATCH 02/27] fix bugs --- src/transform/thread_storage_sync.cc | 35 +++++++++++++++------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 4cd8c5f40..7a770b971 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -67,18 +67,6 @@ using arith::IRVisitorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; -bool IsThreadInvariant_(const PrimExpr &cond) { - if (auto call = cond.as()) { - if (auto opt_call_op = call->op.as()) { - const auto &call_op = opt_call_op.value(); - if (call_op.same_as(builtin::tvm_thread_invariant())) { - return true; - } - } - } - return false; -} - using namespace tir; using arith::IRMutatorWithAnalyzer; @@ -445,7 +433,9 @@ class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer { struct TileLangThreadSyncPlanner : public ConstrVisitor { explicit TileLangThreadSyncPlanner(StorageScope sync_scope) - : sync_scope_(std::move(sync_scope)) {} + : sync_scope_(std::move(sync_scope)) { + scope_.push_back(std::vector()); + } /*! \brief Storage access type */ enum AccessType : uint8_t { kRead, @@ -662,6 +652,18 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { scope_.back().emplace_back(std::move(s)); } } + bool IsThreadInvariant_(const PrimExpr &cond) { + if (auto call = cond.as()) { + if (auto opt_call_op = call->op.as()) { + const auto &call_op = opt_call_op.value(); + if (call_op.same_as(builtin::tvm_thread_invariant())) { + return true; + } + } + } + return false; + } + /** * @brief Visit an IfThenElse statement and collect storage access summaries * for its branches. @@ -1030,7 +1032,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { // head are before first sync, tail are after last sync std::vector head, tail; AccessEntry esync{.cset = constr_stack_}; - ; + esync.threads = this->env_threads(); esync.type = kSync; esync.scope = sync_scope_; @@ -1125,6 +1127,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (!prev.buffer.same_as(curr.buffer)) { return false; } + return true; // Assumes no race between threads // Same index value means no conflicts @@ -1191,10 +1194,10 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { Var curr_var(info.name_curr, info.iv->var.dtype()); prev_indice_bytes = Substitute(prev_indice_bytes, {{info.iv->var, prev_var}}); - prev_cset.Substitute({{info.iv->var, prev_var}}); + prev_cset = prev_cset.Substitute({{info.iv->var, prev_var}}); curr_indice_bytes = Substitute(curr_indice_bytes, {{info.iv->var, curr_var}}); - curr_cset.Substitute({{info.iv->var, curr_var}}); + curr_cset = curr_cset.Substitute({{info.iv->var, curr_var}}); } prev_cset.Populate(analyzer); curr_cset.Populate(analyzer); From fab1648c9c7c380191985d43fa5e69c5153a933a Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Wed, 7 Jan 2026 16:34:00 +0800 Subject: [PATCH 03/27] remove old storage_access --- .../eliminate_storage_sync_for_mbarrier.cc | 1 - src/transform/inject_ptx_async_copy.cc | 1 - src/transform/storage_access.cc | 483 ------------------ src/transform/storage_access.h | 182 ------- 4 files changed, 667 deletions(-) delete mode 100644 src/transform/storage_access.cc delete mode 100644 src/transform/storage_access.h diff --git a/src/transform/eliminate_storage_sync_for_mbarrier.cc b/src/transform/eliminate_storage_sync_for_mbarrier.cc index 504de732c..90c37cac8 100644 --- a/src/transform/eliminate_storage_sync_for_mbarrier.cc +++ b/src/transform/eliminate_storage_sync_for_mbarrier.cc @@ -2,7 +2,6 @@ * \file eliminate_storage_sync_for_mbarrier.cc */ #include "../op/builtin.h" -#include "./storage_access.h" #include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h" #include diff --git a/src/transform/inject_ptx_async_copy.cc b/src/transform/inject_ptx_async_copy.cc index a62bac762..19346d462 100644 --- a/src/transform/inject_ptx_async_copy.cc +++ b/src/transform/inject_ptx_async_copy.cc @@ -29,7 +29,6 @@ #include #include -#include "storage_access.h" #include "tir/ir/buffer_common.h" #include "tvm/tir/stmt.h" diff --git a/src/transform/storage_access.cc b/src/transform/storage_access.cc deleted file mode 100644 index 49c839929..000000000 --- a/src/transform/storage_access.cc +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file storage_access.cc - */ -#include "storage_access.h" - -#include -#include -#include - -#include -#include - -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h" - -namespace tvm { -namespace tl { - -using namespace tir; - -void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) { - Var buf = op->buffer->data; - buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); - StorageScope scope = GetScope(buf); - if (Enabled(buf.get(), scope)) { - ICHECK(allow_append_) << tvm::ffi::GetRef(op) << " " - << scope.to_string(); - AccessEntry e; - e.threads = env_threads(); - e.thread_range = this->ComputeThreadRange(e.threads); - e.buffer = buf; - e.buffer_indices = op->indices; - e.dtype = op->dtype.element_of(); - for (const auto &index : op->indices) { - e.touched.push_back(arith::IntSet::Vector(index)); - } - e.type = kRead; - e.scope = scope; - curr_stmt_.access.emplace_back(std::move(e)); - } - // traverse child - IRVisitorWithAnalyzer::VisitExpr_(op); -} - -void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) { - allow_append_ = true; - ICHECK_EQ(curr_stmt_.access.size(), 0U); - curr_stmt_.stmt = op; - - Var buf = op->buffer->data; - buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); - StorageScope scope = GetScope(buf); - if (Enabled(buf.get(), scope)) { - AccessEntry e; - e.threads = env_threads(); - e.thread_range = this->ComputeThreadRange(e.threads); - e.buffer = buf; - e.buffer_indices = op->indices; - e.dtype = op->value.dtype().element_of(); - for (const auto &index : op->indices) { - e.touched.push_back(arith::IntSet::Vector(index)); - } - e.type = kWrite; - e.scope = scope; - curr_stmt_.access.emplace_back(std::move(e)); - } - // traverse child - IRVisitorWithAnalyzer::VisitStmt_(op); - // push to the scope - scope_.back().push_back(curr_stmt_); - // clear access entry. - curr_stmt_.access.clear(); - allow_append_ = false; -} - -void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) { - allow_append_ = true; - ICHECK_EQ(curr_stmt_.access.size(), 0U); - curr_stmt_.stmt = op; - IRVisitorWithAnalyzer::VisitStmt_(op); - // push to the scope - if (!curr_stmt_.access.empty()) { - scope_.back().push_back(curr_stmt_); - curr_stmt_.access.clear(); - } - allow_append_ = false; -} - -void TileLangStorageAccessVisitor::VisitStmt_(const LetStmtNode *op) { - allow_append_ = true; - ICHECK_EQ(curr_stmt_.access.size(), 0U); - curr_stmt_.stmt = op; - this->VisitExpr(op->value); - // push to the scope - scope_.back().push_back(curr_stmt_); - // clear access entry. - curr_stmt_.access.clear(); - allow_append_ = false; - // traverse body block - this->VisitStmt(op->body); -} - -void TileLangStorageAccessVisitor::VisitStmt_(const BlockNode *op) { - auto block = Downcast(op); - for (const auto &buffer : block->alloc_buffers) { - ICHECK(buffer->IsInstance()); - buffer_data_to_buffer_.Set(buffer->data, buffer); - } - IRVisitorWithAnalyzer::VisitStmt_(op); -} - -void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) { - if (op->attr_key == tvm::tir::attr::double_buffer_write) { - ICHECK(double_buffer_write_ == nullptr); - double_buffer_write_ = op->node.as(); - scope_.push_back(std::vector()); - IRVisitorWithAnalyzer::VisitStmt_(op); - StmtEntry s; - s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); - scope_.pop_back(); - if (!s.access.empty()) { - for (AccessEntry &e : s.access) { - if (e.type == kWrite && e.buffer.get() == double_buffer_write_) { - e.double_buffer_write = true; - } - } - scope_.back().emplace_back(std::move(s)); - } - double_buffer_write_ = nullptr; - } else if (op->attr_key == tvm::tir::attr::coproc_scope) { - IterVar iv = Downcast(op->node); - env_threads_.push_back(iv); - IRVisitorWithAnalyzer::VisitStmt_(op); - env_threads_.pop_back(); - } else if (op->attr_key == tvm::tir::attr::thread_extent) { - IterVar iv = Downcast(op->node); - env_threads_.push_back(iv); - ICHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind( - iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); - - if (!in_device_env_) { - in_device_env_ = true; - scope_.push_back(std::vector()); - IRVisitorWithAnalyzer::VisitStmt_(op); - // no need to take the result as the thread barrier automatically syncs. - Summarize(std::move(scope_.back()), nullptr); - in_device_env_ = false; - scope_.pop_back(); - } else { - IRVisitorWithAnalyzer::VisitStmt_(op); - } - env_threads_.pop_back(); - } else if (op->attr_key == tvm::tir::attr::hand_threaded) { - // skip this pass on blocks that were hand_threaded - // this avoids control flow and read/write conflicts - // between hand-threaded kernels and automatic threading - } else { - IRVisitorWithAnalyzer::VisitStmt_(op); - } -} - -void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) { - scope_.push_back(std::vector()); - IRVisitorWithAnalyzer::VisitStmt_(op); - StmtEntry s; - s.stmt = op; - s.access = Summarize(std::move(scope_.back()), op); - scope_.pop_back(); - if (!s.access.empty()) { - // relax the touched set to contain all ranges in the loop. - std::unordered_map relax_map; - relax_map[op->loop_var.get()] = - arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); - for (AccessEntry &e : s.access) { - if (e.buffer.defined()) { - ICHECK(!e.touched.empty()); - Array new_touched; - for (const auto &touched : e.touched) { - new_touched.push_back(arith::EvalSet(touched, relax_map)); - } - e.touched = std::move(new_touched); - } - } - } - if (!s.access.empty()) { - scope_.back().emplace_back(std::move(s)); - } -} - -bool IsThreadInvariant(const PrimExpr &cond) { - if (auto call = cond.as()) { - if (auto opt_call_op = call->op.as()) { - const auto &call_op = opt_call_op.value(); - if (call_op.same_as(builtin::tvm_thread_invariant())) { - return true; - } - } - } - return false; -} - -/** - * @brief Visit an IfThenElse statement and collect storage access summaries for - * its branches. - * - * Visits the if-then-else node's condition and both branches to summarize - * buffer reads, writes, and synchronization events under the condition's - * constraints. If the condition is not thread-invariant, increments an internal - * condition counter for the duration of processing. - * - * Behavior and side effects: - * - Evaluates the condition expression (using ExtractRealCondition) and applies - * it as a constraint while summarizing the then-branch. - * - For the else-branch (when present), applies the negated, - * analyzer-simplified condition - * (analyzer_.rewrite_simplify(Not(real_condition))) as the constraint. - * - Accumulates summarized StmtEntry access information for the then/else - * branches and appends a combined StmtEntry for the IfThenElseNode into the - * current scope. - * - Temporarily toggles allow_append_ and clears curr_stmt_.access during - * condition evaluation and branch summarization. - * - Modifies internal state: scope_ (push/pop of temporary branch scopes), - * curr_stmt_.access, and condition_counter_ (incremented/decremented when the - * condition is not thread-invariant). - */ -void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) { - bool is_thread_invariant = IsThreadInvariant(op->condition); - if (!is_thread_invariant) { - ++condition_counter_; - } - - allow_append_ = true; - this->VisitExpr(op->condition); - PrimExpr real_condition = ExtractRealCondition(op->condition); - - // Preserve accesses collected from the condition expression so they - // participate in dependency analysis. Otherwise, a write to shared memory - // immediately followed by an if-condition reading that memory would not - // trigger a sync before the if-statement. - std::vector cond_access = std::move(curr_stmt_.access); - allow_append_ = false; - - scope_.push_back(std::vector()); - { - With constraint(&analyzer_, real_condition); - this->VisitStmt(op->then_case); - } - - StmtEntry s; - s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); - scope_.pop_back(); - // Merge the condition's access summary into the if-statement's access list - // so the planner can insert a sync before the if when necessary. - if (!cond_access.empty()) { - s.access.insert(s.access.begin(), cond_access.begin(), cond_access.end()); - } - if (op->else_case) { - scope_.push_back(std::vector()); - { - With constraint( - &analyzer_, analyzer_.rewrite_simplify(Not(real_condition))); - this->VisitStmt(op->else_case.value()); - } - auto v = Summarize(std::move(scope_.back()), nullptr); - scope_.pop_back(); - s.access.insert(s.access.end(), v.begin(), v.end()); - } - scope_.back().emplace_back(std::move(s)); - if (!is_thread_invariant) { - --condition_counter_; - } -} - -void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) { - bool is_thread_invariant = IsThreadInvariant(op->condition); - if (!is_thread_invariant) { - ++condition_counter_; - } - this->VisitExpr(op->condition); - scope_.push_back(std::vector()); - this->VisitStmt(op->body); - StmtEntry s; - s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); - scope_.pop_back(); - scope_.back().emplace_back(std::move(s)); - if (!is_thread_invariant) { - --condition_counter_; - } -} - -void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { - // Mark async TMA load context so that tvm_access_ptr within the call - // can be tagged accordingly. - auto is_tma_load = [&]() { - if (auto opt = op->op.as()) { - const Op &call_op = opt.value(); - return call_op.same_as(tl::tma_load()) || - call_op.same_as(tl::tma_load_im2col()); - } - return false; - }(); - if (is_tma_load) { - tma_depth_++; - for (const auto &a : op->args) { - this->VisitExpr(a); - } - tma_depth_--; - return; - } - if (op->op.same_as(builtin::address_of())) { - ICHECK_EQ(op->args.size(), 1U); - if (auto load = op->args[0].as()) { - Buffer buffer = load->buffer; - DataType dtype = buffer->dtype; - const VarNode *buffer_var = buffer->data.as(); - buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buffer_var), buffer); - StorageScope scope = GetScope(tvm::ffi::GetRef(buffer_var)); - Array buffer_ranges; - // from indices to buffer indices - ICHECK(buffer->shape.size() == load->indices.size()); - // Use buffer shape and indices to compute the buffer_ranges for each - // dimension. - for (size_t i = 0; i < buffer->shape.size(); ++i) { - PrimExpr min = load->indices[i]; - PrimExpr extent = make_const(buffer->shape[i].dtype(), 1); - buffer_ranges.push_back(Range::FromMinExtent(min, extent)); - } - if (Enabled(buffer_var, scope)) { - ICHECK(allow_append_); - AccessEntry e; - e.threads = env_threads(); - e.thread_range = this->ComputeThreadRange(e.threads); - e.dtype = dtype; - e.buffer = Downcast(buffer->data); - e.buffer_ranges = buffer_ranges; - for (const auto &index : load->indices) { - e.touched.push_back(arith::IntSet::Vector(index)); - } - e.is_pointer_access = true; - e.type = kRead; - e.scope = scope; - curr_stmt_.access.emplace_back(e); - } - IRVisitorWithAnalyzer::VisitExpr_(load); - } else { - IRVisitorWithAnalyzer::VisitExpr_(op); - } - } else if (op->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(op->args.size(), 5U); - DataType dtype = op->args[0].dtype(); - const VarNode *buffer_var = op->args[1].as(); - PrimExpr offset = op->args[2]; - PrimExpr extent = op->args[3]; - const IntImmNode *flag = op->args[4].as(); - StorageScope scope = GetScope(tvm::ffi::GetRef(buffer_var)); - // The buffer scope. - if (Enabled(buffer_var, scope)) { - ICHECK(allow_append_); - Array buffer_ranges; - if (buffer_data_to_buffer_.find(tvm::ffi::GetRef(buffer_var)) == - buffer_data_to_buffer_.end()) { - // cannot find buffer map, use the default buffer - buffer_ranges = {Range::FromMinExtent(offset, extent)}; - } else { - Buffer buffer = - buffer_data_to_buffer_.at(tvm::ffi::GetRef(buffer_var)); - auto buffer_shape = buffer->shape; - // convert 1d offset to multi-dimensional index - auto linear_to_indices = [this](PrimExpr offset, - const Array &shape) { - Array indices; - PrimExpr remaining = std::move(offset); - for (size_t i = 0; i < shape.size(); ++i) { - PrimExpr stride = make_const(DataType::Int(32), 1); - for (size_t j = i + 1; j < shape.size(); ++j) { - stride = stride * shape[j]; - } - PrimExpr idx = FloorDiv(remaining, stride); - remaining = FloorMod(remaining, stride); - indices.push_back(analyzer_.Simplify(idx)); - } - return indices; - }; - Array start_indices = linear_to_indices(offset, buffer_shape); - Array end_indices = - linear_to_indices(offset + extent, buffer_shape); - for (size_t i = 0; i < buffer_shape.size(); ++i) { - buffer_ranges.push_back(Range::FromMinExtent( - start_indices[i], - analyzer_.Simplify(end_indices[i] - start_indices[i]))); - } - } - AccessEntry e; - e.threads = env_threads(); - e.thread_range = this->ComputeThreadRange(e.threads); - e.dtype = dtype; - e.buffer = tvm::ffi::GetRef(buffer_var); - e.buffer_ranges = buffer_ranges; - e.is_pointer_access = true; - e.touched = { - arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))}; - e.scope = scope; - if (flag->value & 1) { - e.type = kRead; - e.is_async_copy = (tma_depth_ > 0); - curr_stmt_.access.emplace_back(e); - } - if (flag->value & 2) { - e.type = kWrite; - e.is_async_copy = (tma_depth_ > 0); - curr_stmt_.access.emplace_back(e); - } - } - IRVisitorWithAnalyzer::VisitExpr_(op); - } else if (op->op.same_as(builtin::tvm_storage_sync())) { - ICHECK(allow_append_); - const std::string &s = op->args[0].as()->value; - if (s != "warp") { - StorageScope scope = StorageScope::Create(s); - AccessEntry e; - e.threads = env_threads(); - e.thread_range = this->ComputeThreadRange(e.threads); - e.type = kSync; - e.scope = StorageScope::Create(s); - curr_stmt_.access.emplace_back(std::move(e)); - } - } else { - IRVisitorWithAnalyzer::VisitExpr_(op); - } -} - -Map TileLangStorageAccessVisitor::ComputeThreadRange( - const Array &threads) { - Map thread_range; - for (const auto &th : threads) { - auto thread_tag = th->thread_tag; - if (thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" || - thread_tag == "threadIdx.z") { - auto const_int_bound = analyzer_.const_int_bound(th->var); - auto min_value = const_int_bound->min_value; - auto max_value = const_int_bound->max_value; - auto extent = max_value - min_value + 1; - auto dtype = th->var.dtype(); - thread_range.Set(th->var, Range::FromMinExtent(IntImm(dtype, min_value), - IntImm(dtype, extent))); - } - } - return thread_range; -} - -StorageScope -TileLangStorageAccessVisitor::GetScope(const Var &buffer_var) const { - if (buffer_var->type_annotation.as()) { - return StorageScope::Create(GetPtrStorageScope(buffer_var)); - } - return StorageScope(); // global by default -} - -} // namespace tl -} // namespace tvm diff --git a/src/transform/storage_access.h b/src/transform/storage_access.h deleted file mode 100644 index 54114ace2..000000000 --- a/src/transform/storage_access.h +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file storage_access.h - * \brief Common data structure for storage access analysis. - */ -#ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ -#define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ - -#include -#include -#include -#include - -#include -#include - -#include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" - -namespace tvm { -namespace tl { - -using namespace tir; -using namespace ffi; -using arith::IRVisitorWithAnalyzer; -using runtime::StorageRank; -using runtime::StorageScope; - -/*! - * \brief Base class of storage access analysis - */ -class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer { -public: - /*! \brief Storage access type */ - enum AccessType : uint8_t { - kRead, - kWrite, - kSync, - kAlloc, - // acquired version of read, only need to handle WAR dep. - kReadAcquire - }; - /*! \brief An access entry */ - struct AccessEntry { - /*! \brief The thread index that access this entry */ - Array threads; - /*! \brief The touched thread range */ - Map thread_range; - /*! \brief The buffer variable, if any */ - Array buffer_indices; - /*! \brief The buffer ranges for pointer access */ - Array buffer_ranges; - Var buffer = NullValue(); - /*! \brief The access data type */ - DataType dtype; - /*! \brief The touched access range - * - * Has one IntSet for each index in the buffer being accessed. - */ - Array touched; - /*! \brief The type of access */ - AccessType type; - /*! \brief The storage scope */ - StorageScope scope; - /*! \brief Whether the access is double buffer write */ - bool double_buffer_write = false; - /*! \brief Whether the access is pointer access */ - bool is_pointer_access = false; - /*! \brief Whether this access originates from an async copy context - * (e.g., inside a TMA load) and therefore multiple writes - * among themselves should not force barriers between them. */ - bool is_async_copy = false; - }; - - /*! \brief Access pattern about a single statement */ - struct StmtEntry { - /*! \brief The statement */ - const Object *stmt{}; - /*! \brief access patterns in the statement */ - std::vector access; - }; - // override visitor pattern - void VisitExpr_(const BufferLoadNode *op) final; - void VisitStmt_(const BufferStoreNode *op) final; - void VisitStmt_(const EvaluateNode *op) final; - void VisitStmt_(const LetStmtNode *op) final; - void VisitStmt_(const AttrStmtNode *op) override; - void VisitStmt_(const ForNode *op) final; - void VisitStmt_(const IfThenElseNode *op) final; - void VisitStmt_(const WhileNode *op) final; - void VisitExpr_(const CallNode *op) final; - void VisitStmt_(const BlockNode *op) final; - - void SetBufferDataToBuffer(const Var &buffer_var, const Buffer &buffer) { - buffer_data_to_buffer_.Set(buffer_var, buffer); - } - -protected: - TileLangStorageAccessVisitor() { scope_.push_back(std::vector()); } - /*! \return number of conditions in the current scope. */ - int condition_counter() const { return condition_counter_; } - /*! \return whether we are in device environment. */ - bool in_device_env() const { return in_device_env_; } - /*! \return environment threads */ - const Array &env_threads() const { return env_threads_; } - /*! - * \brief Whether we need analyze the buffer in current scope. - * \param buffer The buffer to be checked - * \param scope The scope of the buffer. - * \return Whether the analysis of buffer is enabled. - */ - virtual bool Enabled(const VarNode *buffer, const StorageScope &scope) const { - return true; - } - /*! - * \brief Summarize the sequence of operations into parent. - * - * Insert synchronization if necessary and remove un-necessary - * memory access which are already synced. - * - * \param seq The sequence of the access operations. - * \param loop Pass loop node if it is a loop, otherwise nullptr. - * \return The summarized sequence that represent access that - * the parent should taken care of to synchronize. - */ - virtual std::vector Summarize(std::vector seq, - const ForNode *loop) = 0; - - /*! - * \brief Compute the thread range for the given threads. - * \param threads The threads to compute the range for. - * \return The thread range. - */ - Map ComputeThreadRange(const Array &threads); - - /*! - * \brief Get the scope of the buffer array. - * \return The scope of the final buffer array. - */ - StorageScope GetScope(const Var &buffer_var) const; - // access scope - std::vector> scope_; - -private: - // whether access appending is enabled. - bool allow_append_{false}; - // Whether we are in device environment - bool in_device_env_{false}; - // Nesting depth of tma_load/tma_load_im2col calls - int tma_depth_{0}; - // Whether we are inside condition. - int condition_counter_{0}; - // The current double buffer write scope. - const VarNode *double_buffer_write_{nullptr}; - // the current free stmt entry. - StmtEntry curr_stmt_; - // The involving threads - Array env_threads_; - // The buffer map - Map buffer_data_to_buffer_; -}; -} // namespace tl -} // namespace tvm -#endif // TVM_TL_TRANSFORMS_STORAGE_ACCESS_H_ From bd73f7ae7bbf99a17f5915d2615366efa196de00 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Wed, 7 Jan 2026 16:34:53 +0800 Subject: [PATCH 04/27] remove deadcode for debugging --- src/transform/thread_storage_sync.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 7a770b971..c21185afd 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1127,7 +1127,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (!prev.buffer.same_as(curr.buffer)) { return false; } - return true; // Assumes no race between threads // Same index value means no conflicts From fea912396f25766b8e828f5ff1e0ee56b5e010d4 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Wed, 7 Jan 2026 16:47:22 +0800 Subject: [PATCH 05/27] format --- src/transform/thread_storage_sync.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index c21185afd..ac801969b 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -494,7 +494,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (Enabled(buf.get(), scope)) { ICHECK(allow_append_) << tvm::ffi::GetRef(op) << " " << scope.to_string(); - AccessEntry e{.cset = constr_stack_}; + AccessEntry e{.cset = {constr_stack_}}; e.threads = env_threads(); e.buffer = buf; e.buffer_indices = op->indices; @@ -518,7 +518,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); StorageScope scope = GetScope(buf); if (Enabled(buf.get(), scope)) { - AccessEntry e{.cset = constr_stack_}; + AccessEntry e{.cset = {constr_stack_}}; e.threads = env_threads(); e.buffer = buf; e.buffer_indices = op->indices; @@ -790,7 +790,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } if (Enabled(buffer_var, scope)) { ICHECK(allow_append_); - AccessEntry e{.cset = constr_stack_}; + AccessEntry e{.cset = {constr_stack_}}; e.threads = env_threads(); e.dtype = dtype; e.buffer = Downcast(buffer->data); @@ -852,7 +852,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { start_indices[i], end_indices[i] - start_indices[i])); } } - AccessEntry e{.cset = constr_stack_}; + AccessEntry e{.cset = {constr_stack_}}; e.threads = env_threads(); e.dtype = dtype; e.buffer = tvm::ffi::GetRef(buffer_var); @@ -878,7 +878,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { const std::string &s = op->args[0].as()->value; if (s != "warp") { StorageScope scope = StorageScope::Create(s); - AccessEntry e{.cset = constr_stack_}; + AccessEntry e{.cset = {constr_stack_}}; e.threads = env_threads(); e.type = kSync; e.scope = StorageScope::Create(s); @@ -1031,7 +1031,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { int sync_count = 0; // head are before first sync, tail are after last sync std::vector head, tail; - AccessEntry esync{.cset = constr_stack_}; + AccessEntry esync{.cset = {constr_stack_}}; esync.threads = this->env_threads(); esync.type = kSync; esync.scope = sync_scope_; From e3541a66b287f365fdab46cb5a6b586d020147cf Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Thu, 8 Jan 2026 10:42:57 +0800 Subject: [PATCH 06/27] add WhileOp & expose MakeGuard in constr_visitor --- src/transform/common/constr_visitor.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transform/common/constr_visitor.h b/src/transform/common/constr_visitor.h index bc855ced6..99e24d600 100644 --- a/src/transform/common/constr_visitor.h +++ b/src/transform/common/constr_visitor.h @@ -109,6 +109,8 @@ struct ConstrVisitor : public tir::StmtExprVisitor { std::vector &constrs; ~Guard() { constrs.pop_back(); } }; + +protected: template Guard MakeGuard(const Args... args) { constr_stack_.push_back(Constr(args...)); return Guard{constr_stack_}; @@ -184,6 +186,12 @@ struct ConstrVisitor : public tir::StmtExprVisitor { Base::VisitStmt_(op); } } + void VisitStmt_(const tir::WhileNode *op) override { + { + auto guard = MakeGuard(op->condition); + Base::VisitStmt(op->body); + } + } std::vector constr_stack_; }; } // namespace tvm::tl From bf85fcbe2aacd4c113a5fe0b7c59cf64271b2b55 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Thu, 8 Jan 2026 10:43:20 +0800 Subject: [PATCH 07/27] bugfix --- src/transform/thread_storage_sync.cc | 234 +++++++++++++++++---------- 1 file changed, 151 insertions(+), 83 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index ac801969b..05b47df26 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -652,17 +652,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { scope_.back().emplace_back(std::move(s)); } } - bool IsThreadInvariant_(const PrimExpr &cond) { - if (auto call = cond.as()) { - if (auto opt_call_op = call->op.as()) { - const auto &call_op = opt_call_op.value(); - if (call_op.same_as(builtin::tvm_thread_invariant())) { - return true; - } - } - } - return false; - } /** * @brief Visit an IfThenElse statement and collect storage access summaries @@ -689,36 +678,36 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { * condition is not thread-invariant). */ void VisitStmt_(const IfThenElseNode *op) final { - bool is_thread_invariant = IsThreadInvariant_(op->condition); - if (!is_thread_invariant) { - ++condition_counter_; - } - - allow_append_ = true; - this->VisitExpr(op->condition); + StmtEntry s; + { + auto guard = MakeGuard(op->condition); + allow_append_ = true; + this->VisitExpr(op->condition); - // Preserve accesses collected from the condition expression so they - // participate in dependency analysis. Otherwise, a write to shared memory - // immediately followed by an if-condition reading that memory would not - // trigger a sync before the if-statement. - std::vector cond_access = std::move(curr_stmt_.access); - allow_append_ = false; + // Preserve accesses collected from the condition expression so they + // participate in dependency analysis. Otherwise, a write to shared memory + // immediately followed by an if-condition reading that memory would not + // trigger a sync before the if-statement. + std::vector cond_access = std::move(curr_stmt_.access); + allow_append_ = false; - scope_.push_back(std::vector()); - { - this->VisitStmt(op->then_case); - } + scope_.push_back(std::vector()); + { + this->VisitStmt(op->then_case); + } - StmtEntry s; - s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); - scope_.pop_back(); - // Merge the condition's access summary into the if-statement's access list - // so the planner can insert a sync before the if when necessary. - if (!cond_access.empty()) { - s.access.insert(s.access.begin(), cond_access.begin(), cond_access.end()); + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + // Merge the condition's access summary into the if-statement's access + // list so the planner can insert a sync before the if when necessary. + if (!cond_access.empty()) { + s.access.insert(s.access.begin(), cond_access.begin(), + cond_access.end()); + } } if (op->else_case) { + auto guard = MakeGuard(tir::Not(op->condition)); scope_.push_back(std::vector()); { this->VisitStmt(op->else_case.value()); @@ -727,17 +716,11 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { scope_.pop_back(); s.access.insert(s.access.end(), v.begin(), v.end()); } + scope_.back().emplace_back(std::move(s)); - if (!is_thread_invariant) { - --condition_counter_; - } } void VisitStmt_(const WhileNode *op) final { - bool is_thread_invariant = IsThreadInvariant_(op->condition); - if (!is_thread_invariant) { - ++condition_counter_; - } this->VisitExpr(op->condition); scope_.push_back(std::vector()); this->VisitStmt(op->body); @@ -746,9 +729,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { s.access = Summarize(std::move(scope_.back()), nullptr); scope_.pop_back(); scope_.back().emplace_back(std::move(s)); - if (!is_thread_invariant) { - --condition_counter_; - } } void VisitExpr_(const CallNode *op) final { @@ -1087,8 +1067,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { bool in_device_env_{false}; // Nesting depth of tma_load/tma_load_im2col calls int tma_depth_{0}; - // Whether we are inside condition. - int condition_counter_{0}; // The current double buffer write scope. const VarNode *double_buffer_write_{nullptr}; // the current free stmt entry. @@ -1097,13 +1075,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { Array env_threads_; // The buffer map Map buffer_data_to_buffer_; - // Member variables - IterVar tx_ = - IterVar(Range::FromMinExtent(0, 1), Var("tx"), IterVarType::kDataPar); - IterVar ty_ = - IterVar(Range::FromMinExtent(0, 1), Var("ty"), IterVarType::kDataPar); - IterVar tz_ = - IterVar(Range::FromMinExtent(0, 1), Var("tz"), IterVarType::kDataPar); // synchronization scope StorageScope sync_scope_; void insert_syncs(const Object *obj) { @@ -1111,14 +1082,98 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { return; syncs_inserted_.insert(obj); } + void print_access_tentry(const AccessEntry &access) { + std::ostringstream output; + + output << "Access Entry Information:\n"; + output << " Buffer: " << access.buffer << "\n"; + output << " Data Type: " << access.dtype << "\n"; + + std::string type_str; + switch (access.type) { + case kRead: + type_str = "Read"; + break; + case kWrite: + type_str = "Write"; + break; + case kSync: + type_str = "Sync"; + break; + case kAlloc: + type_str = "Alloc"; + break; + case kReadAcquire: + type_str = "ReadAcquire"; + break; + default: + type_str = "Unknown"; + break; + } + output << " Access Type: " << type_str << "\n"; + + output << " Storage Scope: " << access.scope.to_string() << "\n"; + + output << " Threads: ["; + for (size_t i = 0; i < access.threads.size(); ++i) { + if (i > 0) + output << ", "; + output << access.threads[i]->thread_tag; + } + output << "]\n"; + + { + output << " Constraint: {"; + arith::Analyzer analyzer_; + access.cset.Populate(analyzer_); + output << analyzer_.z3_prover.GetSMTLIB2(std::nullopt); + output << "}\n"; + } + + output << " Buffer Indices: ["; + for (size_t i = 0; i < access.buffer_indices.size(); ++i) { + if (i > 0) + output << ", "; + output << access.buffer_indices[i]; + } + output << "]\n"; + + if (!access.buffer_ranges.empty()) { + output << " Buffer Ranges: ["; + for (size_t i = 0; i < access.buffer_ranges.size(); ++i) { + if (i > 0) + output << ", "; + output << "[" << access.buffer_ranges[i]->min << ", " + << access.buffer_ranges[i]->extent << "]"; + } + output << "]\n"; + } + + if (!access.touched.empty()) { + output << " Touched Ranges: ["; + for (size_t i = 0; i < access.touched.size(); ++i) { + if (i > 0) + output << ", "; + output << access.touched[i]; + } + output << "]\n"; + } + + output << " Flags: "; + output << "double_buffer_write=" + << (access.double_buffer_write ? "true" : "false"); + output << ", is_pointer_access=" + << (access.is_pointer_access ? "true" : "false"); + output << ", is_async_copy=" << (access.is_async_copy ? "true" : "false"); + + LOG(WARNING) << output.str(); + } bool FindConflict(const AccessEntry &prev, const AccessEntry &curr, bool loop_carry) { // Special case: ignore conflicts between async-copy writes (e.g., TMA // loads into shared memory). Multiple async writes do not require // interspersed barriers among themselves. We still respect conflicts with // reads to ensure visibility before consumption. - // print_access_tentry(prev); - // print_access_tentry(curr); if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy && curr.is_async_copy) { return false; @@ -1132,16 +1187,8 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { // Same index value means no conflicts // TODO(tqchen) more standard set based testing. bool has_same_index = true; - bool range_is_equal = true; bool range_is_overlap = true; - // for (const auto &kv : prev.thread_range) { - // if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) { - // range_is_equal = false; - // break; - // } - // } - if (prev.buffer_indices.size() != curr.buffer_indices.size()) { // They are not the same indices, should be conflict. return true; @@ -1181,28 +1228,43 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { struct ThreadVarInfo { const char *name_prev; const char *name_curr; - IterVar iv; } thread_vars[] = { - {"tx1", "tx2", tx_}, - {"ty1", "ty2", ty_}, - {"tz1", "tz2", tz_}, + {"tx1", "tx2"}, + {"ty1", "ty2"}, + {"tz1", "tz2"}, }; - for (const auto &info : thread_vars) { - Var prev_var(info.name_prev, info.iv->var.dtype()); - Var curr_var(info.name_curr, info.iv->var.dtype()); + for (unsigned idx = 0; idx != 3; ++idx) { + auto &info = thread_vars[idx]; + Var old_prev_var = prev.threads[prev.threads.size() + idx - 3]->var; + Var old_curr_var = curr.threads[curr.threads.size() + idx - 3]->var; + Var prev_var(info.name_prev, old_prev_var.dtype()); + Var curr_var(info.name_curr, old_curr_var.dtype()); prev_indice_bytes = - Substitute(prev_indice_bytes, {{info.iv->var, prev_var}}); - prev_cset = prev_cset.Substitute({{info.iv->var, prev_var}}); + Substitute(prev_indice_bytes, {{old_prev_var, prev_var}}); + prev_cset = prev_cset.Substitute({{old_prev_var, prev_var}}); curr_indice_bytes = - Substitute(curr_indice_bytes, {{info.iv->var, curr_var}}); - curr_cset = curr_cset.Substitute({{info.iv->var, curr_var}}); + Substitute(curr_indice_bytes, {{old_curr_var, curr_var}}); + curr_cset = curr_cset.Substitute({{old_curr_var, curr_var}}); } prev_cset.Populate(analyzer); curr_cset.Populate(analyzer); - - bool provably_disjoint = - analyzer.CanProve(prev_indice_bytes != curr_indice_bytes); + bool provably_disjoint = false; + if (prev_indice_bytes.dtype().is_scalar() && + curr_indice_bytes.dtype().is_scalar()) { + provably_disjoint = + analyzer.CanProve(prev_indice_bytes != curr_indice_bytes); + } else { + auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); + auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); + if (prev_bound.defined() && curr_bound.defined()) { + if ((prev_bound->min_value) > (curr_bound->max_value) || + (curr_bound->min_value) > (prev_bound->max_value)) { + range_is_overlap = false; + break; + } + } + } if (provably_disjoint) { range_is_overlap = false; @@ -1215,6 +1277,16 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } } + // TODO(silent-coder): check whether range is equal + bool range_is_equal = false; + + // for (const auto &kv : prev.thread_range) { + // if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) { + // range_is_equal = false; + // break; + // } + // } + if (has_same_index && range_is_equal) { return false; } @@ -1230,14 +1302,10 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { // if range_is_overlap is true, then they are in conflict, we should return // true. if range_is_overlap is false, then they are not in conflict, we // should return false. - // LOG(WARNING) << range_is_overlap; return range_is_overlap; } bool FindConflict(const std::vector &prev, const AccessEntry &curr, bool loop_carry) { - // LOG(WARNING) << "FIND: "; - // print_access_tentry(curr); - // LOG(WARNING) << prev.size() << " " << loop_carry; for (const AccessEntry &x : prev) { if (FindConflict(x, curr, loop_carry)) { return true; From 5918b7be2d5dbcee87e296ffb8e4311373ce83c8 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Thu, 8 Jan 2026 11:28:11 +0800 Subject: [PATCH 08/27] bugfix --- src/transform/thread_storage_sync.cc | 119 ++++++++++++--------------- 1 file changed, 53 insertions(+), 66 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 05b47df26..1fcc373ff 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1215,61 +1215,62 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { const auto &prev_indice = prev.buffer_indices[i]; const auto &curr_indice = curr.buffer_indices[i]; - if (!ExprDeepEqual()(prev_indice, curr_indice)) { - PrimExpr prev_indice_bytes = prev_indice * prev_dtype.bytes(); - PrimExpr curr_indice_bytes = curr_indice * curr_dtype.bytes(); - - has_same_index = false; - - ConstrSet prev_cset{prev.cset}; - ConstrSet curr_cset{curr.cset}; - arith::Analyzer analyzer; - - struct ThreadVarInfo { - const char *name_prev; - const char *name_curr; - } thread_vars[] = { - {"tx1", "tx2"}, - {"ty1", "ty2"}, - {"tz1", "tz2"}, - }; - - for (unsigned idx = 0; idx != 3; ++idx) { - auto &info = thread_vars[idx]; - Var old_prev_var = prev.threads[prev.threads.size() + idx - 3]->var; - Var old_curr_var = curr.threads[curr.threads.size() + idx - 3]->var; - Var prev_var(info.name_prev, old_prev_var.dtype()); - Var curr_var(info.name_curr, old_curr_var.dtype()); - prev_indice_bytes = - Substitute(prev_indice_bytes, {{old_prev_var, prev_var}}); - prev_cset = prev_cset.Substitute({{old_prev_var, prev_var}}); - curr_indice_bytes = - Substitute(curr_indice_bytes, {{old_curr_var, curr_var}}); - curr_cset = curr_cset.Substitute({{old_curr_var, curr_var}}); - } - prev_cset.Populate(analyzer); - curr_cset.Populate(analyzer); - bool provably_disjoint = false; - if (prev_indice_bytes.dtype().is_scalar() && - curr_indice_bytes.dtype().is_scalar()) { - provably_disjoint = - analyzer.CanProve(prev_indice_bytes != curr_indice_bytes); - } else { - auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); - auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); - if (prev_bound.defined() && curr_bound.defined()) { - if ((prev_bound->min_value) > (curr_bound->max_value) || - (curr_bound->min_value) > (prev_bound->max_value)) { - range_is_overlap = false; - break; - } + PrimExpr prev_indice_bytes = prev_indice * prev_dtype.bytes(); + PrimExpr curr_indice_bytes = curr_indice * curr_dtype.bytes(); + + has_same_index = false; + + ConstrSet prev_cset{prev.cset}; + ConstrSet curr_cset{curr.cset}; + arith::Analyzer analyzer; + + struct ThreadVarInfo { + const char *name_prev; + const char *name_curr; + } thread_vars[] = { + {"tx1", "tx2"}, + {"ty1", "ty2"}, + {"tz1", "tz2"}, + }; + PrimExpr thread_condition = Bool(false); + for (unsigned idx = 0; idx != 3; ++idx) { + auto &info = thread_vars[idx]; + Var old_prev_var = prev.threads[prev.threads.size() + idx - 3]->var; + Var old_curr_var = curr.threads[curr.threads.size() + idx - 3]->var; + Var prev_var(info.name_prev, old_prev_var.dtype()); + Var curr_var(info.name_curr, old_curr_var.dtype()); + thread_condition = + tir::Or(thread_condition, tir::Not(tir::EQ(prev_var, curr_var))); + prev_indice_bytes = + Substitute(prev_indice_bytes, {{old_prev_var, prev_var}}); + prev_cset = prev_cset.Substitute({{old_prev_var, prev_var}}); + curr_indice_bytes = + Substitute(curr_indice_bytes, {{old_curr_var, curr_var}}); + curr_cset = curr_cset.Substitute({{old_curr_var, curr_var}}); + } + analyzer.EnterConstraint(thread_condition); + prev_cset.Populate(analyzer); + curr_cset.Populate(analyzer); + bool provably_disjoint = false; + if (prev_indice_bytes.dtype().is_scalar() && + curr_indice_bytes.dtype().is_scalar()) { + provably_disjoint = analyzer.CanProve( + tir::Not(tir::EQ(prev_indice_bytes, curr_indice_bytes))); + } else { + auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); + auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); + if (prev_bound.defined() && curr_bound.defined()) { + if ((prev_bound->min_value) > (curr_bound->max_value) || + (curr_bound->min_value) > (prev_bound->max_value)) { + range_is_overlap = false; + break; } } + } - if (provably_disjoint) { - range_is_overlap = false; - break; - } + if (provably_disjoint) { + range_is_overlap = false; + break; } if (!has_same_index) { @@ -1277,20 +1278,6 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } } - // TODO(silent-coder): check whether range is equal - bool range_is_equal = false; - - // for (const auto &kv : prev.thread_range) { - // if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) { - // range_is_equal = false; - // break; - // } - // } - - if (has_same_index && range_is_equal) { - return false; - } - // If this is a read into a double buffer that was previously // swapped out, then it doesn't conflict. if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { From e1fdb594ffd9cc7a5bbbe20763a0514e0fc862bb Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Thu, 8 Jan 2026 12:40:04 +0800 Subject: [PATCH 09/27] bugfix --- src/transform/thread_storage_sync.cc | 96 ++++++++++++++++++++++------ 1 file changed, 77 insertions(+), 19 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 1fcc373ff..d4fab90ea 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1082,7 +1082,54 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { return; syncs_inserted_.insert(obj); } - void print_access_tentry(const AccessEntry &access) { + bool PointerAccessIsDisjoint(const AccessEntry &lhs, const AccessEntry &rhs) { + if (lhs.touched.size() != 1 || rhs.touched.size() != 1) { + return false; + } + ConstrSet prev_cset{lhs.cset}; + ConstrSet curr_cset{rhs.cset}; + arith::Analyzer analyzer; + + struct ThreadVarInfo { + const char *name_prev; + const char *name_curr; + } thread_vars[] = { + {"tx1", "tx2"}, + {"ty1", "ty2"}, + {"tz1", "tz2"}, + }; + PrimExpr lhs_min = analyzer.Simplify(lhs.touched[0].min()); + PrimExpr lhs_max = analyzer.Simplify(lhs.touched[0].max()); + PrimExpr rhs_min = analyzer.Simplify(rhs.touched[0].min()); + PrimExpr rhs_max = analyzer.Simplify(rhs.touched[0].max()); + for (unsigned idx = 0; idx != 3; ++idx) { + auto &info = thread_vars[idx]; + Var old_prev_var = lhs.threads[lhs.threads.size() + idx - 3]->var; + Var old_curr_var = rhs.threads[rhs.threads.size() + idx - 3]->var; + Var prev_var(info.name_prev, old_prev_var.dtype()); + Var curr_var(info.name_curr, old_curr_var.dtype()); + lhs_min = Substitute(lhs_min, {{old_prev_var, prev_var}}); + lhs_max = Substitute(lhs_max, {{old_prev_var, prev_var}}); + prev_cset = prev_cset.Substitute({{old_prev_var, prev_var}}); + rhs_min = Substitute(rhs_min, {{old_curr_var, curr_var}}); + rhs_max = Substitute(rhs_max, {{old_curr_var, curr_var}}); + curr_cset = curr_cset.Substitute({{old_curr_var, curr_var}}); + } + prev_cset.Populate(analyzer); + curr_cset.Populate(analyzer); + + if (analyzer.CanProve(lhs_max < rhs_min, + arith::ProofStrength::kSymbolicBound)) { + return true; + } + if (analyzer.CanProve(rhs_max < lhs_min, + arith::ProofStrength::kSymbolicBound)) { + return true; + } + return false; + } + void print_access_tentry(const AccessEntry &access, + bool print_constr = false) { std::ostringstream output; output << "Access Entry Information:\n"; @@ -1122,7 +1169,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } output << "]\n"; - { + if (print_constr) { output << " Constraint: {"; arith::Analyzer analyzer_; access.cset.Populate(analyzer_); @@ -1193,20 +1240,18 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { // They are not the same indices, should be conflict. return true; } - // if (prev.is_pointer_access || curr.is_pointer_access) { - // // For accesses created via tvm_access_ptr we may still be able to - // prove - // // disjointness using their byte ranges. If both sides expose a - // touched - // // interval and we can show they don't overlap, skip the conflict. - // if (prev.is_pointer_access && curr.is_pointer_access && - // PointerAccessIsDisjoint(prev, curr)) { - // return false; - // } - // // Otherwise fall back to the conservative answer: treat them as - // // overlapping. - // return true; - // } + if (prev.is_pointer_access || curr.is_pointer_access) { + // For accesses created via tvm_access_ptr we may still be able to prove + // disjointness using their byte ranges. If both sides expose a touched + // interval and we can show they don't overlap, skip the conflict. + if (prev.is_pointer_access && curr.is_pointer_access && + PointerAccessIsDisjoint(prev, curr)) { + return false; + } + // Otherwise fall back to the conservative answer: treat them as + // overlapping. + return true; + } for (size_t i = 0; i < prev.buffer_indices.size(); i++) { auto prev_dtype = prev.dtype; @@ -1240,7 +1285,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { Var prev_var(info.name_prev, old_prev_var.dtype()); Var curr_var(info.name_curr, old_curr_var.dtype()); thread_condition = - tir::Or(thread_condition, tir::Not(tir::EQ(prev_var, curr_var))); + tir::Or(thread_condition, tir::NE(prev_var, curr_var)); prev_indice_bytes = Substitute(prev_indice_bytes, {{old_prev_var, prev_var}}); prev_cset = prev_cset.Substitute({{old_prev_var, prev_var}}); @@ -1249,13 +1294,26 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { curr_cset = curr_cset.Substitute({{old_curr_var, curr_var}}); } analyzer.EnterConstraint(thread_condition); + prev_indice_bytes = analyzer.Simplify(prev_indice_bytes); + curr_indice_bytes = analyzer.Simplify(curr_indice_bytes); prev_cset.Populate(analyzer); curr_cset.Populate(analyzer); bool provably_disjoint = false; if (prev_indice_bytes.dtype().is_scalar() && curr_indice_bytes.dtype().is_scalar()) { - provably_disjoint = analyzer.CanProve( - tir::Not(tir::EQ(prev_indice_bytes, curr_indice_bytes))); + if (prev_indice_bytes.dtype() != curr_indice_bytes.dtype()) { + if (prev_indice_bytes.dtype().bits() < + curr_indice_bytes.dtype().bits()) { + prev_indice_bytes = + tir::Cast(curr_indice_bytes.dtype(), prev_indice_bytes); + } else { + curr_indice_bytes = + tir::Cast(prev_indice_bytes.dtype(), curr_indice_bytes); + } + } + ICHECK(prev_indice_bytes.dtype() == curr_indice_bytes.dtype()); + provably_disjoint = + analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); } else { auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); From 7d91bca8ce760283716ba861041b57c0cb2b1599 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Fri, 9 Jan 2026 14:41:03 +0800 Subject: [PATCH 10/27] bugfix --- src/transform/common/constr_visitor.h | 40 ++++ src/transform/thread_storage_sync.cc | 181 +++++++++++------- .../test_tilelang_transform_thread_sync.py | 22 ++- 3 files changed, 171 insertions(+), 72 deletions(-) diff --git a/src/transform/common/constr_visitor.h b/src/transform/common/constr_visitor.h index 99e24d600..24f1a7a38 100644 --- a/src/transform/common/constr_visitor.h +++ b/src/transform/common/constr_visitor.h @@ -8,6 +8,7 @@ #include "tvm/tir/op.h" #include "tvm/tir/stmt.h" #include "tvm/tir/var.h" +#include #include #include #include @@ -40,6 +41,31 @@ struct Constr { Constr(Constr &&other) = default; Constr &operator=(const Constr &other) = default; + void format(std::ostream &os) const { + os << "Constr(kind="; + switch (kind) { + case kConstr: + os << "kConstr"; + os << ", is_assume=" << (is_assume ? "true" : "false"); + os << ", value=" << value; + break; + case kBindValue: + os << "kBindValue"; + os << ", var=" << var->name_hint; + os << ", value=" << value; + break; + case kBindRange: + os << "kBindRange"; + os << ", var=" << var->name_hint; + os << ", range=Range(min=" << range->min; + os << ", extent=" << range->extent << ")"; + break; + default: + os << "Unknown"; + } + os << ")"; + } + PrimExpr ToGenericConstr() const { switch (kind) { case kConstr: @@ -98,6 +124,17 @@ struct ConstrSet { constrs_.push_back(c); } } + + void format(std::ostream &os) const { + os << "ConstrSet(size=" << constrs_.size() << ") {\n"; + for (size_t i = 0; i < constrs_.size(); ++i) { + os << " [" << i << "] "; + constrs_[i].format(os); + os << "\n"; + } + os << "}"; + } + std::vector constrs_; }; @@ -183,6 +220,9 @@ struct ConstrVisitor : public tir::StmtExprVisitor { auto guard_2 = MakeGuard(op->extent > 0); Base::VisitStmt_(op); } else { + auto guard_1 = + MakeGuard(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + auto guard_2 = MakeGuard(op->extent > 0); Base::VisitStmt_(op); } } diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index d4fab90ea..f66933226 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -52,6 +52,7 @@ #include #include +#include #include #include @@ -455,6 +456,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { /*! \brief The buffer ranges for pointer access */ Array buffer_ranges; Var buffer = NullValue(); + Buffer buffer_name; /*! \brief The access data type */ DataType dtype; /*! \brief The touched access range @@ -497,6 +499,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { AccessEntry e{.cset = {constr_stack_}}; e.threads = env_threads(); e.buffer = buf; + e.buffer_name = op->buffer; e.buffer_indices = op->indices; e.dtype = op->dtype.element_of(); for (const auto &index : op->indices) { @@ -521,6 +524,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { AccessEntry e{.cset = {constr_stack_}}; e.threads = env_threads(); e.buffer = buf; + e.buffer_name = op->buffer; e.buffer_indices = op->indices; e.dtype = op->value.dtype().element_of(); for (const auto &index : op->indices) { @@ -774,6 +778,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { e.threads = env_threads(); e.dtype = dtype; e.buffer = Downcast(buffer->data); + e.buffer_name = buffer; e.buffer_ranges = buffer_ranges; for (const auto &index : load->indices) { e.touched.push_back(arith::IntSet::Vector(index)); @@ -1134,6 +1139,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { output << "Access Entry Information:\n"; output << " Buffer: " << access.buffer << "\n"; + output << " Buffer Name: " << access.buffer_name << "\n"; output << " Data Type: " << access.dtype << "\n"; std::string type_str; @@ -1240,6 +1246,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { // They are not the same indices, should be conflict. return true; } + if (prev.is_pointer_access || curr.is_pointer_access) { // For accesses created via tvm_access_ptr we may still be able to prove // disjointness using their byte ranges. If both sides expose a touched @@ -1260,80 +1267,119 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { const auto &prev_indice = prev.buffer_indices[i]; const auto &curr_indice = curr.buffer_indices[i]; - PrimExpr prev_indice_bytes = prev_indice * prev_dtype.bytes(); - PrimExpr curr_indice_bytes = curr_indice * curr_dtype.bytes(); - - has_same_index = false; - - ConstrSet prev_cset{prev.cset}; - ConstrSet curr_cset{curr.cset}; - arith::Analyzer analyzer; - - struct ThreadVarInfo { - const char *name_prev; - const char *name_curr; - } thread_vars[] = { - {"tx1", "tx2"}, - {"ty1", "ty2"}, - {"tz1", "tz2"}, - }; - PrimExpr thread_condition = Bool(false); - for (unsigned idx = 0; idx != 3; ++idx) { - auto &info = thread_vars[idx]; - Var old_prev_var = prev.threads[prev.threads.size() + idx - 3]->var; - Var old_curr_var = curr.threads[curr.threads.size() + idx - 3]->var; - Var prev_var(info.name_prev, old_prev_var.dtype()); - Var curr_var(info.name_curr, old_curr_var.dtype()); - thread_condition = - tir::Or(thread_condition, tir::NE(prev_var, curr_var)); - prev_indice_bytes = - Substitute(prev_indice_bytes, {{old_prev_var, prev_var}}); - prev_cset = prev_cset.Substitute({{old_prev_var, prev_var}}); - curr_indice_bytes = - Substitute(curr_indice_bytes, {{old_curr_var, curr_var}}); - curr_cset = curr_cset.Substitute({{old_curr_var, curr_var}}); - } - analyzer.EnterConstraint(thread_condition); - prev_indice_bytes = analyzer.Simplify(prev_indice_bytes); - curr_indice_bytes = analyzer.Simplify(curr_indice_bytes); - prev_cset.Populate(analyzer); - curr_cset.Populate(analyzer); - bool provably_disjoint = false; - if (prev_indice_bytes.dtype().is_scalar() && - curr_indice_bytes.dtype().is_scalar()) { - if (prev_indice_bytes.dtype() != curr_indice_bytes.dtype()) { - if (prev_indice_bytes.dtype().bits() < - curr_indice_bytes.dtype().bits()) { - prev_indice_bytes = - tir::Cast(curr_indice_bytes.dtype(), prev_indice_bytes); - } else { - curr_indice_bytes = - tir::Cast(prev_indice_bytes.dtype(), curr_indice_bytes); - } + if (!ExprDeepEqual()(prev_indice, curr_indice)) { + + PrimExpr prev_indice_bytes = prev_indice * prev_dtype.bytes(); + PrimExpr curr_indice_bytes = curr_indice * curr_dtype.bytes(); + + has_same_index = false; + + ConstrSet prev_cset{prev.cset}; + ConstrSet curr_cset{curr.cset}; + arith::Analyzer analyzer; + + struct ThreadVarInfo { + const char *name_prev; + const char *name_curr; + } thread_vars[] = { + {"tx1", "tx2"}, + {"ty1", "ty2"}, + {"tz1", "tz2"}, + }; + PrimExpr thread_condition = Bool(false); + ffi::Map prev_sub, curr_sub; + for (unsigned idx = 0; idx != 3; ++idx) { + auto &info = thread_vars[idx]; + Var old_prev_var = prev.threads[prev.threads.size() + idx - 3]->var; + Var old_curr_var = curr.threads[curr.threads.size() + idx - 3]->var; + Var prev_var(info.name_prev, old_prev_var.dtype()); + Var curr_var(info.name_curr, old_curr_var.dtype()); + thread_condition = + tir::Or(thread_condition, tir::NE(prev_var, curr_var)); + prev_sub.Set(old_prev_var, prev_var); + curr_sub.Set(old_curr_var, curr_var); } - ICHECK(prev_indice_bytes.dtype() == curr_indice_bytes.dtype()); - provably_disjoint = - analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); - } else { - auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); - auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); - if (prev_bound.defined() && curr_bound.defined()) { - if ((prev_bound->min_value) > (curr_bound->max_value) || - (curr_bound->min_value) > (prev_bound->max_value)) { - range_is_overlap = false; - break; + analyzer.EnterConstraint(thread_condition); + prev_cset.Substitute(prev_sub).Populate(analyzer); + curr_cset.Substitute(curr_sub).Populate(analyzer); + bool provably_disjoint = false; + if (prev_indice_bytes.dtype().is_scalar() && + curr_indice_bytes.dtype().is_scalar()) { + prev_indice_bytes = + analyzer.Simplify(Substitute(prev_indice_bytes, prev_sub)); + curr_indice_bytes = + analyzer.Simplify(Substitute(curr_indice_bytes, curr_sub)); + if (prev_indice_bytes.dtype() != curr_indice_bytes.dtype()) { + if (prev_indice_bytes.dtype().bits() < + curr_indice_bytes.dtype().bits()) { + prev_indice_bytes = + tir::Cast(curr_indice_bytes.dtype(), prev_indice_bytes); + } else { + curr_indice_bytes = + tir::Cast(prev_indice_bytes.dtype(), curr_indice_bytes); + } } + ICHECK(prev_indice_bytes.dtype() == curr_indice_bytes.dtype()); + provably_disjoint = + analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); + } else { + auto prev_min = analyzer.Simplify( + Substitute(prev.touched[i].min() * prev_dtype.bytes(), prev_sub)); + auto prev_max = analyzer.Simplify( + Substitute(prev.touched[i].max() * prev_dtype.bytes(), prev_sub)); + auto curr_min = analyzer.Simplify( + Substitute(curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); + auto curr_max = analyzer.Simplify( + Substitute(curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); + // analyzer.z3_prover.SetRLimit(100000000); + provably_disjoint = analyzer.CanProve(analyzer.Simplify( + tir::Or(prev_min > curr_max, curr_min > prev_max))); + // if (!provably_disjoint) { + // LOG(WARNING) << analyzer.z3_prover.GetStats(); + // LOG(WARNING) << + // analyzer.z3_prover.GetSMTLIB2(tir::Not(tir::Or(prev_min > + // curr_max, curr_min > prev_max))); + // } + // auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); + // auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); + // if (prev_bound.defined() && curr_bound.defined()) { + // if ((prev_bound->min_value) > (curr_bound->max_value) || + // (curr_bound->min_value) > (prev_bound->max_value)) { + // range_is_overlap = false; + // break; + // } + // } + } + + if (provably_disjoint) { + range_is_overlap = false; + break; } - } - if (provably_disjoint) { - range_is_overlap = false; - break; + if (!has_same_index) { + break; + } } + } - if (!has_same_index) { - break; + if (has_same_index) { + bool range_is_equal = true; + arith::Analyzer prev_analyzer, curr_analyer; + prev.cset.Populate(prev_analyzer); + curr.cset.Populate(curr_analyer); + for (unsigned idx = 0; idx != 3; ++idx) { + Var prev_var = prev.threads[prev.threads.size() + idx - 3]->var; + Var curr_var = curr.threads[curr.threads.size() + idx - 3]->var; + auto prev_bound = prev_analyzer.const_int_bound(prev_var); + auto curr_bound = curr_analyer.const_int_bound(curr_var); + if (prev_bound->min_value != curr_bound->min_value || + prev_bound->max_value != curr_bound->max_value) { + range_is_equal = false; + break; + } } + if (range_is_equal) + return false; } // If this is a read into a double buffer that was previously @@ -1372,7 +1418,6 @@ PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) { planner.SetBufferDataToBuffer(buffer->data, buffer); } planner(stmt); - stmt = ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); n->body = ThreadPartialSyncRewriter::Rewrite(std::move(stmt)); diff --git a/testing/python/transform/test_tilelang_transform_thread_sync.py b/testing/python/transform/test_tilelang_transform_thread_sync.py index 046ed447a..024c3026b 100644 --- a/testing/python/transform/test_tilelang_transform_thread_sync.py +++ b/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -30,6 +30,8 @@ def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared") T.launch_thread(blockIdx_x, 8) T.launch_thread(threadIdx_x, 4) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) result_local[0] = T.float32(0) if threadIdx_y < 8: temp_shared[threadIdx_x] = p0[0] @@ -51,6 +53,8 @@ def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared") T.launch_thread(blockIdx_x, 8) T.launch_thread(threadIdx_x, 4) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) result_local[0] = T.float32(0) if threadIdx_x < 1: temp_shared[0] = p0[0] @@ -72,6 +76,8 @@ def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): C = T.allocate([1], "float32", "local") D = T.allocate([16], "float32", "shared") threadIdx_x = T.launch_thread("threadIdx.x", 16) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) B_1 = T.Buffer((24,), data=B, scope="shared") A_1 = T.Buffer((16,), data=A.data) B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x] @@ -89,6 +95,8 @@ def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): C_1 = T.allocate([1], "float32", "local") D_1 = T.allocate([16], "float32", "shared") threadIdx_x = T.launch_thread("threadIdx.x", 16) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) B_1_1 = T.Buffer((24,), data=B_1, scope="shared") A_1 = T.Buffer((16,), data=A.data) B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x] @@ -113,9 +121,11 @@ def func(A: T.Buffer((16 * 512), "float32")): in_thread_A_temp = T.allocate([1], "float32", "local") cross_thread_A_temp = T.allocate([1], "float32", "local") threadIdx_x = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared") - for ax0 in range(512): - A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] + ax0 = threadIdx_x + A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, scope="local") in_thread_A_temp_1[0] = T.float32(0) with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: @@ -147,9 +157,11 @@ def expected(A: T.Buffer((8192,), "float32")): in_thread_A_temp_1 = T.allocate([1], "float32", "local") cross_thread_A_temp_1 = T.allocate([1], "float32", "local") threadIdx_x = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) A_shared_1_1 = T.Buffer((512,), data=A_shared_1, scope="shared") - for ax0 in range(512): - A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0] + ax0 = threadIdx_x + A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1_1 = T.Buffer((1,), data=in_thread_A_temp_1, scope="local") in_thread_A_temp_1_1[0] = T.float32(0) T.tvm_storage_sync("shared") @@ -186,6 +198,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist(): def func(): buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn") tx = T.launch_thread("threadIdx.x", 384) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) for i in T.unroll(8): off = ( i // 4 * 8192 From d05b15ee361af8851bcdb6c722ba13d67a521a9a Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Fri, 9 Jan 2026 14:52:18 +0800 Subject: [PATCH 11/27] trigger ci From f9f8e8498eac03bd16b48f2b48ef9b963101d3ed Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Fri, 9 Jan 2026 15:08:53 +0800 Subject: [PATCH 12/27] add try for checking --- src/transform/thread_storage_sync.cc | 47 ++++++++++++++++------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index f66933226..1cfdf76b0 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1322,33 +1322,40 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { ICHECK(prev_indice_bytes.dtype() == curr_indice_bytes.dtype()); provably_disjoint = analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); + // if (!provably_disjoint) { + // LOG(WARNING) << analyzer.z3_prover.GetModel( + // tir::EQ(prev_indice_bytes, curr_indice_bytes)); + // } } else { - auto prev_min = analyzer.Simplify( - Substitute(prev.touched[i].min() * prev_dtype.bytes(), prev_sub)); - auto prev_max = analyzer.Simplify( - Substitute(prev.touched[i].max() * prev_dtype.bytes(), prev_sub)); - auto curr_min = analyzer.Simplify( - Substitute(curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); - auto curr_max = analyzer.Simplify( - Substitute(curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); - // analyzer.z3_prover.SetRLimit(100000000); - provably_disjoint = analyzer.CanProve(analyzer.Simplify( - tir::Or(prev_min > curr_max, curr_min > prev_max))); + try { + auto prev_min = analyzer.Simplify(Substitute( + prev.touched[i].min() * prev_dtype.bytes(), prev_sub)); + auto prev_max = analyzer.Simplify(Substitute( + prev.touched[i].max() * prev_dtype.bytes(), prev_sub)); + auto curr_min = analyzer.Simplify(Substitute( + curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); + auto curr_max = analyzer.Simplify(Substitute( + curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); + // analyzer.z3_prover.SetRLimit(100000000); + provably_disjoint = analyzer.CanProve(analyzer.Simplify( + tir::Or(prev_min > curr_max, curr_min > prev_max))); + } catch (...) { + auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); + auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); + if (prev_bound.defined() && curr_bound.defined()) { + if ((prev_bound->min_value) > (curr_bound->max_value) || + (curr_bound->min_value) > (prev_bound->max_value)) { + range_is_overlap = false; + break; + } + } + } // if (!provably_disjoint) { // LOG(WARNING) << analyzer.z3_prover.GetStats(); // LOG(WARNING) << // analyzer.z3_prover.GetSMTLIB2(tir::Not(tir::Or(prev_min > // curr_max, curr_min > prev_max))); // } - // auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); - // auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); - // if (prev_bound.defined() && curr_bound.defined()) { - // if ((prev_bound->min_value) > (curr_bound->max_value) || - // (curr_bound->min_value) > (prev_bound->max_value)) { - // range_is_overlap = false; - // break; - // } - // } } if (provably_disjoint) { From 00f0fae4e084b1c0e013704328c95cdba8175f0b Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Fri, 9 Jan 2026 15:34:33 +0800 Subject: [PATCH 13/27] bugfix --- src/transform/thread_storage_sync.cc | 201 ++++++++++++++------------- 1 file changed, 103 insertions(+), 98 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 1cfdf76b0..cdb97fc9c 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1268,107 +1268,10 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { const auto &curr_indice = curr.buffer_indices[i]; if (!ExprDeepEqual()(prev_indice, curr_indice)) { - - PrimExpr prev_indice_bytes = prev_indice * prev_dtype.bytes(); - PrimExpr curr_indice_bytes = curr_indice * curr_dtype.bytes(); - has_same_index = false; - - ConstrSet prev_cset{prev.cset}; - ConstrSet curr_cset{curr.cset}; - arith::Analyzer analyzer; - - struct ThreadVarInfo { - const char *name_prev; - const char *name_curr; - } thread_vars[] = { - {"tx1", "tx2"}, - {"ty1", "ty2"}, - {"tz1", "tz2"}, - }; - PrimExpr thread_condition = Bool(false); - ffi::Map prev_sub, curr_sub; - for (unsigned idx = 0; idx != 3; ++idx) { - auto &info = thread_vars[idx]; - Var old_prev_var = prev.threads[prev.threads.size() + idx - 3]->var; - Var old_curr_var = curr.threads[curr.threads.size() + idx - 3]->var; - Var prev_var(info.name_prev, old_prev_var.dtype()); - Var curr_var(info.name_curr, old_curr_var.dtype()); - thread_condition = - tir::Or(thread_condition, tir::NE(prev_var, curr_var)); - prev_sub.Set(old_prev_var, prev_var); - curr_sub.Set(old_curr_var, curr_var); - } - analyzer.EnterConstraint(thread_condition); - prev_cset.Substitute(prev_sub).Populate(analyzer); - curr_cset.Substitute(curr_sub).Populate(analyzer); - bool provably_disjoint = false; - if (prev_indice_bytes.dtype().is_scalar() && - curr_indice_bytes.dtype().is_scalar()) { - prev_indice_bytes = - analyzer.Simplify(Substitute(prev_indice_bytes, prev_sub)); - curr_indice_bytes = - analyzer.Simplify(Substitute(curr_indice_bytes, curr_sub)); - if (prev_indice_bytes.dtype() != curr_indice_bytes.dtype()) { - if (prev_indice_bytes.dtype().bits() < - curr_indice_bytes.dtype().bits()) { - prev_indice_bytes = - tir::Cast(curr_indice_bytes.dtype(), prev_indice_bytes); - } else { - curr_indice_bytes = - tir::Cast(prev_indice_bytes.dtype(), curr_indice_bytes); - } - } - ICHECK(prev_indice_bytes.dtype() == curr_indice_bytes.dtype()); - provably_disjoint = - analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); - // if (!provably_disjoint) { - // LOG(WARNING) << analyzer.z3_prover.GetModel( - // tir::EQ(prev_indice_bytes, curr_indice_bytes)); - // } - } else { - try { - auto prev_min = analyzer.Simplify(Substitute( - prev.touched[i].min() * prev_dtype.bytes(), prev_sub)); - auto prev_max = analyzer.Simplify(Substitute( - prev.touched[i].max() * prev_dtype.bytes(), prev_sub)); - auto curr_min = analyzer.Simplify(Substitute( - curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); - auto curr_max = analyzer.Simplify(Substitute( - curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); - // analyzer.z3_prover.SetRLimit(100000000); - provably_disjoint = analyzer.CanProve(analyzer.Simplify( - tir::Or(prev_min > curr_max, curr_min > prev_max))); - } catch (...) { - auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); - auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); - if (prev_bound.defined() && curr_bound.defined()) { - if ((prev_bound->min_value) > (curr_bound->max_value) || - (curr_bound->min_value) > (prev_bound->max_value)) { - range_is_overlap = false; - break; - } - } - } - // if (!provably_disjoint) { - // LOG(WARNING) << analyzer.z3_prover.GetStats(); - // LOG(WARNING) << - // analyzer.z3_prover.GetSMTLIB2(tir::Not(tir::Or(prev_min > - // curr_max, curr_min > prev_max))); - // } - } - - if (provably_disjoint) { - range_is_overlap = false; - break; - } - - if (!has_same_index) { - break; - } + break; } } - if (has_same_index) { bool range_is_equal = true; arith::Analyzer prev_analyzer, curr_analyer; @@ -1389,6 +1292,108 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { return false; } + for (size_t i = 0; i < prev.buffer_indices.size(); i++) { + auto prev_dtype = prev.dtype; + auto curr_dtype = curr.dtype; + + const auto &prev_indice = prev.buffer_indices[i]; + const auto &curr_indice = curr.buffer_indices[i]; + + PrimExpr prev_indice_bytes = prev_indice * prev_dtype.bytes(); + PrimExpr curr_indice_bytes = curr_indice * curr_dtype.bytes(); + + has_same_index = false; + + ConstrSet prev_cset{prev.cset}; + ConstrSet curr_cset{curr.cset}; + arith::Analyzer analyzer; + + struct ThreadVarInfo { + const char *name_prev; + const char *name_curr; + } thread_vars[] = { + {"tx1", "tx2"}, + {"ty1", "ty2"}, + {"tz1", "tz2"}, + }; + PrimExpr thread_condition = Bool(false); + ffi::Map prev_sub, curr_sub; + for (unsigned idx = 0; idx != 3; ++idx) { + auto &info = thread_vars[idx]; + Var old_prev_var = prev.threads[prev.threads.size() + idx - 3]->var; + Var old_curr_var = curr.threads[curr.threads.size() + idx - 3]->var; + Var prev_var(info.name_prev, old_prev_var.dtype()); + Var curr_var(info.name_curr, old_curr_var.dtype()); + thread_condition = + tir::Or(thread_condition, tir::NE(prev_var, curr_var)); + prev_sub.Set(old_prev_var, prev_var); + curr_sub.Set(old_curr_var, curr_var); + } + analyzer.EnterConstraint(thread_condition); + prev_cset.Substitute(prev_sub).Populate(analyzer); + curr_cset.Substitute(curr_sub).Populate(analyzer); + bool provably_disjoint = false; + if (prev_indice_bytes.dtype().is_scalar() && + curr_indice_bytes.dtype().is_scalar()) { + prev_indice_bytes = + analyzer.Simplify(Substitute(prev_indice_bytes, prev_sub)); + curr_indice_bytes = + analyzer.Simplify(Substitute(curr_indice_bytes, curr_sub)); + if (prev_indice_bytes.dtype() != curr_indice_bytes.dtype()) { + if (prev_indice_bytes.dtype().bits() < + curr_indice_bytes.dtype().bits()) { + prev_indice_bytes = + tir::Cast(curr_indice_bytes.dtype(), prev_indice_bytes); + } else { + curr_indice_bytes = + tir::Cast(prev_indice_bytes.dtype(), curr_indice_bytes); + } + } + ICHECK(prev_indice_bytes.dtype() == curr_indice_bytes.dtype()); + provably_disjoint = + analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); + if (!provably_disjoint) { + LOG(WARNING) << analyzer.z3_prover.GetModel( + tir::EQ(prev_indice_bytes, curr_indice_bytes)); + } + } else { + try { + auto prev_min = analyzer.Simplify( + Substitute(prev.touched[i].min() * prev_dtype.bytes(), prev_sub)); + auto prev_max = analyzer.Simplify( + Substitute(prev.touched[i].max() * prev_dtype.bytes(), prev_sub)); + auto curr_min = analyzer.Simplify( + Substitute(curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); + auto curr_max = analyzer.Simplify( + Substitute(curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); + // analyzer.z3_prover.SetRLimit(100000000); + provably_disjoint = analyzer.CanProve(analyzer.Simplify( + tir::Or(prev_min > curr_max, curr_min > prev_max))); + } catch (...) { + auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); + auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); + if (prev_bound.defined() && curr_bound.defined()) { + if ((prev_bound->min_value) > (curr_bound->max_value) || + (curr_bound->min_value) > (prev_bound->max_value)) { + range_is_overlap = false; + break; + } + } + } + // if (!provably_disjoint) { + // LOG(WARNING) << analyzer.z3_prover.GetStats(); + // LOG(WARNING) << + // analyzer.z3_prover.GetSMTLIB2(tir::Not(tir::Or(prev_min > + // curr_max, curr_min > prev_max))); + // } + } + + if (provably_disjoint) { + range_is_overlap = false; + break; + } + } + // If this is a read into a double buffer that was previously // swapped out, then it doesn't conflict. if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { From 1aa8bd40f1a0ab8a233c91b15316670a19228a48 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Fri, 9 Jan 2026 15:44:00 +0800 Subject: [PATCH 14/27] update tvm with better analyzer --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 001022bdb..65ae814ba 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 001022bdb2dbb337d242eed9d208f8555b8edc98 +Subproject commit 65ae814bab4df9f181a820f198efdf321826cce3 From c8e3f6b9c70127a67b52b08aab75b65dd4f190da Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 11:11:10 +0800 Subject: [PATCH 15/27] update tvm --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 65ae814ba..b82c74d2e 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 65ae814bab4df9f181a820f198efdf321826cce3 +Subproject commit b82c74d2e2c614b80bfb41d4b9364f26bd03b004 From 9c5341be9d2be0a6b976c4ae6fb3a33362d64f27 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 11:14:06 +0800 Subject: [PATCH 16/27] Log for debugging in thread_storage_sync --- src/transform/thread_storage_sync.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index cdb97fc9c..f355307ee 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1369,7 +1369,9 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { // analyzer.z3_prover.SetRLimit(100000000); provably_disjoint = analyzer.CanProve(analyzer.Simplify( tir::Or(prev_min > curr_max, curr_min > prev_max))); - } catch (...) { + } catch (const std::exception& e) { + // Log for debugging; fall back to conservative bound check + LOG(WARNING) << "Exception in conflict detection: " << e.what(); auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); auto curr_bound = analyzer.const_int_bound(curr_indice_bytes); if (prev_bound.defined() && curr_bound.defined()) { From 00db1ff45e981dbb9362e9dabe75bf4ab17936a3 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 11:15:12 +0800 Subject: [PATCH 17/27] format --- src/transform/thread_storage_sync.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index f355307ee..c8a4310b6 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1369,7 +1369,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { // analyzer.z3_prover.SetRLimit(100000000); provably_disjoint = analyzer.CanProve(analyzer.Simplify( tir::Or(prev_min > curr_max, curr_min > prev_max))); - } catch (const std::exception& e) { + } catch (const std::exception &e) { // Log for debugging; fall back to conservative bound check LOG(WARNING) << "Exception in conflict detection: " << e.what(); auto prev_bound = analyzer.const_int_bound(prev_indice_bytes); From d61093ddfded3542d2aa61a8c5830fd9c83a6bb1 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 11:24:44 +0800 Subject: [PATCH 18/27] consider let node --- src/transform/thread_storage_sync.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index c8a4310b6..25c223024 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -566,7 +566,7 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { curr_stmt_.access.clear(); allow_append_ = false; // traverse body block - this->VisitStmt(op->body); + ConstrVisitor::VisitStmt_(op); } void VisitStmt_(const BlockNode *op) final { auto block = Downcast(op); From 1b5ddeffce585d8b15e56a29def77e22c1f15475 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 11:25:38 +0800 Subject: [PATCH 19/27] typo --- src/transform/thread_storage_sync.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 25c223024..742dc33c2 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1274,14 +1274,14 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } if (has_same_index) { bool range_is_equal = true; - arith::Analyzer prev_analyzer, curr_analyer; + arith::Analyzer prev_analyzer, curr_analyzer; prev.cset.Populate(prev_analyzer); - curr.cset.Populate(curr_analyer); + curr.cset.Populate(curr_analyzer); for (unsigned idx = 0; idx != 3; ++idx) { Var prev_var = prev.threads[prev.threads.size() + idx - 3]->var; Var curr_var = curr.threads[curr.threads.size() + idx - 3]->var; auto prev_bound = prev_analyzer.const_int_bound(prev_var); - auto curr_bound = curr_analyer.const_int_bound(curr_var); + auto curr_bound = curr_analyzer.const_int_bound(curr_var); if (prev_bound->min_value != curr_bound->min_value || prev_bound->max_value != curr_bound->max_value) { range_is_equal = false; From a8de275e49bc267bc10f659fe9af378901efbf4f Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 11:26:02 +0800 Subject: [PATCH 20/27] remove debugging log --- src/transform/thread_storage_sync.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 742dc33c2..7b99c0556 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1353,8 +1353,8 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { provably_disjoint = analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); if (!provably_disjoint) { - LOG(WARNING) << analyzer.z3_prover.GetModel( - tir::EQ(prev_indice_bytes, curr_indice_bytes)); + // LOG(WARNING) << analyzer.z3_prover.GetModel( + // tir::EQ(prev_indice_bytes, curr_indice_bytes)); } } else { try { From 35a8c8e26273bbb56c757c7128cf3e328cc265b0 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 11:51:37 +0800 Subject: [PATCH 21/27] consider while op in thread_sync & fix bug sbout letnode --- src/transform/thread_storage_sync.cc | 30 +++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 7b99c0556..30408df7e 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -566,7 +566,10 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { curr_stmt_.access.clear(); allow_append_ = false; // traverse body block - ConstrVisitor::VisitStmt_(op); + { + auto guard = MakeGuard(op->var, op->value); + this->VisitStmt(op->body); + } } void VisitStmt_(const BlockNode *op) final { auto block = Downcast(op); @@ -725,13 +728,26 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { } void VisitStmt_(const WhileNode *op) final { - this->VisitExpr(op->condition); - scope_.push_back(std::vector()); - this->VisitStmt(op->body); StmtEntry s; - s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); - scope_.pop_back(); + { + auto guard = MakeGuard(op->condition); + allow_append_ = true; + this->VisitExpr(op->condition); + std::vector cond_access = std::move(curr_stmt_.access); + allow_append_ = false; + + scope_.push_back(std::vector()); + { + this->VisitStmt(op->body); + } + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + if (!cond_access.empty()) { + s.access.insert(s.access.begin(), cond_access.begin(), + cond_access.end()); + } + } scope_.back().emplace_back(std::move(s)); } From 2d7db7d0dcac7e366a6d9d7fe7098a6a6e2fd43d Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 12:57:57 +0800 Subject: [PATCH 22/27] add testing for thread_sync issues --- .../python/issue/test_tilelang_issue_1026.py | 26 +++++++++++++ .../python/issue/test_tilelang_issue_1106.py | 38 +++++++++++++++++++ .../python/issue/test_tilelang_issue_1604.py | 36 ++++++++++++++++++ 3 files changed, 100 insertions(+) create mode 100644 testing/python/issue/test_tilelang_issue_1026.py create mode 100644 testing/python/issue/test_tilelang_issue_1106.py create mode 100644 testing/python/issue/test_tilelang_issue_1604.py diff --git a/testing/python/issue/test_tilelang_issue_1026.py b/testing/python/issue/test_tilelang_issue_1026.py new file mode 100644 index 000000000..07bd86292 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1026.py @@ -0,0 +1,26 @@ +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit +def get_shared_kernel(): + @T.prim_func + def shared_kernel(): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + shared_mem = T.alloc_shared((32), dtype="float32", scope="shared") + if tx % 2 == 0: + a = shared_mem[tx] + shared_mem[tx ^ 1] = a + + return shared_kernel + + +def test_issue_1026(): + kernel = get_shared_kernel() + assert "__syncthreads" not in kernel.get_kernel_source() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/issue/test_tilelang_issue_1106.py b/testing/python/issue/test_tilelang_issue_1106.py new file mode 100644 index 000000000..2669acbd4 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1106.py @@ -0,0 +1,38 @@ +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, +) +def get_kernel(m: int): + dtype = "int32" + + @T.prim_func + def test_kernel(a: T.Tensor[(m,), dtype], b: T.Tensor[(m,), dtype]): + with T.Kernel(1, threads=64) as (bx): + shared = T.alloc_shared((64,), dtype) + tx = T.get_thread_binding(0) + tid = tx + bx * 64 + + for i in T.serial((m // 2 - tx) // 64 + 1): + for j in T.vectorized(2): + shared[tx] += a[(i * 64 + tid) * 2 + j] + + b[tid] = shared[tx] + + return test_kernel + + +def test_issue_1106(): + m = 200 + kernel = get_kernel(m) + assert "__syncthreads" not in kernel.get_kernel_source() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/issue/test_tilelang_issue_1604.py b/testing/python/issue/test_tilelang_issue_1604.py new file mode 100644 index 000000000..215d479c7 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1604.py @@ -0,0 +1,36 @@ +import tilelang +import tilelang.testing +import tilelang.language as T +import re + + +@tilelang.jit +def qwq(): + dtype = "float32" + + @T.prim_func + def main(out: T.Tensor[(512,), dtype]): + with T.Kernel(1, threads=512): + A = T.alloc_shared((32,), dtype) + B = T.alloc_shared((32,), dtype) + + tid = T.get_thread_binding() + if tid < 32: + A[tid] = tid + B[tid] = tid + + out[tid] = A[tid % 32] + + return main + + +def test_issue_1604(): + kernel = qwq() + print(kernel.get_kernel_source()) + target = "__syncthreads" + pattern = r"if [^{]*{[^}]*\b" + re.escape(target) + r"\b[^}]*}" + assert len(re.findall(pattern, kernel.get_kernel_source())) == 0 + + +if __name__ == "__main__": + tilelang.testing.main() From c3991e98873ec6021d0e7f4b40d6b2aa0df38197 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 14:45:18 +0800 Subject: [PATCH 23/27] update testing for thread_sync --- .../transform/test_tilelang_transform_thread_sync.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/testing/python/transform/test_tilelang_transform_thread_sync.py b/testing/python/transform/test_tilelang_transform_thread_sync.py index 024c3026b..2e8b1c8c5 100644 --- a/testing/python/transform/test_tilelang_transform_thread_sync.py +++ b/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -124,8 +124,8 @@ def func(A: T.Buffer((16 * 512), "float32")): ty = T.launch_thread("threadIdx.y", 1) tz = T.launch_thread("threadIdx.z", 1) A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared") - ax0 = threadIdx_x - A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] + for ax0 in range(512): + A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, scope="local") in_thread_A_temp_1[0] = T.float32(0) with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: @@ -160,8 +160,8 @@ def expected(A: T.Buffer((8192,), "float32")): ty = T.launch_thread("threadIdx.y", 1) tz = T.launch_thread("threadIdx.z", 1) A_shared_1_1 = T.Buffer((512,), data=A_shared_1, scope="shared") - ax0 = threadIdx_x - A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0] + for ax0 in range(512): + A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1_1 = T.Buffer((1,), data=in_thread_A_temp_1, scope="local") in_thread_A_temp_1_1[0] = T.float32(0) T.tvm_storage_sync("shared") @@ -233,6 +233,5 @@ def func(): # Ensure the sync appears before the unrolled loop assert s.index('T.tvm_storage_sync("shared.dyn")') < s.index("for i in T.unroll(8)") - if __name__ == "__main__": tilelang.testing.main() From 6353023adc7e9027b4022f386834e4f90bbb15d6 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Mon, 12 Jan 2026 14:47:18 +0800 Subject: [PATCH 24/27] format --- testing/python/transform/test_tilelang_transform_thread_sync.py | 1 + 1 file changed, 1 insertion(+) diff --git a/testing/python/transform/test_tilelang_transform_thread_sync.py b/testing/python/transform/test_tilelang_transform_thread_sync.py index 2e8b1c8c5..8d94d9049 100644 --- a/testing/python/transform/test_tilelang_transform_thread_sync.py +++ b/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -233,5 +233,6 @@ def func(): # Ensure the sync appears before the unrolled loop assert s.index('T.tvm_storage_sync("shared.dyn")') < s.index("for i in T.unroll(8)") + if __name__ == "__main__": tilelang.testing.main() From f3d377f9a519148ef1525c2498b24c438209c69a Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Tue, 13 Jan 2026 10:43:49 +0800 Subject: [PATCH 25/27] consider T.Ramp expr --- src/transform/thread_storage_sync.cc | 30 ++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 30408df7e..f24375a8d 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1349,12 +1349,32 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { prev_cset.Substitute(prev_sub).Populate(analyzer); curr_cset.Substitute(curr_sub).Populate(analyzer); bool provably_disjoint = false; + + prev_indice_bytes = + analyzer.Simplify(Substitute(prev_indice_bytes, prev_sub)); + curr_indice_bytes = + analyzer.Simplify(Substitute(curr_indice_bytes, curr_sub)); + + // Handle Ramp expressions by creating a new index variable + // Check if prev_indice_bytes is a Ramp expression + if (const RampNode *prev_ramp = prev_indice_bytes.as()) { + // Create index variable for prev Ramp + Var prev_idx("prev_idx", DataType::Int(32)); + analyzer.Bind(prev_idx, Range::FromMinExtent(0, prev_ramp->lanes)); + prev_indice_bytes = prev_ramp->base + prev_idx * prev_ramp->stride; + } + + // Check if curr_indice_bytes is a Ramp expression + if (const RampNode *curr_ramp = curr_indice_bytes.as()) { + // Create index variable for curr Ramp + Var curr_idx("curr_idx", DataType::Int(32)); + analyzer.Bind(curr_idx, Range::FromMinExtent(0, curr_ramp->lanes)); + curr_indice_bytes = curr_ramp->base + curr_idx * curr_ramp->stride; + } + + // Now handle the simplified expressions if (prev_indice_bytes.dtype().is_scalar() && curr_indice_bytes.dtype().is_scalar()) { - prev_indice_bytes = - analyzer.Simplify(Substitute(prev_indice_bytes, prev_sub)); - curr_indice_bytes = - analyzer.Simplify(Substitute(curr_indice_bytes, curr_sub)); if (prev_indice_bytes.dtype() != curr_indice_bytes.dtype()) { if (prev_indice_bytes.dtype().bits() < curr_indice_bytes.dtype().bits()) { @@ -1373,6 +1393,8 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { // tir::EQ(prev_indice_bytes, curr_indice_bytes)); } } else { + LOG(WARNING) << "Unscalar: " << prev_indice_bytes << "; " + << curr_indice_bytes; try { auto prev_min = analyzer.Simplify( Substitute(prev.touched[i].min() * prev_dtype.bytes(), prev_sub)); From 12839adad10d654bb4ea05ff12b33e99f70b56d2 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Tue, 13 Jan 2026 10:59:34 +0800 Subject: [PATCH 26/27] remove useless header --- src/transform/thread_storage_sync.cc | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index f24375a8d..5b5fdc141 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -27,50 +27,34 @@ #include #include #include +#include "./common/thread_sync_types.h" +#include "arith/ir_mutator_with_analyzer.h" #include #include -#include #include "../op/builtin.h" #include "./common/constr_visitor.h" -#include "./common/thread_sync_types.h" -#include "arith/ir_mutator_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" #include #include -#include -#include -#include -#include - -#include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" +#include #include #include #include - -#include -#include -#include - -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h" +#include namespace tvm { namespace tl { using namespace tir; using namespace ffi; -using arith::IRVisitorWithAnalyzer; +using arith::IRMutatorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; -using namespace tir; -using arith::IRMutatorWithAnalyzer; - // There are cases where necessary syncthreads is not inserted by // ThreadSyncInserter. For example, syncthreads is needed after async_wait_queue // in the second loop below, but since ThreadSyncInserter is not aware of the From 4b41d4ca6d94ed9b4bc34ff3e669a6b3b88370d8 Mon Sep 17 00:00:00 2001 From: silentCoder-dev Date: Tue, 13 Jan 2026 11:01:14 +0800 Subject: [PATCH 27/27] format --- src/transform/thread_storage_sync.cc | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 5b5fdc141..1fd0f10a7 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -20,30 +20,27 @@ /*! * \file thread_storage_sync.cc */ +#include "../op/builtin.h" +#include "./common/constr_visitor.h" +#include "./common/thread_sync_types.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" +#include +#include +#include #include #include +#include +#include #include #include #include +#include #include #include -#include "./common/thread_sync_types.h" -#include "arith/ir_mutator_with_analyzer.h" - #include #include - -#include "../op/builtin.h" -#include "./common/constr_visitor.h" -#include "runtime/thread_storage_scope.h" -#include "tir/transforms/ir_utils.h" -#include -#include - -#include -#include -#include -#include #include namespace tvm {