[Bugfix] Fix thread storage sync conflict detection for loop carry write-after-read#1781
[Bugfix] Fix thread storage sync conflict detection for loop carry write-after-read#1781LeiWang1999 merged 16 commits intomainfrom
Conversation
…to correctly identify conflicts between read and write operations based on loop carry conditions.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds loop-aware and runtime-dependent analysis to thread-storage sync planning: Z3 AllSAT-based thread-extent counting with range fallback, loop-aware substitution for precise loop-carried conflict detection, hoisting decisions for IfThenElse based on runtime-dependent conditions, removes a double-buffer flag, and exposes constraint accessor; expands CUDA tests and updates examples. Changes
Sequence Diagram(s)sequenceDiagram
participant Runner as "Pass Runner"
participant ThreadSync as "ThreadSync Pass\n(src/transform/thread_storage_sync.cc)"
participant Constr as "Constraint Engine\n(ConstrVisitor / Z3 / AllSAT)"
participant IR as "IR (For / If / AccessEntries)"
Runner->>ThreadSync: invoke ThreadSync(mod)
ThreadSync->>IR: traverse For/If, collect AccessEntries
ThreadSync->>Constr: GetConstrSet() and request runtime checks
ThreadSync->>Constr: run AllSAT for thread-index extents
Constr-->>ThreadSync: extent / satisfiability / equivalence results
ThreadSync->>IR: apply loop-aware substitution, decide hoist vs keep sync
ThreadSync-->>Runner: return transformed IR
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 1441-1444: The existing comment about double-buffer reads is
stale; update the comment above the branch that checks if (curr.type == kWrite
&& prev.type == kRead && loop_carry) to reflect that this is a conservative
loop-carried WAR detection: explain that when a read in the prior iteration
(prev.type == kRead) precedes a write in the current iteration (curr.type ==
kWrite) and loop_carry is true, the code unconditionally treats it as a conflict
(regardless of range_is_overlap). Replace the old double-buffer wording with a
concise note such as: "Loop-carried WAR: treat a read in the prior iteration
followed by a write in the current iteration as a conflict (conservative
check)."
…ared memory for `sorted_token_ids` instead of local memory, improving thread synchronization. Adjust default argument values for M, N, and K in the main function for better testing scenarios.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py`:
- Around line 331-363: The test test_loop_carry_modulo_buffering currently only
prints the transformed IR (s) and lacks a validation; replace the print with an
assertion that the ThreadSync transform did not insert a synchronization call —
e.g., after running mod = tilelang.transform.ThreadSync("shared")(mod) and
computing s = str(mod), assert that the serialized module does not contain
synchronization markers like "thread_sync" (or "tir.thread_sync") to validate
that func's double-buffering pattern avoided a barrier; use the existing symbols
test_loop_carry_modulo_buffering, func, ThreadSync("shared"), mod, and s to
locate and modify the code.
🧹 Nitpick comments (1)
testing/python/transform/test_tilelang_transform_thread_sync.py (1)
292-329: Documentation is misleading - this is a same-iteration cross-thread dependency, not loop-carried.The test correctly expects a barrier, but the documentation is inaccurate. The dependency here is within the same iteration: thread
txwritesA[tx]while thread(tx+1)%128readsA[tx]in the same iteration. This is a cross-thread WAR hazard that requires synchronization regardless of the loop.The docstring and test name suggest this tests "loop-carried dependency," but the actual scenario being tested is same-iteration cross-thread access. Consider renaming to
test_same_iteration_cross_thread_dependencyand updating the docstring.📝 Suggested documentation fix
`@tilelang.testing.requires_cuda` -def test_loop_carry_with_cross_thread_dependency(): - """Test loop-carried dependency where different threads access overlapping locations. +def test_same_iteration_cross_thread_dependency(): + """Test same-iteration dependency where different threads access overlapping locations. In this test: - Thread tx writes to A[tx] - Then reads from A[(tx + 127) % 128] (neighbor's data from previous iteration) - - After iteration shift analysis, we compare: - - Iteration i: thread tx writes A[tx] - - Iteration i+1: thread tx reads A[(tx + 127) % 128] - - This creates a cross-thread dependency where thread tx+1's write conflicts - with thread tx's read in the next iteration, requiring a barrier. + + This creates a cross-thread WAR hazard within the same iteration: + thread (tx+1)%128 writes A[(tx+1)%128] while thread tx reads A[(tx+127)%128] = A[tx-1 mod 128]. + Since different threads access overlapping locations, a barrier is required. """
| @tilelang.testing.requires_cuda | ||
| def test_loop_carry_modulo_buffering(): | ||
| """Test that A[i%2] write followed by A[i%2] read does NOT need barrier (double buffering). | ||
|
|
||
| After iteration shift analysis: | ||
| - Iteration i writes A[i%2] | ||
| - Iteration i+1 reads A[(i+1)%2] (shifted from A[i%2]) | ||
| - A[i%2] vs A[(i+1)%2] are disjoint (0 vs 1 or 1 vs 0), so no dependency | ||
| """ | ||
|
|
||
| @T.prim_func(private=True) | ||
| def func(): | ||
| temp_shared = T.alloc_buffer([2, 64], dtype="float32", scope="shared") | ||
| result_local = T.alloc_buffer([1], dtype="float32", scope="local") | ||
| bx = T.launch_thread("blockIdx.x", 1) | ||
| tx = T.launch_thread("threadIdx.x", 64) | ||
| ty = T.launch_thread("threadIdx.y", 1) | ||
| tz = T.launch_thread("threadIdx.z", 1) | ||
| result_local[0] = T.float32(0) | ||
| for i in range(10): | ||
| # Double buffering pattern: write to buffer[i%2], read from buffer[i%2] | ||
| # After shift: write buffer[i%2], read buffer[(i+1)%2] | ||
| # These are different buffers, so no conflict | ||
| temp_shared[i % 2, tx] = T.float32(i) | ||
| result_local[0] = result_local[0] + temp_shared[i % 2, tx] | ||
|
|
||
| mod = tvm.IRModule({"main": func}) | ||
| mod = tilelang.transform.ThreadSync("shared")(mod) | ||
| s = str(mod) | ||
| # Should NOT have sync inside loop due to modulo buffering analysis | ||
| # Note: This test verifies the modulo analysis capability | ||
| print(f"Modulo buffering result:\n{s}") | ||
|
|
There was a problem hiding this comment.
Missing assertion - test does not validate expected behavior.
The test documents the expected behavior (no sync due to modulo buffering) in the docstring but only has a print statement at line 362 without any assertion. This means the test will always pass regardless of actual synchronization behavior.
🧪 Add assertion to validate expected behavior
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
# Should NOT have sync inside loop due to modulo buffering analysis
- # Note: This test verifies the modulo analysis capability
- print(f"Modulo buffering result:\n{s}")
+ assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync with modulo buffering:\n{s}"🤖 Prompt for AI Agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py` around lines
331 - 363, The test test_loop_carry_modulo_buffering currently only prints the
transformed IR (s) and lacks a validation; replace the print with an assertion
that the ThreadSync transform did not insert a synchronization call — e.g.,
after running mod = tilelang.transform.ThreadSync("shared")(mod) and computing s
= str(mod), assert that the serialized module does not contain synchronization
markers like "thread_sync" (or "tir.thread_sync") to validate that func's
double-buffering pattern avoided a barrier; use the existing symbols
test_loop_carry_modulo_buffering, func, ThreadSync("shared"), mod, and s to
locate and modify the code.
Introduce the UniformExprChecker class to determine if expressions are uniform across threads, crucial for safe synchronization in conditional statements. Update the TileLangThreadSyncPlanner to hoist synchronization points out of non-uniform if-statements to prevent potential deadlocks. Enhance tests to validate sync hoisting behavior for various non-uniform conditions involving thread indices and shared memory access.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)
739-831:⚠️ Potential issue | 🟠 MajorHoisted syncs leave stale sync markers in
s.access.When non-uniform hoisting triggers, you erase syncs from
syncs_inserted_but keeps.accesscomputed earlier with those syncs included. That can suppress downstream conflict detection because outer summaries will treat the if as containing a sync that no longer exists. This is a correctness risk (missed sync insertion).Please recompute the access summary after hoisting or ensure the hoisted path strips only the inserted sync markers from
s.accessbefore propagating it.
🤖 Fix all issues with AI agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py`:
- Around line 365-395: Add an assertion to fail the test if the ThreadSync
transform incorrectly inserts synchronization for the disjoint-index case: after
creating mod = tilelang.transform.ThreadSync("shared")(mod) and computing s =
str(mod), assert that the generated IR does NOT contain the thread-sync/barrier
markers (for example assert "tvm_thread_sync" not in s and assert "tir.barrier"
not in s); this verifies test_loop_carry_different_indices and the
ThreadSync("shared") output instead of just printing the IR.
- Around line 544-571: The test test_sync_hoist_non_uniform_if_in_loop currently
only asserts a sync exists; update it to also assert the sync is placed inside
the loop before the non-uniform if by checking the string position: find the
index of 'for k in range(2):' (or locate the loop body start in s), then ensure
s.index('T.tvm_storage_sync("shared")') is greater than the loop start index but
less than s.index('if token_ids[tx]') (or directly assert the sync index is less
than the if index), so the storage sync for "shared" is hoisted into the loop
and appears before the non-uniform if that references token_ids and data_shared.
🧹 Nitpick comments (1)
src/transform/thread_storage_sync.cc (1)
1354-1608: Consider shifting loop-dependent constraints alongside indices.The loop-carry path shifts
currindices but leavescurr.cset(andcurr.touchedin the non-scalar fallback) unshifted. If a guard depends on the loop var (e.g.,i % 2), this can over-report conflicts and reduce the precision you just gained from index shifting.A light refactor to substitute
loop_var -> loop_var + stepintocurrconstraints (and the touched-range fallback) will keep the analysis consistent.♻️ Suggested refinement
- PrimExpr curr_constr = curr.cset.ToConjunction(); + PrimExpr curr_constr = curr.cset.ToConjunction(); + if (loop != nullptr) { + curr_constr = Substitute(curr_constr, loop_shift_sub); + }- auto curr_min = analyzer.Simplify( - Substitute(curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); - auto curr_max = analyzer.Simplify( - Substitute(curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); + auto curr_min = analyzer.Simplify(Substitute( + curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); + auto curr_max = analyzer.Simplify(Substitute( + curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); + if (loop != nullptr) { + curr_min = Substitute(curr_min, loop_shift_sub); + curr_max = Substitute(curr_max, loop_shift_sub); + }
| @tilelang.testing.requires_cuda | ||
| def test_loop_carry_different_indices(): | ||
| """Test that A[i] write followed by A[i+1] read does NOT need barrier. | ||
|
|
||
| After iteration shift analysis: | ||
| - Iteration i writes A[i] | ||
| - Iteration i+1 reads A[i+2] (shifted from A[i+1], becomes A[(i+1)+1] = A[i+2]) | ||
| - A[i] vs A[i+2] are disjoint, so no loop-carried dependency | ||
| """ | ||
|
|
||
| @T.prim_func(private=True) | ||
| def func(): | ||
| temp_shared = T.alloc_buffer([128], dtype="float32", scope="shared") | ||
| result_local = T.alloc_buffer([1], dtype="float32", scope="local") | ||
| bx = T.launch_thread("blockIdx.x", 1) | ||
| tx = T.launch_thread("threadIdx.x", 1) | ||
| ty = T.launch_thread("threadIdx.y", 1) | ||
| tz = T.launch_thread("threadIdx.z", 1) | ||
| result_local[0] = T.float32(0) | ||
| for i in range(10): | ||
| # Write to A[i], read from A[i+1] | ||
| # After shift: comparing A[i] (write) vs A[i+2] (read from i+1 shifted) | ||
| # No overlap, no dependency | ||
| temp_shared[i] = T.float32(i) | ||
| result_local[0] = result_local[0] + temp_shared[i + 1] | ||
|
|
||
| mod = tvm.IRModule({"main": func}) | ||
| mod = tilelang.transform.ThreadSync("shared")(mod) | ||
| s = str(mod) | ||
| print(f"Different indices result:\n{s}") | ||
|
|
There was a problem hiding this comment.
Add an assertion to validate the “different indices” case.
The test currently only prints, so it can’t fail if the transform regresses.
🧪 Add assertion
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
- print(f"Different indices result:\n{s}")
+ assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync:\n{s}"Based on learnings, tests in testing/python/transform should assert structural patterns in generated IR rather than rely on prints or numeric literals.
🤖 Prompt for AI Agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py` around lines
365 - 395, Add an assertion to fail the test if the ThreadSync transform
incorrectly inserts synchronization for the disjoint-index case: after creating
mod = tilelang.transform.ThreadSync("shared")(mod) and computing s = str(mod),
assert that the generated IR does NOT contain the thread-sync/barrier markers
(for example assert "tvm_thread_sync" not in s and assert "tir.barrier" not in
s); this verifies test_loop_carry_different_indices and the ThreadSync("shared")
output instead of just printing the IR.
| @tilelang.testing.requires_cuda | ||
| def test_sync_hoist_non_uniform_if_in_loop(): | ||
| """Test sync hoisting when non-uniform if is inside a loop.""" | ||
|
|
||
| @T.prim_func(private=True) | ||
| def func(): | ||
| token_ids = T.alloc_buffer([128], dtype="int32", scope="shared") | ||
| data_shared = T.alloc_buffer([128], dtype="float32", scope="shared") | ||
| result_local = T.alloc_buffer([1], dtype="float32", scope="local") | ||
| bx = T.launch_thread("blockIdx.x", 1) | ||
| tx = T.launch_thread("threadIdx.x", 128) | ||
| ty = T.launch_thread("threadIdx.y", 1) | ||
| tz = T.launch_thread("threadIdx.z", 1) | ||
| result_local[0] = T.float32(0) | ||
| for k in range(2): | ||
| # Write to shared memory | ||
| data_shared[tx] = T.float32(tx + k) | ||
| # Non-uniform if inside loop | ||
| if token_ids[tx] != -1: | ||
| result_local[0] = result_local[0] + data_shared[tx] | ||
|
|
||
| mod = tvm.IRModule({"main": func}) | ||
| mod = tilelang.transform.ThreadSync("shared")(mod) | ||
| s = str(mod) | ||
| assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" | ||
| # Sync should be before the if inside the loop, not inside the if | ||
| # This ensures all threads can reach the sync point | ||
|
|
There was a problem hiding this comment.
Add a position check to confirm the sync is hoisted inside the loop.
This test asserts a sync exists but doesn’t verify it is placed before the non-uniform if inside the loop (vs. outside the loop). A simple index check will tighten it.
🧪 Add position assertion
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}"
# Sync should be before the if inside the loop, not inside the if
# This ensures all threads can reach the sync point
+ loop_pos = s.index("for k in range(2)")
+ sync_pos = s.index('T.tvm_storage_sync("shared")')
+ if_pos = s.index("if token_ids")
+ assert loop_pos < sync_pos < if_pos, f"Sync should be inside loop and before if:\n{s}"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @tilelang.testing.requires_cuda | |
| def test_sync_hoist_non_uniform_if_in_loop(): | |
| """Test sync hoisting when non-uniform if is inside a loop.""" | |
| @T.prim_func(private=True) | |
| def func(): | |
| token_ids = T.alloc_buffer([128], dtype="int32", scope="shared") | |
| data_shared = T.alloc_buffer([128], dtype="float32", scope="shared") | |
| result_local = T.alloc_buffer([1], dtype="float32", scope="local") | |
| bx = T.launch_thread("blockIdx.x", 1) | |
| tx = T.launch_thread("threadIdx.x", 128) | |
| ty = T.launch_thread("threadIdx.y", 1) | |
| tz = T.launch_thread("threadIdx.z", 1) | |
| result_local[0] = T.float32(0) | |
| for k in range(2): | |
| # Write to shared memory | |
| data_shared[tx] = T.float32(tx + k) | |
| # Non-uniform if inside loop | |
| if token_ids[tx] != -1: | |
| result_local[0] = result_local[0] + data_shared[tx] | |
| mod = tvm.IRModule({"main": func}) | |
| mod = tilelang.transform.ThreadSync("shared")(mod) | |
| s = str(mod) | |
| assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" | |
| # Sync should be before the if inside the loop, not inside the if | |
| # This ensures all threads can reach the sync point | |
| `@tilelang.testing.requires_cuda` | |
| def test_sync_hoist_non_uniform_if_in_loop(): | |
| """Test sync hoisting when non-uniform if is inside a loop.""" | |
| `@T.prim_func`(private=True) | |
| def func(): | |
| token_ids = T.alloc_buffer([128], dtype="int32", scope="shared") | |
| data_shared = T.alloc_buffer([128], dtype="float32", scope="shared") | |
| result_local = T.alloc_buffer([1], dtype="float32", scope="local") | |
| bx = T.launch_thread("blockIdx.x", 1) | |
| tx = T.launch_thread("threadIdx.x", 128) | |
| ty = T.launch_thread("threadIdx.y", 1) | |
| tz = T.launch_thread("threadIdx.z", 1) | |
| result_local[0] = T.float32(0) | |
| for k in range(2): | |
| # Write to shared memory | |
| data_shared[tx] = T.float32(tx + k) | |
| # Non-uniform if inside loop | |
| if token_ids[tx] != -1: | |
| result_local[0] = result_local[0] + data_shared[tx] | |
| mod = tvm.IRModule({"main": func}) | |
| mod = tilelang.transform.ThreadSync("shared")(mod) | |
| s = str(mod) | |
| assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" | |
| # Sync should be before the if inside the loop, not inside the if | |
| # This ensures all threads can reach the sync point | |
| loop_pos = s.index("for k in range(2)") | |
| sync_pos = s.index('T.tvm_storage_sync("shared")') | |
| if_pos = s.index("if token_ids") | |
| assert loop_pos < sync_pos < if_pos, f"Sync should be inside loop and before if:\n{s}" |
🤖 Prompt for AI Agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py` around lines
544 - 571, The test test_sync_hoist_non_uniform_if_in_loop currently only
asserts a sync exists; update it to also assert the sync is placed inside the
loop before the non-uniform if by checking the string position: find the index
of 'for k in range(2):' (or locate the loop body start in s), then ensure
s.index('T.tvm_storage_sync("shared")') is greater than the loop start index but
less than s.index('if token_ids[tx]') (or directly assert the sync index is less
than the if index), so the storage sync for "shared" is hoisted into the loop
and appears before the non-uniform if that references token_ids and data_shared.
… disabling and kernel source printing for debugging. Update thread synchronization logic in `thread_storage_sync.cc` to check for runtime-dependent conditions, preventing potential deadlocks by hoisting sync points as necessary.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)
1362-1526:⚠️ Potential issue | 🔴 CriticalLoop‑carry constraints aren’t shifted, which can miss real conflicts.
You shift
currindices for loop‑carry, butcurr.csetand the fallbackcurr.touchedexpressions are still evaluated in the current iteration. If branch guards depend onloop_var(e.g.,i % 2), the equivalence and overlap checks run against the wrong iteration and can miss dependencies, leading to missing syncs.🩹 Minimal fix: apply the loop shift to constraints and touched ranges
- PrimExpr prev_constr = prev.cset.ToConjunction(); - PrimExpr curr_constr = curr.cset.ToConjunction(); + PrimExpr prev_constr = prev.cset.ToConjunction(); + ConstrSet curr_cset = curr.cset; + if (loop != nullptr) { + curr_cset = curr_cset.Substitute(loop_shift_sub); + } + PrimExpr curr_constr = curr_cset.ToConjunction(); ... - ConstrSet curr_cset{curr.cset}; + ConstrSet curr_cset{curr.cset}; + if (loop != nullptr) { + curr_cset = curr_cset.Substitute(loop_shift_sub); + } ... - auto curr_min = analyzer.Simplify( - Substitute(curr.touched[i].min() * curr_dtype.bytes(), curr_sub)); - auto curr_max = analyzer.Simplify( - Substitute(curr.touched[i].max() * curr_dtype.bytes(), curr_sub)); + auto curr_min = analyzer.Simplify(Substitute( + Substitute(curr.touched[i].min() * curr_dtype.bytes(), + loop_shift_sub), + curr_sub)); + auto curr_max = analyzer.Simplify(Substitute( + Substitute(curr.touched[i].max() * curr_dtype.bytes(), + loop_shift_sub), + curr_sub));
🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 742-786: The hoist path removes entries from syncs_inserted_ (via
syncs_in_then/syncs_in_else) but does not update the previously computed access
summary s.access, leaving stale kSync entries added during Summarize and causing
incorrect later conflict suppression; after erasing the sync pointers (and
before calling insert_syncs(op)) update the summary to reflect the hoist by
either (a) recomputing the summary via the Summarize routine for the affected
statement/op, or (b) explicitly removing kSync entries in s.access that
correspond to the removed sync objects (syncs_in_then and syncs_in_else), so the
access summary matches the actual sync placement.
| // Check if any syncs were inserted inside the if-then-else | ||
| std::vector<const Object *> syncs_in_then; | ||
| std::vector<const Object *> syncs_in_else; | ||
|
|
||
| for (const auto &sync : syncs_inserted_) { | ||
| if (syncs_before_then.count(sync) == 0 && | ||
| syncs_before_else.count(sync) != 0) { | ||
| // Sync was inserted during then branch processing | ||
| syncs_in_then.push_back(sync); | ||
| } else if (syncs_before_else.count(sync) == 0) { | ||
| // Sync was inserted during else branch processing | ||
| syncs_in_else.push_back(sync); | ||
| } | ||
| } | ||
|
|
||
| bool has_syncs_inside = !syncs_in_then.empty() || !syncs_in_else.empty(); | ||
|
|
||
| if (has_syncs_inside) { | ||
| // Check if the condition depends on runtime values (e.g., shared memory | ||
| // loads). If so, we cannot determine at compile time how many threads | ||
| // will enter the if, so we must hoist the sync to before the if to avoid | ||
| // potential deadlock. | ||
| // | ||
| // If the condition only depends on threadIdx (e.g., `threadIdx.x >= | ||
| // 512`), ThreadPartialSyncRewriter can compute the exact thread count at | ||
| // compile time, so the sync can safely remain inside the if. | ||
| RuntimeDependentConditionChecker checker; | ||
| bool depends_on_runtime = checker.DependsOnRuntimeValue(op->condition); | ||
|
|
||
| if (depends_on_runtime) { | ||
| // Condition depends on runtime values - must hoist sync | ||
| LOG(WARNING) | ||
| << "[ThreadSync] Hoisting sync from inside if to before if. " | ||
| << "Condition depends on runtime value: " << op->condition; | ||
|
|
||
| for (const auto &sync : syncs_in_then) { | ||
| syncs_inserted_.erase(sync); | ||
| } | ||
| for (const auto &sync : syncs_in_else) { | ||
| syncs_inserted_.erase(sync); | ||
| } | ||
|
|
||
| // Insert sync before the if-statement itself | ||
| insert_syncs(op); | ||
| } |
There was a problem hiding this comment.
Hoisting removes sync markers but keeps stale access summaries.
After hoisting, syncs_inserted_ entries from branches are erased, yet s.access still contains the kSync entries that were added during Summarize. This can incorrectly suppress later conflict detection (the outer planner thinks a sync still exists inside the if). Consider recomputing s.access after hoist, or filtering out sync entries corresponding to removed syncs_in_then/else so the summary matches actual sync placement.
🤖 Prompt for AI Agents
In `@src/transform/thread_storage_sync.cc` around lines 742 - 786, The hoist path
removes entries from syncs_inserted_ (via syncs_in_then/syncs_in_else) but does
not update the previously computed access summary s.access, leaving stale kSync
entries added during Summarize and causing incorrect later conflict suppression;
after erasing the sync pointers (and before calling insert_syncs(op)) update the
summary to reflect the hoist by either (a) recomputing the summary via the
Summarize routine for the affected statement/op, or (b) explicitly removing
kSync entries in s.access that correspond to the removed sync objects
(syncs_in_then and syncs_in_else), so the access summary matches the actual sync
placement.
…e_gqa_decode_varlen_logits_paged.py` file. Refactor `example_gqa_decode_varlen_logits.py` to enhance performance and maintainability by removing unused imports and optimizing shared memory usage. Adjust test cases to reflect the removal of the paged example.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)
309-312:⚠️ Potential issue | 🟡 MinorHardcoded argument overrides defeat argparse purpose.
These lines override user-provided CLI arguments, making
--dtype,--test_varlen,--test_sink, and--num_splitflags ineffective. This appears to be leftover debug configuration.Suggested fix: Remove hardcoded overrides
args = parser.parse_args() - args.test_sink = True - args.test_varlen = True - args.dtype = T.float16 - args.num_split = 1
🤖 Fix all issues with AI agents
In `@examples/flash_decoding/example_gqa_decode_varlen_logits.py`:
- Around line 314-319: Restore the original conditional that respects the
args.benchmark flag: revert the unconditional call to
speed_benchmark_decode_comparison and reintroduce the if/else using
args.benchmark to call speed_benchmark_decode_comparison(args) when true and
test_varlen_decode_main(args) when false; specifically update the block using
the symbols args.benchmark, speed_benchmark_decode_comparison, and
test_varlen_decode_main so the CLI flag controls which function runs.
🧹 Nitpick comments (2)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)
63-63: Unused loop variablebzshould be prefixed with underscore.The variable
bzis unpacked but never used in the kernel body. Sincenum_splitdefaults to 1, this appears to be placeholder code for future split functionality.Suggested fix
- with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bid, hid, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bid, hid, _bz):
262-262: Variablehead_sizeshadows parameter from outer scope.The variable
head_sizeunpacked here shadows thehead_sizefromargs.head_sizeassigned on line 224. While they should have the same value, this creates potential confusion.Suggested fix
- _, q_h, head_size = q_decode.shape + _, q_h, _ = q_decode.shape
| # if args.benchmark: | ||
| # speed_benchmark_decode_comparison(args) | ||
| # else: | ||
| # test_varlen_decode_main(args) | ||
|
|
||
| speed_benchmark_decode_comparison(args) |
There was a problem hiding this comment.
Commented-out code breaks the --benchmark CLI flag functionality.
The conditional logic is commented out, making the --benchmark argument (defined at line 306) non-functional. The script now unconditionally runs the benchmark instead of allowing users to choose between test and benchmark modes.
Suggested fix: Restore conditional logic
- # if args.benchmark:
- # speed_benchmark_decode_comparison(args)
- # else:
- # test_varlen_decode_main(args)
-
- speed_benchmark_decode_comparison(args)
+ if args.benchmark:
+ speed_benchmark_decode_comparison(args)
+ else:
+ test_varlen_decode_main(args)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # if args.benchmark: | |
| # speed_benchmark_decode_comparison(args) | |
| # else: | |
| # test_varlen_decode_main(args) | |
| speed_benchmark_decode_comparison(args) | |
| if args.benchmark: | |
| speed_benchmark_decode_comparison(args) | |
| else: | |
| test_varlen_decode_main(args) |
🤖 Prompt for AI Agents
In `@examples/flash_decoding/example_gqa_decode_varlen_logits.py` around lines 314
- 319, Restore the original conditional that respects the args.benchmark flag:
revert the unconditional call to speed_benchmark_decode_comparison and
reintroduce the if/else using args.benchmark to call
speed_benchmark_decode_comparison(args) when true and
test_varlen_decode_main(args) when false; specifically update the block using
the symbols args.benchmark, speed_benchmark_decode_comparison, and
test_varlen_decode_main so the CLI flag controls which function runs.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 380-391: The code dereferences as_const_int(iv->dom->extent)
unsafely; first check whether as_const_int(iv->dom->extent) has a value before
using it (e.g., auto opt_extent = as_const_int(iv->dom->extent); if
(!opt_extent) { /* fallback */ }). If opt_extent is present, use *opt_extent for
extent and call analyzer_->z3_prover.CountSatisfyingValues(iv->var, extent);
otherwise skip Z3 enumeration and fall back to the range-based calculation
(using bound->max_value and bound->min_value) to avoid undefined behavior.
- Around line 468-483: In DependsOnRuntimeValue, avoid the unsafe dereference of
as_const_int(iv->dom->extent) — check the result (e.g., ICHECK(result) or handle
the null case and set depends_on_runtime_ accordingly) before casting to int64_t
and using thread_extent; reference the symbol as_const_int and iv->dom->extent.
Replace the hard-coded warp_size = 32 with a query to the target attribute (use
target->GetAttr<Integer>("thread_warp_size") as done in
lower_thread_allreduce.cc), falling back to a sensible default if the attribute
is missing, and then pass that warp_size into
analyzer_->z3_prover.CountSatisfyingValues; reference
analyzer_->z3_prover.CountSatisfyingValues and warp_size in your changes.
- Around line 1416-1422: The loop shift currently always uses a hardcoded step
of 1 which ignores explicit loop strides; update the logic in the block that
constructs loop_shift_sub so that you read loop->step when defined and fall back
to make_const(loop->loop_var.dtype(), 1) otherwise (i.e., compute a PrimExpr
step = loop->step.defined() ? loop->step : make_const(...)), then use that step
for the substitution of loop->loop_var -> loop->loop_var + step; ensure you
reference the existing loop, loop->loop_var, loop->step, loop_shift_sub and
make_const symbols when making the change.
🧹 Nitpick comments (2)
src/transform/thread_storage_sync.cc (2)
506-507: Duplicateprivate:specifier.The
private:access specifier appears twice (lines 485 and 506). This is harmless but suggests copy-paste; consider removing the redundant one.
1650-1651: Extra semicolon.Line 1651 has an unnecessary semicolon after the return statement.
return tl::TileLangThreadSync(std::move(f), storage_scope); - ;
| ffi::Map<Var, PrimExpr> loop_shift_sub; | ||
| if (loop != nullptr) { | ||
| // Get loop step, default to 1 if not specified | ||
| PrimExpr step = make_const(loop->loop_var.dtype(), 1); | ||
| // Substitute loop_var -> loop_var + step for the "next iteration" | ||
| loop_shift_sub.Set(loop->loop_var, loop->loop_var + step); | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Search for ForNode class definition
ast-grep --pattern 'class ForNode {
$$$
}'Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Also search in .h files for ForNode definition
fd -e "\.h" | xargs rg -l "class ForNode" | head -5Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for ForNode definition and its fields
rg "class ForNode" -A 30 -t cpp --type-listRepository: tile-ai/tilelang
Length of output: 449
🏁 Script executed:
# Simple search for ForNode definition
rg "class ForNode" -A 20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for step field in relation to ForNode
rg "loop.*step\|step.*loop" -i src/transform/thread_storage_sync.cc -C 3Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check the imports in the file being reviewed
head -50 src/transform/thread_storage_sync.cc | grep -E "#include|using namespace"Repository: tile-ai/tilelang
Length of output: 784
🏁 Script executed:
# Look for TIR loop semantics or documentation
rg "ForNode\|for.*step" -C 2 | head -100Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Read the actual code section being reviewed
sed -n '1400,1440p' src/transform/thread_storage_sync.cc | cat -nRepository: tile-ai/tilelang
Length of output: 1983
🏁 Script executed:
# Search for any "step" handling in the file
rg "step|stride" src/transform/thread_storage_sync.cc -n -iRepository: tile-ai/tilelang
Length of output: 1048
🏁 Script executed:
# Look at the entire FindConflict method to understand loop handling
sed -n '1360,1450p' src/transform/thread_storage_sync.cc | cat -nRepository: tile-ai/tilelang
Length of output: 4512
🏁 Script executed:
# Search for ForNode usage patterns to understand loop structure
rg "ForNode\|loop\-\>" src/transform/thread_storage_sync.cc -n -B 2 -A 2 | head -80Repository: tile-ai/tilelang
Length of output: 158
🏁 Script executed:
# Search for how ForNode or loops are created in the codebase
rg "For\(" src/transform/thread_storage_sync.cc -n -B 2 -A 2 | head -60Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for any TODO or FIXME related to step or stride
rg "TODO|FIXME" src/transform/thread_storage_sync.cc -n -iRepository: tile-ai/tilelang
Length of output: 143
🏁 Script executed:
# Search for loop increment or step patterns in the codebase
rg "loop_var.*\+|loop_var.*step" -n | head -40Repository: tile-ai/tilelang
Length of output: 1579
🏁 Script executed:
# Check how the loop parameter is obtained/passed to FindConflict
sed -n '1350,1365p' src/transform/thread_storage_sync.cc | cat -nRepository: tile-ai/tilelang
Length of output: 692
🏁 Script executed:
# Search for examples where ForNode is actually used with step
rg "loop->extend|loop->loop_var" src/transform/thread_storage_sync.cc -n -B 1 -A 1Repository: tile-ai/tilelang
Length of output: 653
🏁 Script executed:
# Search for how ForNode/loops are constructed in this codebase
rg "For\(" -A 5 | head -80Repository: tile-ai/tilelang
Length of output: 4385
🏁 Script executed:
# Look for comments explaining loop semantics or iteration
rg "unit.*stride|unit.*step|step.*1" -i -n | head -30Repository: tile-ai/tilelang
Length of output: 3029
🏁 Script executed:
# Check if there are any loop transformations that create non-unit strides
rg "stride.*2|step.*2|extent.*/" src/ -n | head -30Repository: tile-ai/tilelang
Length of output: 3099
🏁 Script executed:
# Search for ForNode step field in headers
rg "ForNode" -A 20 | grep -i "step\|extent\|loop_var" | head -40Repository: tile-ai/tilelang
Length of output: 3452
🏁 Script executed:
# Look in tvm headers included in the file
rg "class.*For.*{" -A 30 | head -80Repository: tile-ai/tilelang
Length of output: 149
🏁 Script executed:
# Search more specifically for For loop construction to see step usage
rg "For\(" src/ir.cc -B 2 -A 6 | head -100Repository: tile-ai/tilelang
Length of output: 2571
🏁 Script executed:
# Check if loop->step or loop.step is ever accessed in the codebase
rg "loop->step|loop\.step|fnode->step" -nRepository: tile-ai/tilelang
Length of output: 170
🏁 Script executed:
# Check the loop_vectorize.cc usage of fnode->step for context
sed -n '650,670p' src/transform/loop_vectorize.cc | cat -nRepository: tile-ai/tilelang
Length of output: 940
🏁 Script executed:
# Search for other uses of ->step in the codebase
rg "->step|\.step" src/transform/ -n -B 1 -A 1 | grep -i "loop\|for\|node" | head -40Repository: tile-ai/tilelang
Length of output: 464
🏁 Script executed:
# Check if there's any null/optional check for step before using it
rg "step.*null|step.*optional" -i src/ | head -20Repository: tile-ai/tilelang
Length of output: 383
🏁 Script executed:
# Look at how For() constructor is defined to understand step parameter
rg "Stmt For\(" src/ -A 3 | head -40Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check how step being optional is handled elsewhere
rg "step.*Optional\|Optional.*step" -A 3 -B 3 src/ | head -60Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for how to properly access and use loop->step
rg "fnode.*step\|op.*step\|loop.*step" src/transform/ -n -B 2 -A 2 | head -100Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check if there's a helper or pattern for getting step with default
rg "step.*defined\|defined.*step\|step.*value" -i src/ | head -20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look at the Substitute function to understand what it expects
rg "void Substitute\|PrimExpr Substitute" -A 5 | head -40Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for comments explaining what step field means
rg "step" src/ir.cc -B 2 -A 2 | head -80Repository: tile-ai/tilelang
Length of output: 2764
🏁 Script executed:
# Look for existing patterns of handling optional step
rg "step.*defined\(\)|if.*step" src/ -i -n | head -30Repository: tile-ai/tilelang
Length of output: 488
🏁 Script executed:
# Check if there's documentation about loop semantics
find . -name "*.md" -o -name "*.txt" | xargs grep -l "loop.*step\|ForNode" 2>/dev/null | head -5Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look at the comment context around the step=1 line more broadly
sed -n '1100,1130p' src/transform/thread_storage_sync.cc | cat -nRepository: tile-ai/tilelang
Length of output: 1635
🏁 Script executed:
# Check if ForNode->step is ever set in this codebase
rg "loop.*step.*=" src/transform/ -n | grep -v "make_const\|vmap\|Substitute" | head -20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for ForNode creation with step parameter
rg "For\(" src/transform/thread_storage_sync.cc -B 2 -A 5 | head -40Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for any loop transformation that creates loops with non-unit steps
rg "step.*[^1]\|stride.*[^1]" src/transform/thread_storage_sync.cc -i -n | head -20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check what fields the loop parameter has by looking at type info
sed -n '1360,1375p' src/transform/thread_storage_sync.cc | cat -nRepository: tile-ai/tilelang
Length of output: 972
🏁 Script executed:
# Check if loop parameter in FindConflict is guaranteed to have step set
rg "FindConflict" src/transform/thread_storage_sync.cc -B 2 -A 2 | head -40Repository: tile-ai/tilelang
Length of output: 1516
🏁 Script executed:
# Look for where FindConflict is called
rg "FindConflict\(" src/transform/thread_storage_sync.cc -n -B 1 -A 1Repository: tile-ai/tilelang
Length of output: 1244
🏁 Script executed:
# Check what types of loops are passed to FindConflict
sed -n '1000,1050p' src/transform/thread_storage_sync.cc | cat -nRepository: tile-ai/tilelang
Length of output: 2137
🏁 Script executed:
# Look at the actual call sites to see if loop->step might be null
sed -n '1200,1250p' src/transform/thread_storage_sync.cc | cat -nRepository: tile-ai/tilelang
Length of output: 2108
🏁 Script executed:
# Search for any loops created with non-unit step in this entire codebase
rg "step\(" src/ -A 2 -B 2 | grep -i "make_const\|PrimExpr" | head -40Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check if loop->step field is ever accessed/checked in this file
sed -n '1300,1600p' src/transform/thread_storage_sync.cc | grep -i "step"Repository: tile-ai/tilelang
Length of output: 529
🏁 Script executed:
# Look for patterns of handling optional step like in ir.cc
rg "step.*defined\|if.*step" src/transform/ -i -B 1 -A 1 | head -40Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Verify the exact structure - does ForNode definitely have step field?
rg "ForNode.*step\|step.*ForNode" | head -20Repository: tile-ai/tilelang
Length of output: 42
Use actual loop step value instead of hardcoding 1.
The code should check loop->step and use it if defined, falling back to 1 for unit-stride loops. Currently, step is always created as a constant 1, ignoring the loop's actual step value. This causes incorrect iteration shift analysis for loops with explicit non-unit strides.
Replace the hardcoded step assignment:
PrimExpr step = loop->step.defined() ? loop->step : make_const(loop->loop_var.dtype(), 1);🤖 Prompt for AI Agents
In `@src/transform/thread_storage_sync.cc` around lines 1416 - 1422, The loop
shift currently always uses a hardcoded step of 1 which ignores explicit loop
strides; update the logic in the block that constructs loop_shift_sub so that
you read loop->step when defined and fall back to
make_const(loop->loop_var.dtype(), 1) otherwise (i.e., compute a PrimExpr step =
loop->step.defined() ? loop->step : make_const(...)), then use that step for the
substitution of loop->loop_var -> loop->loop_var + step; ensure you reference
the existing loop, loop->loop_var, loop->step, loop_shift_sub and make_const
symbols when making the change.
|
@regression-perf |
…dding a configurable warp size parameter to `RuntimeDependentConditionChecker` and `TileLangThreadSyncPlanner`. This allows for better adaptability to different target architectures. Update the logic to ensure thread extent is a constant and improve handling of runtime-dependent conditions.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 573-581: GetThreadVar currently calls LOG(FATAL) when env_threads_
doesn't contain the requested tag, which can crash callers (e.g., when
requesting "threadIdx.x" outside a device context); change GetThreadVar to
return an Optional/nullable IterVar (e.g., Optional<IterVar> or
std::optional<IterVar>) instead of aborting, returning empty when not found, and
update all call sites that assumed a non-null return (notably the place using
"threadIdx.x") to check has_value()/operator bool and handle the missing case
(skip hoisting/return early) before dereferencing the value.
🧹 Nitpick comments (1)
src/transform/thread_storage_sync.cc (1)
490-514: Duplicateprivate:access specifier.There are two
private:labels (lines 490 and 511) in the class. The second one is redundant and should be removed for cleaner code.🧹 Proposed fix
private: PrimExpr VisitExpr_(const BufferLoadNode *op) final { // Any buffer load introduces runtime dependency // (we don't know the buffer contents at compile time) depends_on_runtime_ = true; return IRMutatorWithAnalyzer::VisitExpr_(op); } PrimExpr VisitExpr_(const CallNode *op) final { // Check tvm_access_ptr and address_of - if used in condition, it's reading // memory if (op->op.same_as(builtin::tvm_access_ptr()) || op->op.same_as(builtin::address_of())) { depends_on_runtime_ = true; return IRMutatorWithAnalyzer::VisitExpr_(op); } // Other calls might also introduce runtime dependency // but we'll be conservative and check children return IRMutatorWithAnalyzer::VisitExpr_(op); } -private: bool depends_on_runtime_{false}; int warp_size_; };
| IterVar GetThreadVar(const std::string &tag) const { | ||
| for (const auto &iv : env_threads_) { | ||
| if (iv->thread_tag == tag) { | ||
| return iv; | ||
| } | ||
| } | ||
| LOG(FATAL) << "Thread variable " << tag << " not found"; | ||
| return IterVar(); | ||
| } |
There was a problem hiding this comment.
Potential crash if thread variable not found.
GetThreadVar will call LOG(FATAL) if the requested thread tag is not found in env_threads_. This is called at line 826 with "threadIdx.x" - if the if-statement is not inside a device context with thread extents defined, this will crash.
Consider returning an Optional<IterVar> and handling the missing case gracefully, or add a check before calling this method.
🛡️ Proposed defensive approach
- IterVar GetThreadVar(const std::string &tag) const {
+ std::optional<IterVar> GetThreadVar(const std::string &tag) const {
for (const auto &iv : env_threads_) {
if (iv->thread_tag == tag) {
return iv;
}
}
- LOG(FATAL) << "Thread variable " << tag << " not found";
- return IterVar();
+ return std::nullopt;
}Then at the call site (line 826):
auto tx_opt = GetThreadVar("threadIdx.x");
if (!tx_opt.has_value()) {
// Cannot determine thread-dependency without threadIdx.x, skip hoisting
continue; // or return early
}
IterVar tx = tx_opt.value();🤖 Prompt for AI Agents
In `@src/transform/thread_storage_sync.cc` around lines 573 - 581, GetThreadVar
currently calls LOG(FATAL) when env_threads_ doesn't contain the requested tag,
which can crash callers (e.g., when requesting "threadIdx.x" outside a device
context); change GetThreadVar to return an Optional/nullable IterVar (e.g.,
Optional<IterVar> or std::optional<IterVar>) instead of aborting, returning
empty when not found, and update all call sites that assumed a non-null return
(notably the place using "threadIdx.x") to check has_value()/operator bool and
handle the missing case (skip hoisting/return early) before dereferencing the
value.
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
…pointer checks instead of optional values. This change improves clarity and ensures that the thread extent is correctly validated as a constant.
|
@regression-perf |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 380-391: The code dereferences as_const_int(iv->dom->extent)
unsafely; change the logic in the block around as_const_int, iv->dom->extent and
analyzer_->z3_prover.CountSatisfyingValues so you first check the
optional/nullable result of as_const_int before using it (e.g., auto
maybe_extent = as_const_int(iv->dom->extent); if not present, skip Z3
enumeration and fall back to the range-based calculation using bound->max_value
and bound->min_value); only call CountSatisfyingValues when you have a valid
extent and keep the existing fallback to return
static_cast<size_t>(bound->max_value - bound->min_value + 1).
🧹 Nitpick comments (1)
src/transform/thread_storage_sync.cc (1)
1667-1668: Minor: Remove extra semicolon.Line 1668 has a double semicolon
;;which is harmless but unnecessary.🧹 Proposed fix
return tl::TileLangThreadSync(std::move(f), storage_scope); - ;
| auto extent = *as_const_int(iv->dom->extent); | ||
| // Always use Z3 enumeration to count satisfying values. | ||
| // This handles constraints like `tx % 4 == 0` that const_int_bound cannot | ||
| // detect. Z3 enumeration will return the exact count of satisfying values. | ||
| int64_t z3_count = | ||
| analyzer_->z3_prover.CountSatisfyingValues(iv->var, extent); | ||
| if (z3_count > 0) { | ||
| return static_cast<size_t>(z3_count); | ||
| } | ||
|
|
||
| // Fallback to range-based calculation if Z3 enumeration failed | ||
| return static_cast<size_t>(bound->max_value - bound->min_value + 1); |
There was a problem hiding this comment.
Unsafe dereference of as_const_int result.
Line 380 dereferences as_const_int(iv->dom->extent) without checking if it returns a valid value. If the extent is not a compile-time constant, this causes undefined behavior.
🐛 Proposed fix
- auto extent = *as_const_int(iv->dom->extent);
+ auto extent_opt = as_const_int(iv->dom->extent);
+ if (!extent_opt) {
+ // Fallback to range-based calculation if extent is not const
+ return static_cast<size_t>(bound->max_value - bound->min_value + 1);
+ }
+ auto extent = *extent_opt;🤖 Prompt for AI Agents
In `@src/transform/thread_storage_sync.cc` around lines 380 - 391, The code
dereferences as_const_int(iv->dom->extent) unsafely; change the logic in the
block around as_const_int, iv->dom->extent and
analyzer_->z3_prover.CountSatisfyingValues so you first check the
optional/nullable result of as_const_int before using it (e.g., auto
maybe_extent = as_const_int(iv->dom->extent); if not present, skip Z3
enumeration and fall back to the range-based calculation using bound->max_value
and bound->min_value); only call CountSatisfyingValues when you have a valid
extent and keep the existing fallback to return
static_cast<size_t>(bound->max_value - bound->min_value + 1).
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
…-carry analysis by modifying the extent calculation. This change ensures valid iteration comparisons by reducing the extent by one, allowing for accurate analysis of loop iterations.
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
Summary
thread_storage_sync.ccto correctly identify conflicts between read and write operations based on loop carry conditionsTest plan
🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes
New Features
Tests
Chores