[Feature] Reimplement Threadsync with ConstrVisitor#1631
[Feature] Reimplement Threadsync with ConstrVisitor#1631LeiWang1999 merged 33 commits intotile-ai:mainfrom
Threadsync with ConstrVisitor#1631Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds textual formatting and guarded traversal to constraint visitors; replaces the TileLang storage-access analysis with a multi-pass thread-sync planner/rewriter pipeline; removes legacy storage_access files and headers; and adds tests ensuring no unintended __syncthreads emission in several TileLang scenarios. Changes
Sequence DiagramsequenceDiagram
participant Client
participant TileLangThreadSyncPlanner
participant ThreadSyncAfterWaitQueueInserter
participant ThreadSyncInserter
participant ThreadPartialSyncRewriter
Client->>TileLangThreadSyncPlanner: Traverse IR (ConstrVisitor)
TileLangThreadSyncPlanner->>TileLangThreadSyncPlanner: Collect AccessEntry / StmtEntry summaries
TileLangThreadSyncPlanner->>ThreadSyncAfterWaitQueueInserter: Provide summaries
ThreadSyncAfterWaitQueueInserter->>ThreadSyncAfterWaitQueueInserter: Insert syncs after async wait-queues
ThreadSyncAfterWaitQueueInserter->>ThreadSyncInserter: Emit IR with inserted sync points
ThreadSyncInserter->>ThreadSyncInserter: Generate per-scope barrier placements
ThreadSyncInserter->>ThreadPartialSyncRewriter: Emit IR with barrier metadata
ThreadPartialSyncRewriter->>ThreadPartialSyncRewriter: Rewrite barrier IDs and propagate params through calls
ThreadPartialSyncRewriter->>Client: Return transformed IR
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)
1281-1282: Remove stray semicolon.Line 1282 has an orphan semicolon after the return statement.
🧹 Proposed fix
return tl::TileLangThreadSync(std::move(f), storage_scope); - ; };
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 1032-1034: The snippet creates an AccessEntry named esync
(AccessEntry esync{.cset = constr_stack_};) but contains an extra stray
semicolon on its own line; remove that orphan semicolon so the block reads with
only the declaration and the subsequent assignment to esync.type = kSync; which
ensures no empty statement remains between the declaration and the following
code referencing esync.
- Around line 425-432: The code dereferences min_node and extent_node without
null checks which can crash if iv->dom->min or iv->dom->extent are not
IntImmNode; update the block around the uses of min_node and extent_node (the
lines that call iv->dom->min.as<IntImmNode>() and
iv->dom->extent.as<IntImmNode>() and the subsequent min/extent/max computation)
to check that both min_node and extent_node are non-null before accessing
->value, and if either is null (symbolic or non-IntImm) return false (or
otherwise handle the non-IntImm case consistently with surrounding logic) so no
null pointer is dereferenced when comparing to bound->min_value /
bound->max_value.
- Around line 1189-1200: The loop currently calls ConstrSet::Substitute but
discards its return value, so substitutions on prev_cset and curr_cset are lost;
update the loop to capture the returned ConstrSet (e.g., prev_cset =
prev_cset.Substitute(...) and curr_cset = curr_cset.Substitute(...)) for each
Var substitution while keeping the existing handling for prev_indice_bytes and
curr_indice_bytes, then call Populate(analyzer) as before on the substituted
sets.
🧹 Nitpick comments (2)
src/transform/thread_storage_sync.cc (2)
36-59: Remove duplicate includes.Several headers are included twice, likely from merging or incomplete cleanup:
runtime/thread_storage_scope.h(lines 39, 50)tvm/tir/stmt_functor.h(lines 28, 44)tvm/tir/expr.h(lines 27, 43)../op/builtin.h(lines 35, 58)tir/transforms/ir_utils.h(lines 40, 59)<utility>(lines 33, 56)<unordered_map>(lines 31, 46)
1136-1160: Consider removing or documenting commented-out code.Several blocks of commented-out code remain (lines 1136-1141, 1147-1160). If these represent deprecated approaches, consider removing them. If they're placeholders for future work, add a TODO comment explaining the intent.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/common/constr_visitor.hsrc/transform/thread_storage_sync.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/thread_storage_sync.cc (1)
tilelang/language/tir/op.py (1)
tvm_storage_sync(534-547)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (5)
src/transform/common/constr_visitor.h (1)
117-119: LGTM!The
usingdeclarations correctly expose the base-class visitor methods, enabling derived classes likeTileLangThreadSyncPlannerto invoke them directly. This aligns with the architectural refactor inthread_storage_sync.cc.src/transform/thread_storage_sync.cc (4)
106-129: LGTM!The
ThreadSyncAfterWaitQueueInsertercorrectly handles async wait queue semantics by inserting synchronization after the wait boundary. The restructuring ofAttrStmtnodes preserves the async scoping correctly.
212-225: Verify: Both read and write counts incremented unconditionally foraddress_of.Lines 214-217 and 218-221 have identical conditions, causing both
read_countandwrite_countto always increment together. Ifaddress_ofis intentionally treated as both read and write, the duplicate condition check is redundant:♻️ Suggested simplification if intentional
if (auto load = op->args[0].as<BufferLoadNode>()) { Var buffer_var(Downcast<Var>(load->buffer->data)); if (sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[buffer_var].read_count; - } - if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[buffer_var].write_count; }
446-500: LGTM!The
TileLangThreadSyncPlannerclass structure withAccessEntryandStmtEntryprovides a clean model for tracking buffer accesses with their constraints, thread context, and async copy semantics. The integration withConstrVisitorproperly inherits constraint tracking.
689-732: LGTM!The
IfThenElsehandling correctly:
- Tracks thread-invariant conditions for the condition counter
- Preserves condition expression accesses for dependency analysis
- Summarizes both branches with appropriate constraints
This ensures proper sync insertion when shared memory is accessed in the condition.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 413-420: The code dereferences min_node and extent_node without
checking the result of as<IntImmNode>(), which can be null if iv->dom->min or
iv->dom->extent are non-integer expressions; update the block that uses
min_node/extent_node (the variables min_node, extent_node and the expression
iv->dom->min / iv->dom->extent) to first check if min_node and extent_node are
non-null and, if either is null, return false (or otherwise handle the symbolic
case), otherwise compute min/extent/max and compare to bound->min_value /
bound->max_value as before.
🧹 Nitpick comments (2)
src/transform/thread_storage_sync.cc (2)
35-59: Remove duplicate includes.Several header files are included multiple times, which is redundant and can cause maintenance issues:
../op/builtin.h(lines 35 and 58)tir/transforms/ir_utils.h(lines 40 and 59)<tvm/tir/expr.h>(lines 27 and 43)<tvm/tir/stmt_functor.h>(lines 28 and 44)<unordered_map>(lines 31 and 46)🧹 Proposed cleanup
Remove the duplicate includes on lines 41-59 that are already present earlier in the file.
#include "../op/builtin.h" +#include "./common/constr_visitor.h" #include "./common/thread_sync_types.h" #include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" -#include <tvm/arith/int_set.h> -#include <tvm/ir/attrs.h> -#include <tvm/tir/expr.h> -#include <tvm/tir/stmt_functor.h> - -#include <unordered_map> -#include <vector> - -#include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" #include <tvm/arith/analyzer.h> +#include <tvm/arith/int_set.h> +#include <tvm/ir/attrs.h> #include <tvm/target/target_info.h> #include <tvm/tir/op.h> - +#include <vector> #include <string> #include <utility> - -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h"
1284-1286: Remove stray semicolon after return statement.Line 1286 contains an extra semicolon after the return statement on line 1285.
🧹 Proposed fix
return tl::TileLangThreadSync(std::move(f), storage_scope); - ; };
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/thread_storage_sync.cc (1)
src/transform/storage_access.h (3)
double_buffer_write_(172-179)condition_counter_(170-170)tma_depth_(168-168)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (5)
src/transform/thread_storage_sync.cc (5)
94-117: LGTM!The
ThreadSyncAfterWaitQueueInserterclass correctly addresses the need to insert synchronization afterasync_wait_queueoperations, which the standardThreadSyncInsertercannot detect due to asynchronous semantics.
119-288: LGTM!The
ThreadSyncInserterclass properly inserts synchronization barriers at the identified sync points and correctly handles read/write statistics for global barrier management. The logic for handlingtvm_access_ptr,address_of, and thread extent tracking is well-structured.
1034-1038: Past issue resolved: Stray semicolon removed.The stray semicolon that was previously flagged after the
AccessEntry esyncdeclaration has been removed. The code now correctly flows from the declaration to the field assignments.
1192-1203: Past issue resolved: Substitute return values now captured.The previous issue where
ConstrSet::Substitutereturn values were discarded has been fixed. Lines 1197 and 1200 now properly capture the returnedConstrSetobjects:
- Line 1197:
prev_cset = prev_cset.Substitute(...)- Line 1200:
curr_cset = curr_cset.Substitute(...)
691-734: LGTM! Well-documented condition handling.The detailed documentation for
VisitStmt_handling ofIfThenElseNodeclearly explains the condition-aware synchronization logic, including thread-invariant checks, condition counter management, and access merging from both branches.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 346-353: The current code in the block using
GetOrCreateBarrier(key, extent_tx, extent_ty, extent_tz) returns an empty Stmt()
when thread_count % 32 != 0, silently dropping the sync; instead, emit a warning
including barrier_id and thread_count (e.g., via the project's logging facility)
and preserve the original synchronization IR node rather than returning Stmt();
locate the place referencing barrier_id and thread_count and replace the early
return with a warning log plus returning/keeping the existing sync statement so
the barrier remains in the IR.
- Around line 1237-1249: The loop indexing into prev.threads and curr.threads
can go out of bounds; before using prev.threads[prev.threads.size() + idx - 3]
and curr.threads[curr.threads.size() + idx - 3] guard with a bounds check or
reduce the loop count to the available elements (e.g. compute count =
std::min<unsigned>(3, prev.threads.size(), curr.threads.size()) and iterate idx
< count), and only perform the substitutions on prev_indice_bytes, prev_cset,
curr_indice_bytes, curr_cset when the corresponding old_prev_var / old_curr_var
are valid; this preserves the existing Var/Substitute logic (references:
prev.threads, curr.threads, thread_vars, old_prev_var, old_curr_var,
prev_indice_bytes, prev_cset, curr_indice_bytes, curr_cset).
- Around line 1280-1292: The variable range_is_equal is never set to true
because the comparison loop is commented out; restore and implement the range
equality check: initialize range_is_equal = true, iterate over prev.thread_range
entries and for each key look up curr.thread_range (handle missing keys by
setting range_is_equal = false), use StructuralEqual()(kv.second,
curr.thread_range.at(kv.first)) (or equivalent safe lookup) to compare values
and set range_is_equal = false and break on mismatch, then remove the TODO and
commented-out lines so the existing condition using has_same_index &&
range_is_equal can work as intended.
🧹 Nitpick comments (5)
src/transform/common/constr_visitor.h (1)
189-194: Consider whether the condition should be evaluated before guarding.The
WhileNodevisitor applies the condition as a guard and visits the body. However, unlikeIfThenElseNode(lines 157-166), the condition expression itself is not visited before creating the guard. If the condition contains buffer accesses or other side-effect-relevant expressions, they won't be captured.Compare with
IfThenElseNodehandling which separates condition visiting from body visiting. If consistency is desired, consider visiting the condition expression first.♻️ Optional: Visit condition before body
void VisitStmt_(const tir::WhileNode *op) override { + Base::VisitExpr(op->condition); { auto guard = MakeGuard(op->condition); Base::VisitStmt(op->body); } }src/transform/thread_storage_sync.cc (4)
36-59: Remove duplicate includes.Several headers are included multiple times:
tvm/tir/expr.h(lines 27, 43)tvm/tir/stmt_functor.h(lines 28, 44)runtime/thread_storage_scope.h(lines 39, 50)../op/builtin.h(lines 35, 58)tir/transforms/ir_utils.h(lines 40, 59)<unordered_map>(lines 31, 46)<utility>(lines 33, 56)♻️ Consolidate includes
Remove the duplicate block (lines 41-59) and keep only the organized includes at the top (lines 23-40).
64-71: Remove duplicateusing namespace tir;declaration.
using namespace tir;appears twice (lines 64 and 70), which is redundant.♻️ Remove duplicate
using namespace tir; using namespace ffi; using arith::IRVisitorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; -using namespace tir; using arith::IRMutatorWithAnalyzer;
1085-1170: Consider removing or conditionally compiling debug function.
print_access_tentryappears to be a debugging helper that logs access entry information viaLOG(WARNING). This function is defined but may not be called in production code. Consider:
- Removing it if unused
- Wrapping in
#ifndef NDEBUGif needed only for debugging- Using
DLOGinstead ofLOG(WARNING)for debug-only output
1351-1352: Remove extra semicolon.Line 1352 has a trailing semicolon after the return statement, which is harmless but unnecessary.
♻️ Remove extra semicolon
return tl::TileLangThreadSync(std::move(f), storage_scope); - ;
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/transform/common/constr_visitor.hsrc/transform/eliminate_storage_sync_for_mbarrier.ccsrc/transform/inject_ptx_async_copy.ccsrc/transform/storage_access.ccsrc/transform/storage_access.hsrc/transform/thread_storage_sync.cc
💤 Files with no reviewable changes (4)
- src/transform/eliminate_storage_sync_for_mbarrier.cc
- src/transform/inject_ptx_async_copy.cc
- src/transform/storage_access.cc
- src/transform/storage_access.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/common/constr_visitor.h
🧬 Code graph analysis (1)
src/transform/thread_storage_sync.cc (2)
tilelang/language/tir/op.py (1)
tvm_storage_sync(534-547)src/transform/storage_rewrite.cc (19)
buffer_var(100-105)buffer_var(100-100)buffer_var(107-112)buffer_var(107-107)buffer_var(700-710)buffer_var(700-700)e(862-904)e(862-862)scope(678-683)scope(678-678)buf(249-256)buf(249-249)buf(509-526)buf(509-509)buf(1761-1786)buf(1761-1761)buffer(1425-1458)buffer(1426-1427)i(720-860)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (4)
src/transform/common/constr_visitor.h (2)
108-117: LGTM! Well-designed RAII guard pattern.The
Guardstruct correctly implements scoped constraint management using RAII - the destructor automatically pops the constraint when the guard goes out of scope. MakingMakeGuardprotected appropriately allows derived classes (likeTileLangThreadSyncPlanner) to use this facility.
119-121: Appropriate exposure of base class visitor methods.The
usingdeclarations correctly expose the base classVisitExpr_andVisitStmt_overload sets, enabling derived classes to call specific overloads or extend the visitor pattern without hiding base implementations.src/transform/thread_storage_sync.cc (2)
94-117: LGTM! Correct async wait queue synchronization insertion.The class correctly inserts a
tvm_storage_synccall afterasync_wait_queue_scopeattributes, preserving the nested attribute structure while adding the synchronization barrier.
200-213: Verify:address_ofincrements both read and write counts.When handling
address_of, bothread_countandwrite_countare incremented with identical conditions (lines 202-209). This appears intentional for conservative pointer aliasing, but confirm this is the desired behavior rather than a copy-paste oversight where the second block should check for write intent.
Threadsync with ConstrVisitorThreadsync with ConstrVisitor
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
src/transform/thread_storage_sync.cc (2)
23-59: Remove duplicate includes.Several headers are included twice:
tvm/tir/expr.h(lines 27, 43)tvm/tir/stmt_functor.h(lines 28, 44)<unordered_map>(lines 31, 46)runtime/thread_storage_scope.h(lines 39, 50)../op/builtin.h(lines 35, 58)tir/transforms/ir_utils.h(lines 40, 59)Consolidate these to a single include per header.
♻️ Proposed cleanup
#include <tvm/ffi/function.h> #include <tvm/ffi/reflection/registry.h> #include <tvm/tir/analysis.h> +#include <tvm/arith/analyzer.h> +#include <tvm/arith/int_set.h> #include <tvm/tir/builtin.h> #include <tvm/tir/expr.h> +#include <tvm/tir/op.h> #include <tvm/tir/stmt_functor.h> #include <tvm/tir/transform.h> +#include <tvm/target/target_info.h> +#include <tvm/ir/attrs.h> #include <unordered_map> #include <unordered_set> +#include <vector> +#include <string> #include <utility> #include "../op/builtin.h" #include "./common/constr_visitor.h" #include "./common/thread_sync_types.h" #include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" -#include <tvm/arith/int_set.h> -#include <tvm/ir/attrs.h> -#include <tvm/tir/expr.h> -#include <tvm/tir/stmt_functor.h> - -#include <unordered_map> -#include <vector> - -#include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" -#include <tvm/arith/analyzer.h> -#include <tvm/target/target_info.h> -#include <tvm/tir/op.h> - -#include <string> -#include <utility> - -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h"
1396-1398: Remove stray semicolon.Line 1397 has an extra semicolon after the return statement.
🧹 Proposed fix
return tl::TileLangThreadSync(std::move(f), storage_scope); - ; };
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🧰 Additional context used
🪛 Cppcheck (2.19.0)
src/transform/thread_storage_sync.cc
[error] 1344-1344: syntax error
(syntaxError)
Threadsync with ConstrVisitorThreadsync with ConstrVisitor
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
testing/python/transform/test_tilelang_transform_thread_sync.py (1)
25-40: BindthreadIdx.y/zconsistently (don’t mixT.env_threadwithT.launch_thread("...")for the same tag).In
test_sync_if_with_same_index,threadIdx_y = T.env_thread("threadIdx.y")is used in theif, but the launched thread env is created viaT.launch_thread("threadIdx.y", 1). Depending on how TVMScript binds env threads, this can accidentally produce two different IterVars for the same CUDA tag and make the test misleading/fragile.Proposed fix
- threadIdx_y = T.env_thread("threadIdx.y") + threadIdx_y = T.env_thread("threadIdx.y") blockIdx_x = T.env_thread("blockIdx.x") @@ - ty = T.launch_thread("threadIdx.y", 1) - tz = T.launch_thread("threadIdx.z", 1) + T.launch_thread(threadIdx_y, 1) + threadIdx_z = T.env_thread("threadIdx.z") + T.launch_thread(threadIdx_z, 1)src/transform/thread_storage_sync.cc (1)
291-423: Critical: ThreadPartialSyncRewriter likely drops barriers and may not compile as written.
auto bound_tx = analyzer_->const_int_bound(tx_);passes anIterVarwhere the analyzer typically expects aPrimExpr/Var(likely intended:tx_->var).if (thread_count % 32 != 0) return Stmt();removes the original sync entirely → data race. Fallback should keep the originaltvm_storage_sync(...)when you can’t lower to a named barrier.IsFullThreadExtentassumesiv->dom->min/extentareIntImmand dereferences without null checks; dynamic extents will crash.
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 1090-1135: PointerAccessIsDisjoint assumes lhs.threads and
rhs.threads contain three entries and indexes them with
lhs.threads[lhs.threads.size() + idx - 3] / rhs.threads[rhs.threads.size() + idx
- 3], which can underflow; update the function to first validate thread vector
sizes or locate thread vars by tag: either add an explicit check like
ICHECK_GE(lhs.threads.size(), 3) and ICHECK_GE(rhs.threads.size(), 3) (or return
false/conservative conflict if the invariant is not met), or replace the
fixed-index logic with a lookup that finds the tx/ty/tz entries by their
thread_tag names and if any are missing bail out conservatively (return not
provably disjoint). Ensure you update the same pattern in the other affected
region (around the 1245-1383 range) and reference the symbols
lhs_min/lhs_max/rhs_min/rhs_max, prev_cset/curr_cset, and the thread_vars loop
when making the change.
- Around line 450-570: The LetStmt override records accesses in op->value then
visits op->body without applying the var->value constraint, which weakens
disjointness proofs; fix VisitStmt_(const LetStmtNode *op) in
TileLangThreadSyncPlanner by establishing the let binding before visiting the
body (e.g., call MakeGuard(op->var, op->value) to push the binding around the
VisitStmt(op->body) call, or simply delegate to ConstrVisitor::VisitStmt_(op) so
the base class applies the guard), and ensure the existing push/clear of
curr_stmt_.access and allow_append_ semantics remain correct when you introduce
the guard.
- Around line 492-510: The code uses C++20 designated initializer syntax (e.g.,
AccessEntry e{.cset = {constr_stack_}};) which is incompatible with the
project's C++17 setting; replace each designated-initializer usage (notably the
AccessEntry construction in VisitExpr_ (BufferLoadNode) and the other
occurrences referenced around lines 524 and 840) by default-constructing the
object (e.g., AccessEntry e;) and then assigning fields explicitly (e.cset =
{constr_stack_}; e.threads = env_threads(); etc.), ensuring all fields set via
designated init are instead assigned after construction.
🧹 Nitpick comments (6)
testing/python/transform/test_tilelang_transform_thread_sync.py (2)
56-58: Drop unusedty/tzbindings (or name them_ty/_tz) to avoid “used-for-side-effect” confusion.These frames aren’t referenced later; a plain
T.launch_thread(...)call reads clearer and avoids suggesting thatty/tzare meaningful values.Also applies to: 79-81, 98-100, 124-126, 160-162, 201-203
127-138: The let-stmt test now reads shared indices that are never written (undefined semantics).
A_shared_1[ax0]only initializes indices[0..127], but the test still reads[threadIdx_x + 128/256/384]. Even if the test is “structural”, this weakens its validity as a sync test (it no longer models the common “load full tile into shared, then read it” pattern). Based on learnings, it’s fine to keep assertions structural, but the IR should still be semantically well-formed.Suggestion: either (a) have each thread write the 4 elements it later reads, or (b) change the subsequent reads to only touch
A_shared_1[threadIdx_x].Also applies to: 163-175
src/transform/common/constr_visitor.h (3)
20-80: Avoid UB:Constr()default-constructs with uninitializedkind, butformat()now switches on it.If any code path default-constructs a
Constrand later logs/formats it, this is undefined behavior. Consider initializingkind(and maybevar/range/value) in the default member declarations or deleting the default ctor if it’s not intended.
141-158: MakeGuard should use perfect forwarding (and avoid extra copies).Current signature
template <typename... Args> Guard MakeGuard(const Args... args)copies args. PreferArgs&&...+std::forward(esp. as this is a hot visitor utility).
215-235: Deduplicate the identicalForNodeif/else branch; also consider visitingWhilecondition if that was intended.The
ifandelsebodies inVisitStmt_(ForNode*)are identical now; can be collapsed. ForWhileNode, you add a guard for the condition but don’t traverse the condition expression; that may be OK (consistent withIfThenElsehere), but worth confirming for downstream users.src/transform/thread_storage_sync.cc (1)
35-61: Header/include hygiene: lots of duplicate includes /using namespaceclutter (harder reviews + slower builds).Not blocking, but worth pruning duplicates (
<unordered_map>,<utility>,tir/expr.h,tir/stmt_functor.h,../op/builtin.h,tir/transforms/ir_utils.h, repeatedusing namespace tir;, etc.) once the WIP stabilizes.Also applies to: 65-73
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/transform/common/constr_visitor.hsrc/transform/thread_storage_sync.cctesting/python/transform/test_tilelang_transform_thread_sync.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/common/constr_visitor.h
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: For Python tests of the tilelang transform passes, focus assertions on structural patterns in the generated kernel source (e.g., hoisting behavior) rather than specific numeric literals. Do not rely on particular values (like 430); the test should validate behavior, not numerical precision or actual stored values. This guideline can be applied to all tests in the testing/python/transform directory.
Applied to files:
testing/python/transform/test_tilelang_transform_thread_sync.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/transform/test_tilelang_transform_thread_sync.py
🧬 Code graph analysis (2)
src/transform/thread_storage_sync.cc (1)
src/transform/storage_rewrite.cc (19)
buffer_var(100-105)buffer_var(100-100)buffer_var(107-112)buffer_var(107-107)buffer_var(700-710)buffer_var(700-700)e(862-904)e(862-862)scope(678-683)scope(678-678)buf(249-256)buf(249-249)buf(509-526)buf(509-509)buf(1761-1786)buf(1761-1761)buffer(1425-1458)buffer(1426-1427)i(720-860)
testing/python/transform/test_tilelang_transform_thread_sync.py (1)
tilelang/language/ast/ir.py (1)
launch_thread(1205-1237)
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 1090-1135: PointerAccessIsDisjoint (and the similar code in
FindConflict) performs lhs.threads[lhs.threads.size() + idx - 3] and
rhs.threads[rhs.threads.size() + idx - 3] with idx in 0..2 which underflows when
threads.size() < 3; fix by first validating the invariant (e.g., if
(lhs.threads.size() < 3 || rhs.threads.size() < 3) return false; or
assert/handle error) before the loop, or compute the base index as size_t base =
lhs.threads.size() >= 3 ? lhs.threads.size() - 3 : 0 (and similarly for rhs) and
ensure you never perform negative arithmetic on unsigned sizes; apply the same
guard/update in FindConflict at the referenced locations and prefer explicit
checks rather than relying on implicit invariants.
🧹 Nitpick comments (5)
src/transform/thread_storage_sync.cc (5)
36-60: Remove duplicate includes.Multiple headers are included twice in this section, which is unnecessary:
"../op/builtin.h"(lines 36, 59)"runtime/thread_storage_scope.h"(lines 39, 50)"tir/transforms/ir_utils.h"(lines 40, 60)- Several TVM headers (lines 41-44, 51-53)
♻️ Consolidate includes
#include "../op/builtin.h" -#include "./common/constr_visitor.h" #include "./common/thread_sync_types.h" #include "arith/ir_mutator_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" #include <tvm/arith/int_set.h> #include <tvm/ir/attrs.h> #include <tvm/tir/expr.h> #include <tvm/tir/stmt_functor.h> - #include <unordered_map> #include <vector> - #include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" #include <tvm/arith/analyzer.h> #include <tvm/target/target_info.h> #include <tvm/tir/op.h> - #include <fstream> #include <string> #include <utility> - -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h" +#include "./common/constr_visitor.h"
607-609: Remove or uncomment the analyzer binding code.These lines are commented out without explanation. If the analyzer binding is no longer needed, remove the commented code. If it's still needed, uncomment and document why.
env_threads_.push_back(iv); ICHECK_NE(iv->thread_tag.length(), 0U); - // analyzer_.Bind( - // iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), - // op->value)); -
1136-1223: Consider using appropriate log level for debug output.The
print_access_entrymethod usesLOG(WARNING)for diagnostic output (line 1222). Consider usingLOG(INFO)orDLOG(INFO)instead, or make it conditional on a debug flag to avoid cluttering logs in production.
1325-1328: Remove commented-out debug code.These commented-out
LOG(WARNING)statements should be removed to keep the codebase clean. If this debugging capability is needed, consider adding it behind a compile-time or runtime debug flag.- // if (!provably_disjoint) { - // LOG(WARNING) << analyzer.z3_prover.GetModel( - // tir::EQ(prev_indice_bytes, curr_indice_bytes)); - // }- // if (!provably_disjoint) { - // LOG(WARNING) << analyzer.z3_prover.GetStats(); - // LOG(WARNING) << - // analyzer.z3_prover.GetSMTLIB2(tir::Not(tir::Or(prev_min > - // curr_max, curr_min > prev_max))); - // }Also applies to: 1353-1358
1448-1449: Remove extraneous semicolon.Line 1449 has an unnecessary semicolon after the return statement.
return tl::TileLangThreadSync(std::move(f), storage_scope); - ;
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🔇 Additional comments (7)
src/transform/thread_storage_sync.cc (7)
95-118: LGTM!The
ThreadSyncAfterWaitQueueInsertercorrectly identifiesasync_wait_queue_scopeattributes and inserts synchronization barriers after them. The implementation properly reconstructs the nested attribute structure.
120-289: LGTM!The
ThreadSyncInserterimplementation correctly:
- Inserts barriers at marked synchronization points
- Tracks read/write statistics for global memory
- Handles special cases like global barriers with proper initialization
- Manages thread extent scoping
435-578: LGTM!The
TileLangThreadSyncPlannerdata structures and basic visitor methods are well-designed:
AccessEntryandStmtEntryprovide comprehensive tracking of memory accesses- Buffer load/store visitors correctly populate access entries with threads, indices, and constraints
- Proper scoping and state management with
allow_append_flag
632-736: LGTM!The control flow visitor methods correctly handle loop and conditional constructs:
ForNodeproperly relaxes touched sets to cover the entire loop rangeIfThenElseNodecorrectly applies constraint guards for both branches and preserves condition access for dependency analysisWhileNodeappropriately summarizes body accessesThe detailed documentation on
VisitStmt_(const IfThenElseNode*)is particularly helpful.
738-875: LGTM!The
CallNodevisitor comprehensively handles various built-in operations:
- TMA load context tracking prevents spurious conflicts between async writes
address_ofandtvm_access_ptrproperly extract buffer ranges- Linear-to-indices conversion correctly computes multi-dimensional indices from flat offsets
- Synchronization operations are properly recognized and recorded
881-1058: LGTM!The
Summarizemethod implements a sound conflict detection and synchronization placement algorithm:
- Properly unifies
shared.dynaccesses for coordinated planning- Detects RAW, WAR, and WAW hazards between unsynced operations
- Intelligently hoists loop synchronization when only writes are present (avoiding per-iteration barriers)
- Correctly builds exposed access summaries for parent scopes
The optimization at lines 1001-1010 for hoisting synchronization before write-only loops is a good performance improvement.
1416-1432: LGTM!The
TileLangThreadSyncentry point correctly orchestrates the three-stage pipeline:
ThreadSyncAfterWaitQueueInserterfor async wait queuesTileLangThreadSyncPlannerto analyze and plan sync pointsThreadSyncInserterto insert the barriersThreadPartialSyncRewriterto optimize partial synchronizationThe buffer map initialization ensures the planner has access to all buffer metadata.
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)
1453-1454: Double semicolon.There's an extra semicolon at line 1454.
return tl::TileLangThreadSync(std::move(f), storage_scope); - ;
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 414-419: The code dereferences min_node and extent_node obtained
via iv->dom->min.as<IntImmNode>() / iv->dom->extent.as<IntImmNode>() without
null checks; add checks that both min_node and extent_node are non-null before
accessing ->value, and handle the non-IntImm case (e.g., return/skip, emit an
error/log, or use an alternative extraction path) so the code does not crash
when iv->dom->min or iv->dom->extent are not IntImm nodes.
- Line 1277: Rename the misspelled local variable curr_analyer to curr_analyzer
in the declaration alongside prev_analyzer (arith::Analyzer prev_analyzer,
curr_analyzer;) and update all usages where curr_analyer is referenced (e.g.,
the subsequent calls/assignments that currently use curr_analyer) to use
curr_analyzer so the identifier is consistent.
- Around line 727-736: VisitStmt_(const WhileNode *op) currently just calls
VisitExpr(op->condition) and misses the guarded-access collection done in
VisitStmt_(IfThenElseNode*). Update VisitStmt_(const WhileNode*) to mirror the
IfThenElseNode handling: call MakeGuard(op->condition) before visiting, set
allow_append_ = true while visiting the condition expression, collect the
condition-related entries from scope_ (as done for IfThenElse), then set
allow_append_ = false and use Summarize on those collected entries for s.access;
keep the existing body handling (pushing a new scope, VisitStmt(op->body), pop,
and emplace_back the StmtEntry). Ensure you reference and use MakeGuard,
allow_append_, scope_, VisitExpr(op->condition), Summarize(...), and the
StmtEntry s to store the computed access summary.
- Around line 349-354: The current early return when (thread_count % 32 != 0)
drops the sync by returning Stmt(); instead fall back to emitting the original
full barrier instead of an empty statement: locate the branch that checks
thread_count and replace the return Stmt() with code that emits the full
__syncthreads()/thread barrier (the same sync node used elsewhere in this pass),
and also add a warning log (or an assert, configurable) so this non-warp-aligned
case is detectable at runtime or in debug builds; keep the existing TODO comment
if desired.
- Around line 1360-1382: The catch-all in the try block around
analyzer.Simplify()/analyzer.CanProve(...) can hide real errors and lead to
incorrect range_is_overlap=false decisions; replace catch(...) with catching
specific exception types thrown by the analyzer (e.g., analysis/solver
exceptions) or at minimum log the exception details before falling back, and if
the exception is not an expected solver/analysis failure rethrow it so only
anticipated proof failures use the const_int_bound fallback (update the block
around analyzer.Simplify/CanProve and the fallback that uses
const_int_bound(prev_indice_bytes)/const_int_bound(curr_indice_bytes)
accordingly, keeping the existing range_is_overlap logic only for expected
analysis failures).
- Around line 816-834: The lambda linear_to_indices currently does PrimExpr
remaining = std::move(offset), which moves-from offset and leaves it unusable
for subsequent calls (used later for start_indices and end_indices); change this
to copy the value instead (e.g., PrimExpr remaining = offset or make a local
copy before moving) so offset remains valid for the two calls
linear_to_indices(offset, buffer_shape) and linear_to_indices(offset + extent,
buffer_shape) and ensure remaining is modified only on the local copy.
🧹 Nitpick comments (4)
src/transform/thread_storage_sync.cc (4)
23-60: Remove duplicate includes.There are duplicate
#includedirectives. Many headers appear twice (e.g.,<tvm/tir/expr.h>,<tvm/tir/stmt_functor.h>,<unordered_map>,"runtime/thread_storage_scope.h","../op/builtin.h","tir/transforms/ir_utils.h"). This adds unnecessary noise and can slow down compilation.♻️ Suggested consolidation
#include <tvm/ffi/function.h> #include <tvm/ffi/reflection/registry.h> +#include <tvm/arith/analyzer.h> +#include <tvm/arith/int_set.h> +#include <tvm/ir/attrs.h> +#include <tvm/target/target_info.h> #include <tvm/tir/analysis.h> #include <tvm/tir/builtin.h> #include <tvm/tir/expr.h> +#include <tvm/tir/op.h> #include <tvm/tir/stmt_functor.h> #include <tvm/tir/transform.h> +#include <fstream> +#include <string> #include <unordered_map> #include <unordered_set> #include <utility> +#include <vector> #include "../op/builtin.h" #include "./common/constr_visitor.h" #include "./common/thread_sync_types.h" #include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" -#include <tvm/arith/int_set.h> -#include <tvm/ir/attrs.h> -#include <tvm/tir/expr.h> -#include <tvm/tir/stmt_functor.h> - -#include <unordered_map> -#include <vector> - -#include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" -#include <tvm/arith/analyzer.h> -#include <tvm/target/target_info.h> -#include <tvm/tir/op.h> - -#include <fstream> -#include <string> -#include <utility> - -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h"
65-72: Remove duplicateusingdeclarations.
using namespace tir;appears at lines 65 and 71. Similarly, other using declarations may be redundant.♻️ Consolidate using declarations
namespace tvm { namespace tl { using namespace tir; -using namespace ffi; using arith::IRVisitorWithAnalyzer; +using arith::IRMutatorWithAnalyzer; +using namespace ffi; using runtime::StorageRank; using runtime::StorageScope; -using namespace tir; -using arith::IRMutatorWithAnalyzer;
1355-1358: Debug logging left in production code path.
LOG(WARNING)at line 1356 is called every time the analyzer cannot prove disjointness. This will spam logs in production and may leak internal state viaGetModel(). Remove or guard with a debug flag.♻️ Remove or guard debug logging
provably_disjoint = analyzer.CanProve(tir::NE(prev_indice_bytes, curr_indice_bytes)); - if (!provably_disjoint) { - LOG(WARNING) << analyzer.z3_prover.GetModel( - tir::EQ(prev_indice_bytes, curr_indice_bytes)); - } + // Debug: uncomment to diagnose conflict detection + // if (!provably_disjoint) { + // DLOG(INFO) << analyzer.z3_prover.GetModel( + // tir::EQ(prev_indice_bytes, curr_indice_bytes)); + // }
1136-1223: Consider removing or guarding unused debug function.
print_access_tentryappears to be a debug helper that is never called in the production code paths. If needed for development, consider:
- Guarding with
#ifndef NDEBUGor similar- Using
DLOGorVLOGinstead ofLOG(WARNING)- Removing entirely if no longer needed
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/thread_storage_sync.cc (2)
tilelang/language/tir/op.py (1)
tvm_storage_sync(534-547)src/transform/common/constr_visitor.h (1)
Substitute(81-83)
🔇 Additional comments (1)
src/transform/thread_storage_sync.cc (1)
95-118: LGTM!The
ThreadSyncAfterWaitQueueInserterclass correctly inserts synchronization afterasync_wait_queueoperations as documented in the comments. The implementation properly reconstructs the attribute statement structure with the sync barrier inserted.
| auto linear_to_indices = [this](PrimExpr offset, | ||
| const Array<PrimExpr> &shape) { | ||
| Array<PrimExpr> indices; | ||
| PrimExpr remaining = std::move(offset); | ||
| for (size_t i = 0; i < shape.size(); ++i) { | ||
| PrimExpr stride = make_const(DataType::Int(32), 1); | ||
| for (size_t j = i + 1; j < shape.size(); ++j) { | ||
| stride = stride * shape[j]; | ||
| } | ||
| PrimExpr idx = FloorDiv(remaining, stride); | ||
| remaining = FloorMod(remaining, stride); | ||
| indices.push_back(idx); | ||
| } | ||
| return indices; | ||
| }; | ||
| Array<PrimExpr> start_indices = | ||
| linear_to_indices(offset, buffer_shape); | ||
| Array<PrimExpr> end_indices = | ||
| linear_to_indices(offset + extent, buffer_shape); |
There was a problem hiding this comment.
Use-after-move of offset.
At line 819, offset is moved into remaining: PrimExpr remaining = std::move(offset);. However, offset is used again at lines 832 and 834 (linear_to_indices(offset, ...) and linear_to_indices(offset + extent, ...)). After the move, offset is in an unspecified state.
🐛 Remove the std::move
auto linear_to_indices = [this](PrimExpr offset,
const Array<PrimExpr> &shape) {
Array<PrimExpr> indices;
- PrimExpr remaining = std::move(offset);
+ PrimExpr remaining = offset;
for (size_t i = 0; i < shape.size(); ++i) {📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| auto linear_to_indices = [this](PrimExpr offset, | |
| const Array<PrimExpr> &shape) { | |
| Array<PrimExpr> indices; | |
| PrimExpr remaining = std::move(offset); | |
| for (size_t i = 0; i < shape.size(); ++i) { | |
| PrimExpr stride = make_const(DataType::Int(32), 1); | |
| for (size_t j = i + 1; j < shape.size(); ++j) { | |
| stride = stride * shape[j]; | |
| } | |
| PrimExpr idx = FloorDiv(remaining, stride); | |
| remaining = FloorMod(remaining, stride); | |
| indices.push_back(idx); | |
| } | |
| return indices; | |
| }; | |
| Array<PrimExpr> start_indices = | |
| linear_to_indices(offset, buffer_shape); | |
| Array<PrimExpr> end_indices = | |
| linear_to_indices(offset + extent, buffer_shape); | |
| auto linear_to_indices = [this](PrimExpr offset, | |
| const Array<PrimExpr> &shape) { | |
| Array<PrimExpr> indices; | |
| PrimExpr remaining = offset; | |
| for (size_t i = 0; i < shape.size(); ++i) { | |
| PrimExpr stride = make_const(DataType::Int(32), 1); | |
| for (size_t j = i + 1; j < shape.size(); ++j) { | |
| stride = stride * shape[j]; | |
| } | |
| PrimExpr idx = FloorDiv(remaining, stride); | |
| remaining = FloorMod(remaining, stride); | |
| indices.push_back(idx); | |
| } | |
| return indices; | |
| }; | |
| Array<PrimExpr> start_indices = | |
| linear_to_indices(offset, buffer_shape); | |
| Array<PrimExpr> end_indices = | |
| linear_to_indices(offset + extent, buffer_shape); |
🤖 Prompt for AI Agents
In @src/transform/thread_storage_sync.cc around lines 816 - 834, The lambda
linear_to_indices currently does PrimExpr remaining = std::move(offset), which
moves-from offset and leaves it unusable for subsequent calls (used later for
start_indices and end_indices); change this to copy the value instead (e.g.,
PrimExpr remaining = offset or make a local copy before moving) so offset
remains valid for the two calls linear_to_indices(offset, buffer_shape) and
linear_to_indices(offset + extent, buffer_shape) and ensure remaining is
modified only on the local copy.
There was a problem hiding this comment.
@silentCoder-dev, I think there may be a misunderstanding here. The bot's concern was about C++ move semantics inside the linear_to_indices lambda, not about thread index bounds or array access safety.
However, I should clarify that the bot's analysis is actually incorrect in this case. Here's why the code is fine as written:
- The lambda
linear_to_indicestakesPrimExpr offsetas a parameter by value (line 816) - Inside the lambda,
PrimExpr remaining = std::move(offset);moves from the lambda's local parameter copy, not the outer scope'soffsetvariable - When you call
linear_to_indices(offset, buffer_shape)at lines 832 and 834, each call passes a fresh argument to the lambda - The outer
offsetvariable (from line 804) remains completely unaffected by what happens inside the lambda's parameter
So there's no use-after-move bug here. The std::move inside the lambda only affects that particular invocation's parameter copy, and each call to the lambda gets its own independent copy.
That said, the std::move inside the lambda is somewhat unnecessary since offset is already passed by value (so you're moving from a copy that's about to go out of scope anyway), but it's harmless.
The code is correct as-is and doesn't need the change the bot suggested. The bot's static analysis appears to have confused the lambda's parameter scope with the outer scope.
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)
1455-1456: Remove extraneous semicolon.return tl::TileLangThreadSync(std::move(f), storage_scope); - ;
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Line 1277: Rename the misspelled variable curr_analyer to curr_analyzer in the
declaration arith::Analyzer prev_analyzer, curr_analyzer; and update all
subsequent usages of curr_analyer to curr_analyzer (search for curr_analyer in
thread_storage_sync.cc and replace every occurrence, including in any
comparisons, assignments, or function calls that reference it).
- Around line 349-354: Currently when (thread_count % 32 != 0) the code returns
an empty Stmt() which can break downstream consumers; instead return the
original statement unchanged and/or emit a diagnostic. Replace the return Stmt()
in the thread_count % 32 check with a return of the original input statement
(the same Stmt passed into this function) and optionally call the project's
logging/diagnostic API to record that the pass skipped transformation due to
non-multiple-of-32 thread_count; keep the thread_count and conditional
(thread_count % 32 != 0) as the guard.
- Around line 414-419: The code dereferences min_node and extent_node from
iv->dom using as<IntImmNode>() which can return nullptr for non-constant ranges;
add null checks for min_node and extent_node before reading ->value and handle
the symbolic case (e.g., bail out/skip this transformation, compute bounds via a
safe API, or use conservative sentinel values) so min, extent, and derived max
are not read from a null pointer; update the logic around iv->dom, min_node,
extent_node, and the computed min/extent/max to return early or use a safe
fallback when either as<IntImmNode() returns nullptr.
- Around line 201-214: The code handling address_of (BufferLoadNode via
op->args[0].as<BufferLoadNode>()) wrongly increments both
rw_stats_[buffer_var].read_count and .write_count because the two if conditions
are identical; fix this by only incrementing read_count for address_of, or by
mirroring the tvm_access_ptr logic and guard the write_count increment with the
buffer load's write flag (e.g., check the BufferLoadNode's access/write
indicator on the `load` object) while keeping the existing scope checks using
sync_scope_ and GetScope(buffer_var) before updating rw_stats_.
- Around line 1110-1122: The loop indexes lhs.threads and rhs.threads with
lhs.threads.size() + idx - 3 which underflows when threads.size() < 3; add a
guard at the start of the function to verify both lhs.threads.size() and
rhs.threads.size() are >= 3 (or adjust the loop to iterate only up to the
available threads) and return/skip the loop early if not, so the body that
constructs old_prev_var/old_curr_var and substitutes prev_var/curr_var cannot
perform an out-of-bounds access; apply the same guard/adjustment for the other
occurrences that use the same indexing pattern around the code that manipulates
thread_vars, lhs_min/lhs_max/rhs_min/rhs_max and prev_cset/curr_cset.
- Around line 727-736: VisitStmt_(const WhileNode *op) is visiting op->condition
but not preserving the accesses it produces; after this VisitExpr(op->condition)
capture the entries in scope_.back(), create a StmtEntry (e.g., cond_entry) with
stmt = op (or a marker for the condition) and access =
Summarize(std::move(scope_.back()), nullptr), clear or pop that temporary
entries vector, then proceed to push a new scope for the loop body,
VisitStmt(op->body) and summarize body as you already do; finally ensure you
append both the condition StmtEntry and the loop body StmtEntry(s) back into the
outer scope so condition accesses are considered for sync placement (use the
same patterns and helper Summarize/StmtEntry manipulations used by the
IfThenElseNode handler).
- Around line 1355-1358: The LOG(WARNING) call that prints
analyzer.z3_prover.GetModel(tir::EQ(prev_indice_bytes, curr_indice_bytes))
triggers on every non-provably_disjoint case and spams logs; replace or guard it
with less noisy logging (e.g., change LOG(WARNING) to DLOG(WARNING) or
VLOG(<level>), or wrap the call in a verbosity check like if
(VLOG_IS_ON(<level>)) before invoking GetModel) so the expensive/model output
only appears under verbose/diagnostic runs; keep the same message/context but
ensure GetModel() is called only when the chosen verbose level is enabled to
avoid runtime spam.
🧹 Nitpick comments (9)
src/transform/thread_storage_sync.cc (9)
36-60: Remove duplicate#includedirectives.Multiple headers are included twice, likely from merging different code paths:
<tvm/tir/expr.h>(lines 27, 43)<tvm/tir/stmt_functor.h>(lines 28, 44)<unordered_map>(lines 31, 46)runtime/thread_storage_scope.h(lines 39, 50)../op/builtin.h(lines 35, 59)tir/transforms/ir_utils.h(lines 40, 60)<utility>(lines 33, 57)Proposed cleanup
#include "../op/builtin.h" #include "./common/constr_visitor.h" #include "./common/thread_sync_types.h" #include "arith/ir_mutator_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" -#include <tvm/arith/int_set.h> -#include <tvm/ir/attrs.h> -#include <tvm/tir/expr.h> -#include <tvm/tir/stmt_functor.h> - -#include <unordered_map> -#include <vector> - #include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" #include <tvm/arith/analyzer.h> +#include <tvm/arith/int_set.h> +#include <tvm/ir/attrs.h> #include <tvm/target/target_info.h> #include <tvm/tir/op.h> #include <fstream> #include <string> -#include <utility> - -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h" +#include <vector>
55-56: Remove unused<fstream>include.The
<fstream>header appears to be unused in this file. It may be a leftover from debugging code.
65-72: Remove duplicateusing namespace tir;directive.
using namespace tir;appears twice (lines 65 and 71).Proposed fix
using namespace tir; using namespace ffi; using arith::IRVisitorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; -using namespace tir; using arith::IRMutatorWithAnalyzer;
458-459: Consider renamingbuffer_namefor clarity.The field
Buffer buffer_namestores the fullBufferobject, not a name string. This could be confusing sincebuffer(aVar) stores the data pointer. Consider renaming tobuffer_obj,source_buffer, orfull_bufferto clarify the distinction.
493-494: Simplify:bufis already aVar.The expression
tvm::ffi::GetRef<Var>(buf.get())is equivalent tobufsincebufis already aVar. This applies to multiple places in the file.- buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer); + buffer_data_to_buffer_.Set(buf, op->buffer);
558-570: Inconsistent handling: pushesStmtEntryeven if no accesses were collected.Unlike
EvaluateNode(line 551) which checks!curr_stmt_.access.empty()before pushing,LetStmtNodeunconditionally pushes toscope_.back(). Consider adding the same check for consistency and to avoid empty entries.Suggested fix
this->VisitExpr(op->value); // push to the scope - scope_.back().push_back(curr_stmt_); + if (!curr_stmt_.access.empty()) { + scope_.back().push_back(curr_stmt_); + } // clear access entry. curr_stmt_.access.clear();
607-609: Remove or implement commented-outanalyzer_.Bind()code.This commented code appears to be leftover from development. Either remove it or implement it with a TODO explaining when it should be enabled.
1136-1223: Remove unused debugging functionprint_access_tentry.This function appears to be dead code - it's defined but never called. If needed for future debugging, consider:
- Guarding with
#ifndef NDEBUG- Moving to a separate debug utilities file
- Converting
LOG(WARNING)toDLOGorVLOG
1372-1384: Consider narrowing the exception catch scope.Catching
std::exceptionis broad and may mask unrelated errors. If specific exceptions from the analyzer are expected, catch those specifically.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/thread_storage_sync.cc (2)
tilelang/language/tir/op.py (1)
tvm_storage_sync(534-547)src/transform/common/constr_visitor.h (1)
Substitute(81-83)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (2)
src/transform/thread_storage_sync.cc (2)
95-118: LGTM!The
ThreadSyncAfterWaitQueueInsertercorrectly handles async_wait_queue synchronization with appropriate validation checks.
1423-1439: LGTM!The
TileLangThreadSyncfunction correctly orchestrates the multi-stage synchronization pipeline: async wait queue insertion → planning → barrier insertion → partial sync rewriting.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)
1471-1472: Minor: Double semicolon typo.Line 1472 has an extra semicolon after the return statement.
🔧 Fix
- return tl::TileLangThreadSync(std::move(f), storage_scope); - ; + return tl::TileLangThreadSync(std::move(f), storage_scope);
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 414-419: The code dereferences min_node and extent_node from
iv->dom->min.as<IntImmNode>() / iv->dom->extent.as<IntImmNode>() without
checking for nullptr, which can crash for symbolic domain bounds; update the
logic in thread_storage_sync.cc (around the block computing int64_t
min/extent/max) to first test that min_node and extent_node are non-null and
handle the symbolic case (e.g., abort the transformation, return/skip, or emit a
diagnostic) instead of dereferencing; ensure any downstream use of
min/extent/max only occurs after these checks and clearly document the chosen
fallback behavior.
- Around line 201-211: The code unconditionally increments both
rw_stats_[buffer_var].read_count and .write_count for an address_of that wraps a
BufferLoadNode, which forces volatile_scope via InitGlobalBarrier; change the
logic to mirror tvm_access_ptr: inspect the address_of call's intent/flags (the
same flag bits used by tvm_access_ptr) and only increment read_count and/or
write_count according to those flags (or, if flags are absent, increment only
read_count), using sync_scope_, GetScope(buffer_var) and StorageRank::kGlobal to
gate the update to rw_stats_; update the address_of handling in the
Visit/transform function that processes BufferLoadNode/address_of so it no
longer unconditionally sets both counters.
- Around line 1126-1138: The loop assumes the vectors lhs.threads and
rhs.threads have at least 3 elements and indexes them with lhs.threads.size() +
idx - 3, which can underflow; add a precondition and guard: before the for-loop
in the function FindConflict ensure lhs.threads.size() >= 3 &&
rhs.threads.size() >= 3 (or compute a safe start index like size_t start =
(threads.size() > 3 ? threads.size() - 3 : 0) and iterate from 0 to min(3,
threads.size())), and apply the same safe-indexing change to the other similar
blocks referenced in FindConflict (the occurrences that index with size() + idx
- 3) so you never compute a negative/underflowing index and you only substitute
when the corresponding thread entries exist.
🧹 Nitpick comments (7)
src/transform/thread_storage_sync.cc (7)
23-71: Remove duplicate includes and using declarations.Several headers and using declarations are included multiple times:
<tvm/tir/expr.h>(lines 27, 43)<tvm/tir/stmt_functor.h>(lines 28, 44)<unordered_map>(lines 31, 46)runtime/thread_storage_scope.h(lines 39, 50)../op/builtin.h(lines 35, 59)tir/transforms/ir_utils.h(lines 40, 60)<utility>(lines 33, 57)using namespace tir;(lines 65, 71)This appears to be leftover from the refactoring. Please consolidate the includes.
♻️ Suggested cleanup
#include <tvm/ffi/function.h> #include <tvm/ffi/reflection/registry.h> +#include <tvm/arith/analyzer.h> +#include <tvm/arith/int_set.h> +#include <tvm/ir/attrs.h> +#include <tvm/target/target_info.h> #include <tvm/tir/analysis.h> #include <tvm/tir/builtin.h> #include <tvm/tir/expr.h> +#include <tvm/tir/op.h> #include <tvm/tir/stmt_functor.h> #include <tvm/tir/transform.h> +#include <fstream> +#include <string> #include <unordered_map> #include <unordered_set> #include <utility> +#include <vector> #include "../op/builtin.h" #include "./common/constr_visitor.h" #include "./common/thread_sync_types.h" #include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" -#include <tvm/arith/int_set.h> -#include <tvm/ir/attrs.h> -#include <tvm/tir/expr.h> -#include <tvm/tir/stmt_functor.h> - -#include <unordered_map> -#include <vector> - -#include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" -#include <tvm/arith/analyzer.h> -#include <tvm/target/target_info.h> -#include <tvm/tir/op.h> - -#include <fstream> -#include <string> -#include <utility> - -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h"
458-459: Confusing field naming:buffervsbuffer_name.The field
bufferholds aVar(the buffer's data variable), whilebuffer_nameholds the actualBufferobject. This naming is counterintuitive and could confuse maintainers.Consider renaming to clarify their purposes:
buffer→buffer_varordata_varbuffer_name→bufferorbuffer_obj
493-494: RedundantGetRefwhenVaris already available.
bufis already aVarobject, sotvm::ffi::GetRef<Var>(buf.get())is unnecessary. You can usebufdirectly.This pattern appears in multiple places (lines 494, 521, 779, 818, 859).
♻️ Suggested fix
- buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer); + buffer_data_to_buffer_.Set(buf, op->buffer);
1279-1290: Unused variablesprev_dtypeandcurr_dtype.These variables are declared at lines 1280-1281 but never used before the loop potentially breaks at line 1288. The same variables are re-declared in the subsequent loop at lines 1312-1313.
♻️ Remove unused declarations
for (size_t i = 0; i < prev.buffer_indices.size(); i++) { - auto prev_dtype = prev.dtype; - auto curr_dtype = curr.dtype; - const auto &prev_indice = prev.buffer_indices[i]; const auto &curr_indice = curr.buffer_indices[i]; if (!ExprDeepEqual()(prev_indice, curr_indice)) {
897-1074: Consider extracting helper functions fromSummarize.The
Summarizefunction spans ~180 lines with multiple distinct phases:
- Shared dynamic buffer normalization (lines 901-913)
- Forward dependency analysis (lines 920-967)
- Loop-carried dependency detection (lines 969-1029)
- Exposed entry computation (lines 1031-1073)
While the logic is correct, extracting these into helper functions would improve readability and testability.
610-612: Commented-out code: Consider removing or documenting why it's disabled.The
analyzer_.Bindcall is commented out. If this was intentionally disabled, please add a comment explaining why, or remove it if it's obsolete.
55-56: Remove unused include:<fstream>.The
<fstream>header is included but never used in this file. No file stream types (ifstream, ofstream, fstream) appear anywhere in the code.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/thread_storage_sync.cc (2)
src/transform/storage_rewrite.cc (19)
buffer_var(100-105)buffer_var(100-100)buffer_var(107-112)buffer_var(107-107)buffer_var(700-710)buffer_var(700-700)e(862-904)e(862-862)scope(678-683)scope(678-678)buf(249-256)buf(249-249)buf(509-526)buf(509-509)buf(1761-1786)buf(1761-1761)buffer(1425-1458)buffer(1426-1427)i(720-860)src/transform/common/constr_visitor.h (1)
Substitute(81-83)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (3)
src/transform/thread_storage_sync.cc (3)
95-118: LGTM!The
ThreadSyncAfterWaitQueueInsertercorrectly handles the async pipeline semantics by insertingsyncthreadsafterasync_wait_queueoperations. The comment block above clearly explains the motivation for this approach.
558-573: LGTM!The
LetStmtNodehandling correctly visits the value expression first, pushes the statement entry, then usesMakeGuardto establish the variable binding before visiting the body. This ensures proper constraint tracking for the let-bound variable.
1152-1239: [Rewritten review comment]
[Classification tag]
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
testing/python/issue/test_tilelang_issue_1026.py (1)
12-12: Inconsistent shape syntax: use tuple(32,)instead of(32).
(32)evaluates to an integer32, not a single-element tuple. For consistency with other tests in this PR (e.g.,test_tilelang_issue_1604.pyline 14 uses(32,)), use a proper tuple.Suggested fix
- shared_mem = T.alloc_shared((32), dtype="float32", scope="shared") + shared_mem = T.alloc_shared((32,), dtype="float32", scope="shared")testing/python/issue/test_tilelang_issue_1604.py (2)
7-8: Consider a more descriptive function name.The function name
qwqis not descriptive. For consistency with other tests in this directory (e.g.,get_shared_kernel,get_kernel), consider renaming to something likeget_issue_1604_kernelorget_conditional_shared_kernel.
29-29: Remove debug print statement.Debug
Suggested fix
def test_issue_1604(): kernel = qwq() - print(kernel.get_kernel_source()) target = "__syncthreads"
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
testing/python/issue/test_tilelang_issue_1026.pytesting/python/issue/test_tilelang_issue_1106.pytesting/python/issue/test_tilelang_issue_1604.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/issue/test_tilelang_issue_1106.pytesting/python/issue/test_tilelang_issue_1604.pytesting/python/issue/test_tilelang_issue_1026.py
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
testing/python/issue/test_tilelang_issue_1106.pytesting/python/issue/test_tilelang_issue_1604.pytesting/python/issue/test_tilelang_issue_1026.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/issue/test_tilelang_issue_1106.pytesting/python/issue/test_tilelang_issue_1604.py
📚 Learning: 2025-12-26T06:45:51.789Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:51.789Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.
Applied to files:
testing/python/issue/test_tilelang_issue_1604.pytesting/python/issue/test_tilelang_issue_1026.py
🧬 Code graph analysis (2)
testing/python/issue/test_tilelang_issue_1604.py (3)
tilelang/language/proxy.py (1)
Tensor(233-233)tilelang/language/kernel.py (1)
threads(215-219)tilelang/language/ast/ir.py (1)
target(1677-1707)
testing/python/issue/test_tilelang_issue_1026.py (3)
tilelang/language/kernel.py (1)
threads(215-219)tilelang/language/allocate.py (1)
alloc_shared(39-54)testing/python/issue/test_tilelang_issue_1604.py (1)
main(12-22)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (4)
testing/python/issue/test_tilelang_issue_1026.py (1)
20-22: LGTM!The test correctly validates that the reimplemented thread sync planner does not insert unnecessary
__syncthreadsfor this conditional shared memory access pattern.testing/python/issue/test_tilelang_issue_1604.py (1)
30-32: Test logic is appropriate for the issue.The regex pattern correctly validates that
__syncthreadsdoes not appear inside single-level if-blocks. Note that this pattern won't handle nested braces, but that's likely acceptable for the generated CUDA code being tested.testing/python/issue/test_tilelang_issue_1106.py (2)
6-28: LGTM!The kernel definition with disabled warp specialization and TMA lowering, along with the loop-based shared memory accumulation pattern, correctly exercises the thread sync planner. The test validates that the reimplemented
ThreadsyncwithConstrVisitorcorrectly determines that no synchronization is needed when each thread only accesses its own shared memory element.
31-34: LGTM!Test correctly validates the absence of
__syncthreadsfor this kernel pattern.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 349-354: The current branch where thread_count % 32 != 0 silently
returns Stmt(), dropping the sync; change this to either emit a full
__syncthreads() fallback or at minimum log/error instead of returning an empty
statement: locate the check using thread_count and the return Stmt() (and the
surrounding TODO) and replace the empty return with code that constructs and
returns a regular synchronization statement (e.g., a call node representing
__syncthreads()) or invoke the project's warning/error logger before returning,
so synchronization is not silently removed.
🧹 Nitpick comments (6)
src/transform/thread_storage_sync.cc (6)
41-60: Remove duplicate includes.This section duplicates many headers already included above (lines 23-40), including
<tvm/tir/expr.h>,<tvm/tir/stmt_functor.h>,<unordered_map>,"runtime/thread_storage_scope.h","../op/builtin.h","tir/transforms/ir_utils.h", and<utility>. While harmless due to include guards, these duplicates clutter the file.Proposed fix
#include "./common/constr_visitor.h" #include "./common/thread_sync_types.h" #include "arith/ir_mutator_with_analyzer.h" #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" -#include <tvm/arith/int_set.h> -#include <tvm/ir/attrs.h> -#include <tvm/tir/expr.h> -#include <tvm/tir/stmt_functor.h> - -#include <unordered_map> -#include <vector> - -#include "arith/ir_visitor_with_analyzer.h" -#include "runtime/thread_storage_scope.h" +#include <tvm/arith/analyzer.h> +#include <tvm/arith/int_set.h> +#include <tvm/ir/attrs.h> #include <tvm/target/target_info.h> #include <tvm/tir/op.h> -#include <tvm/arith/analyzer.h> +#include <vector> #include <fstream> #include <string> -#include <utility> -#include "../op/builtin.h" -#include "tir/transforms/ir_utils.h" +#include "arith/ir_visitor_with_analyzer.h"
65-72: Remove duplicateusing namespace tir;declaration.Line 71 duplicates line 65.
Proposed fix
using namespace tir; using namespace ffi; using arith::IRVisitorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; -using namespace tir; using arith::IRMutatorWithAnalyzer;
1279-1290: Unused variablesprev_dtypeandcurr_dtype.These variables are declared on lines 1280-1281 but never used within this loop iteration.
Proposed fix
for (size_t i = 0; i < prev.buffer_indices.size(); i++) { - auto prev_dtype = prev.dtype; - auto curr_dtype = curr.dtype; - const auto &prev_indice = prev.buffer_indices[i]; const auto &curr_indice = curr.buffer_indices[i];
1396-1397: Consider reducing log verbosity for non-scalar indices.This
LOG(WARNING)is emitted whenever indices are non-scalar, which may be common in certain workloads and produce excessive log noise in production.Proposed fix: use DLOG or VLOG for debug-only logging
} else { - LOG(WARNING) << "Unscalar: " << prev_indice_bytes << "; " - << curr_indice_bytes; + DLOG(INFO) << "Unscalar indices in conflict detection: " + << prev_indice_bytes << "; " << curr_indice_bytes;
1152-1239: Debug utility usesLOG(WARNING)for output.
print_access_tentryappears to be a debugging utility, but it usesLOG(WARNING)which will emit output in production builds. Consider usingDLOG(INFO)or gating the output behind a debug flag.Proposed fix
- LOG(WARNING) << output.str(); + DLOG(INFO) << output.str();
1493-1494: Extra semicolon.Line 1494 has a double semicolon
;;after the return statement.Proposed fix
return tl::TileLangThreadSync(std::move(f), storage_scope); - ;
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:31.685Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.
📚 Learning: 2026-01-12T07:25:31.685Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:31.685Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.
Applied to files:
src/transform/thread_storage_sync.cc
🧬 Code graph analysis (1)
src/transform/thread_storage_sync.cc (2)
tilelang/language/tir/op.py (1)
tvm_storage_sync(534-547)src/transform/common/constr_visitor.h (1)
Substitute(81-83)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (7)
src/transform/thread_storage_sync.cc (7)
95-118: LGTM!The
ThreadSyncAfterWaitQueueInserterclass is well-documented and correctly handles async wait queue semantics by inserting synchronization afterasync_wait_queue_scopeattributes.
201-214: Suspicious:address_ofincrements both read and write counts unconditionally.When handling
address_ofwith aBufferLoadNode, the code increments bothread_count(line 205) andwrite_count(line 209) under identical conditions. Typically, taking the address of a buffer for a load operation should only count as a read unless there's a known pattern where the address is used for both reading and writing.Is this intentional? If so, consider adding a clarifying comment. If not, the write increment may be erroneous.
435-486: LGTM!The
TileLangThreadSyncPlannerclass structure and data model (AccessEntry,StmtEntry) are well-designed for tracking buffer accesses with thread, constraint, and scope information.
663-728: LGTM!The
IfThenElsehandling is well-documented and correctly:
- Preserves condition expression accesses for dependency analysis.
- Applies constraints via
MakeGuardfor both branches.- Summarizes and merges accesses from both branches.
1106-1151: LGTM!The
PointerAccessIsDisjointfunction correctly renames thread variables to avoid conflicts and uses the analyzer to prove range disjointness. Based on learnings, the threads array access pattern at lines 1128-1129 is safe as the array is guaranteed to contain all three thread dimensions at this IR level.
1311-1314: Duplicate unused variable declarations.
prev_dtypeandcurr_dtypeare declared again and are used later in this loop (lines 1318-1319, 1400-1405), so these are valid. However, they shadow the identical declarations in the previous loop (lines 1280-1281) which are unused.
1461-1477: LGTM!The
TileLangThreadSyncfunction correctly orchestrates the multi-pass synchronization pipeline: inserting syncs after wait queues, planning sync points, inserting barriers, and rewriting partial syncs.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In @src/transform/thread_storage_sync.cc:
- Around line 398-403: The code dereferences min_node and extent_node obtained
via iv->dom->min.as<IntImmNode>() / iv->dom->extent.as<IntImmNode>() without
null checks; add checks after these casts (e.g., test if min_node or extent_node
is nullptr) and handle non-IntImm cases (return early, emit a diagnostic, or
compute/propagate symbolic bounds) before using min/extent to compute max,
updating the logic in the function that contains this snippet (referencing iv,
iv->dom, min_node, extent_node) so you never dereference a null pointer.
- Around line 333-338: The early return that yields an empty Stmt() when
thread_count % 32 != 0 silently removes the synchronization and can drop needed
syncs; instead preserve/emit the original sync: detect this branch in the block
that checks "if (thread_count % 32 != 0)" and return the original sync statement
(or re-emit the input Stmt representing the sync) rather than Stmt(), or at
minimum log/warn about dropping the sync before returning; locate the code that
constructs the sync (the surrounding transform that uses thread_count and
currently returns Stmt()) and change the return to the original sync Stmt (or
emit a warning then return it) so synchronization is not silently lost.
🧹 Nitpick comments (4)
src/transform/thread_storage_sync.cc (4)
1263-1274: Unused variablesprev_dtypeandcurr_dtypein loop.These variables are declared but never used in this loop iteration. They appear again in the subsequent loop starting at line 1295 where they're also unused.
Proposed cleanup
for (size_t i = 0; i < prev.buffer_indices.size(); i++) { - auto prev_dtype = prev.dtype; - auto curr_dtype = curr.dtype; - const auto &prev_indice = prev.buffer_indices[i]; const auto &curr_indice = curr.buffer_indices[i];
1136-1223: Debug functionprint_access_tentryappears unused.This function is defined but never called anywhere in the file. If it's intended for debugging, consider removing it or guarding it with a debug flag to avoid code bloat.
816-830: Lambda capturesthisunnecessarily.The
linear_to_indiceslambda capturesthisbut doesn't use any member variables or methods. This can be simplified to a capture-less lambda.Proposed fix
- auto linear_to_indices = [this](PrimExpr offset, + auto linear_to_indices = [](PrimExpr offset, const Array<PrimExpr> &shape) {
1477-1478: Extra semicolon.Minor typo: double semicolon at line 1478.
Proposed fix
return tl::TileLangThreadSync(std::move(f), storage_scope); - ;
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:31.685Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.
📚 Learning: 2026-01-12T07:25:31.685Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:31.685Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.
Applied to files:
src/transform/thread_storage_sync.cc
🔇 Additional comments (2)
src/transform/thread_storage_sync.cc (2)
1445-1461: LGTM!The pass composition correctly chains the synchronization analysis and insertion steps:
ThreadSyncAfterWaitQueueInserter→TileLangThreadSyncPlanner→ThreadSyncInserter→ThreadPartialSyncRewriter. The buffer map initialization before planning ensures proper buffer resolution.
88-90: Potential null pointer dereference before ICHECK.If
op->body.as<AttrStmtNode>()returnsnullptr, accessinginner->attr_keyon line 90 is undefined behavior beforeICHECKcan evaluate.Proposed fix
auto inner = op->body.as<AttrStmtNode>(); - ICHECK(inner && - inner->attr_key == tvm::tir::attr::async_wait_inflight_count); + ICHECK(inner) << "Expected AttrStmtNode inside async_wait_queue_scope"; + ICHECK(inner->attr_key == tvm::tir::attr::async_wait_inflight_count);⛔ Skipped due to learnings
Learnt from: silentCoder-dev Repo: tile-ai/tilelang PR: 1631 File: src/transform/thread_storage_sync.cc:1126-1137 Timestamp: 2026-01-12T07:25:31.685Z Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.
| if (thread_count % 32 != 0) { | ||
| // TODO(lei): This is a workaround for the case where the thread count is | ||
| // not a multiple of 32. we should enhance the pass to analysis index | ||
| // instead of buffer expression etc. | ||
| return Stmt(); | ||
| } |
There was a problem hiding this comment.
Returning empty Stmt() may silently remove necessary synchronization.
When thread_count % 32 != 0, returning an empty statement silently drops the sync call. This could lead to race conditions if the sync was actually needed. Consider keeping the original sync call or emitting a warning.
Proposed fix
if (thread_count % 32 != 0) {
// TODO(lei): This is a workaround for the case where the thread count is
// not a multiple of 32. we should enhance the pass to analysis index
// instead of buffer expression etc.
- return Stmt();
+ // Fall back to the original sync call to preserve correctness
+ return Evaluate(IRMutatorWithAnalyzer::VisitExpr_(op));
}🤖 Prompt for AI Agents
In @src/transform/thread_storage_sync.cc around lines 333 - 338, The early
return that yields an empty Stmt() when thread_count % 32 != 0 silently removes
the synchronization and can drop needed syncs; instead preserve/emit the
original sync: detect this branch in the block that checks "if (thread_count %
32 != 0)" and return the original sync statement (or re-emit the input Stmt
representing the sync) rather than Stmt(), or at minimum log/warn about dropping
the sync before returning; locate the code that constructs the sync (the
surrounding transform that uses thread_count and currently returns Stmt()) and
change the return to the original sync Stmt (or emit a warning then return it)
so synchronization is not silently lost.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)
1474-1476: Minor: Double semicolon at end of return statement.There's a redundant semicolon after the return statement.
🔧 Fix
return tl::TileLangThreadSync(std::move(f), storage_scope); - ; };
🧹 Nitpick comments (3)
src/transform/thread_storage_sync.cc (3)
1377-1378: Consider reducing log verbosity in conflict detection hot path.The
LOG(WARNING)for unscalar indices could be noisy in production. Consider usingDLOG(debug-only logging) or gating behind a verbosity flag to avoid performance impact in release builds.♻️ Suggested change
} else { - LOG(WARNING) << "Unscalar: " << prev_indice_bytes << "; " - << curr_indice_bytes; + DLOG(INFO) << "Unscalar indices in conflict detection: " + << prev_indice_bytes << "; " << curr_indice_bytes;
1133-1220: Debug utility function uses WARNING log level and has naming inconsistency.The
print_access_tentryfunction:
- Uses
LOG(WARNING)which will emit in production - considerDLOGorVLOG- Has a naming inconsistency: "tentry" vs "entry" in the function name
Since this appears to be a debug utility, these are minor polish items.
813-827: Lambda capturesthisbut doesn't use instance members.The
linear_to_indiceslambda capturesthisbut only usesmake_const, which is a free function. Consider using[]capture for clarity.♻️ Suggested change
- auto linear_to_indices = [this](PrimExpr offset, + auto linear_to_indices = [](PrimExpr offset, const Array<PrimExpr> &shape) {
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/thread_storage_sync.cc
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:31.685Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.
📚 Learning: 2026-01-12T07:25:31.685Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:31.685Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.
Applied to files:
src/transform/thread_storage_sync.cc
🧬 Code graph analysis (1)
src/transform/thread_storage_sync.cc (2)
tilelang/language/tir/op.py (1)
tvm_storage_sync(534-547)src/transform/common/constr_visitor.h (1)
Substitute(81-83)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (8)
src/transform/thread_storage_sync.cc (8)
76-99: LGTM!The
ThreadSyncAfterWaitQueueInsertercorrectly handles async wait queue synchronization by inserting storage sync calls after async wait operations. The nested attribute handling is properly structured.
182-191: Conservative read/write tracking foraddress_ofis reasonable.The code increments both
read_countandwrite_countwhen taking the address of a buffer. This is conservative since a pointer obtained viaaddress_ofcould be used for either reading or writing. The approach ensures correctness at the cost of potentially more synchronization.
416-467: LGTM!The
AccessEntryandStmtEntrydata structures are well-designed for tracking per-statement memory access patterns. The fields comprehensively capture buffer access metadata needed for conflict detection.
607-613: Verify:hand_threadedattribute handler skips all child processing.The handler for
hand_threadedblocks has an empty body, which means neither the attribute's body nor its children are visited. This is intentional per the comment, but could miss synchronization requirements if hand-threaded kernels interact with automatic threading.Ensure this behavior is tested and documented for mixed hand-threaded/automatic threading scenarios.
1107-1119: Thread array indexing relies on guaranteed dimensions.The indexing
threads[threads.size() + idx - 3]accesses the last 3 elements of the threads array. Based on learnings, the threads array is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z) at this IR level, making this access pattern safe.
1442-1458: LGTM!The
TileLangThreadSyncfunction cleanly orchestrates the multi-pass synchronization pipeline: async wait queue insertion, planning viaTileLangThreadSyncPlanner, sync insertion viaThreadSyncInserter, and final rewriting viaThreadPartialSyncRewriter.
878-1055: LGTM!The
Summarizefunction implements sophisticated logic for determining sync placement, including:
- Shared dynamic buffer normalization
- Loop-carried dependency detection
- Intelligent sync hoisting (outside loop when no in-scope reads exist)
- Proper propagation of exposed access entries
The approach correctly balances correctness with performance by minimizing redundant per-iteration synchronization.
668-733: LGTM!The
IfThenElseNodeandWhileNodevisitors correctly:
- Use constraint guards to track branch conditions
- Preserve condition expression accesses for dependency analysis
- Merge access summaries from all branches
- Ensure syncs are inserted before the statement when condition reads conflict with prior writes
| const auto *min_node = iv->dom->min.as<IntImmNode>(); | ||
| const auto *extent_node = iv->dom->extent.as<IntImmNode>(); | ||
|
|
||
| int64_t min = min_node->value; | ||
| int64_t extent = extent_node->value; | ||
| int64_t max = min + extent - 1; |
There was a problem hiding this comment.
Potential null pointer dereference if domain bounds are not integer immediates.
The code uses as<IntImmNode>() without checking for null before dereferencing. If iv->dom->min or iv->dom->extent are not IntImmNode (e.g., symbolic expressions), this will cause a null pointer dereference.
🐛 Proposed fix
const auto *min_node = iv->dom->min.as<IntImmNode>();
const auto *extent_node = iv->dom->extent.as<IntImmNode>();
+ if (min_node == nullptr || extent_node == nullptr) {
+ // Cannot determine bounds statically; assume full extent conservatively
+ return true;
+ }
+
int64_t min = min_node->value;
int64_t extent = extent_node->value;
int64_t max = min + extent - 1;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| const auto *min_node = iv->dom->min.as<IntImmNode>(); | |
| const auto *extent_node = iv->dom->extent.as<IntImmNode>(); | |
| int64_t min = min_node->value; | |
| int64_t extent = extent_node->value; | |
| int64_t max = min + extent - 1; | |
| const auto *min_node = iv->dom->min.as<IntImmNode>(); | |
| const auto *extent_node = iv->dom->extent.as<IntImmNode>(); | |
| if (min_node == nullptr || extent_node == nullptr) { | |
| // Cannot determine bounds statically; assume full extent conservatively | |
| return true; | |
| } | |
| int64_t min = min_node->value; | |
| int64_t extent = extent_node->value; | |
| int64_t max = min + extent - 1; |
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.