Skip to content

[Bugfix] Fix thread storage sync conflict detection for loop carry write-after-read#1781

Merged
LeiWang1999 merged 16 commits intomainfrom
sync_0203_
Feb 4, 2026
Merged

[Bugfix] Fix thread storage sync conflict detection for loop carry write-after-read#1781
LeiWang1999 merged 16 commits intomainfrom
sync_0203_

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Feb 3, 2026

Summary

  • Fix thread storage synchronization logic in thread_storage_sync.cc to correctly identify conflicts between read and write operations based on loop carry conditions
  • The previous logic incorrectly checked for double buffer write with read operations without loop carry, but the correct behavior should detect write-after-read conflicts when loop carry is present

Test plan

  • Verify existing thread sync tests pass
  • Confirm the fix addresses the synchronization issue in loop carry scenarios

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes

    • More accurate placement and hoisting of thread-storage synchronizations, including loop-aware and runtime-dependent condition handling to avoid incorrect barriers.
  • New Features

    • Loop-aware conflict analysis and target-specific warp-size support; ability to detect runtime-dependent conditions for hoisting.
    • Added a pure-Python reference attention implementation for decoding comparisons.
  • Tests

    • New CUDA-focused tests for loop-carried sync patterns, hoisting behavior, and buffering.
  • Chores

    • Reduced example matrix defaults, removed a paged decoding example and its test, updated third-party submodule, and exposed a constraint-accessor helper.

@github-actions
Copy link

github-actions bot commented Feb 3, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 3, 2026

📝 Walkthrough

Walkthrough

Adds loop-aware and runtime-dependent analysis to thread-storage sync planning: Z3 AllSAT-based thread-extent counting with range fallback, loop-aware substitution for precise loop-carried conflict detection, hoisting decisions for IfThenElse based on runtime-dependent conditions, removes a double-buffer flag, and exposes constraint accessor; expands CUDA tests and updates examples.

Changes

Cohort / File(s) Summary
Thread storage sync logic
src/transform/thread_storage_sync.cc
Added RuntimeDependentConditionChecker; CalculateThreadExtent now uses Z3 AllSAT with a range fallback; FindConflict signature changed to accept const ForNode *loop; implemented loop-aware substitution, loop-carry conflict checks, and hoisting logic; removed AccessEntry.double_buffer_write; added GetThreadVar and warp_size constructor parameter/usage.
Constraint utilities
src/transform/common/constr_visitor.h
Added public ConstrSet GetConstrSet() const to expose the current accumulated constraints.
Tests — thread sync
testing/python/transform/test_tilelang_transform_thread_sync.py, testing/python/transform/test_tile_sync_thread.py
Added multiple CUDA-targeted tests for loop-carried dependency scenarios and if-hoisting behavior; assertions check presence/absence of T.tvm_storage_sync("shared") and print IR for inspection.
Examples — dequantize
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
Reduced default CLI matrix sizes (M, N, K) to 256 and minor formatting tweak.
Examples — flash decoding
examples/flash_decoding/example_gqa_decode_varlen_logits.py, deleted examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py, examples/flash_decoding/test_example_flash_decoding.py
Removed Triton-specific blocks and the paged example; added Python reference ref_attention(...); adjusted control flow to always run comparison/benchmark; removed paged example and its test import.
Submodule
3rdparty/tvm
Updated submodule pointer (commit bump only).

Sequence Diagram(s)

sequenceDiagram
    participant Runner as "Pass Runner"
    participant ThreadSync as "ThreadSync Pass\n(src/transform/thread_storage_sync.cc)"
    participant Constr as "Constraint Engine\n(ConstrVisitor / Z3 / AllSAT)"
    participant IR as "IR (For / If / AccessEntries)"

    Runner->>ThreadSync: invoke ThreadSync(mod)
    ThreadSync->>IR: traverse For/If, collect AccessEntries
    ThreadSync->>Constr: GetConstrSet() and request runtime checks
    ThreadSync->>Constr: run AllSAT for thread-index extents
    Constr-->>ThreadSync: extent / satisfiability / equivalence results
    ThreadSync->>IR: apply loop-aware substitution, decide hoist vs keep sync
    ThreadSync-->>Runner: return transformed IR
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • kurisu6912
  • Rachmanino

Poem

🐇 I hopped through loops and counted threads with glee,
I checked runtime paths and set the syncs free.
I dropped a flag, shifted indices in time,
Hoisted safe barriers and kept the IR in line.
Tiny rabbit cheers: "Syncs hum—code sings with me!"

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.19% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly describes the main bugfix in thread_storage_sync.cc addressing conflict detection for loop carry write-after-read scenarios, which aligns with the core changes and objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch sync_0203_

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 1441-1444: The existing comment about double-buffer reads is
stale; update the comment above the branch that checks if (curr.type == kWrite
&& prev.type == kRead && loop_carry) to reflect that this is a conservative
loop-carried WAR detection: explain that when a read in the prior iteration
(prev.type == kRead) precedes a write in the current iteration (curr.type ==
kWrite) and loop_carry is true, the code unconditionally treats it as a conflict
(regardless of range_is_overlap). Replace the old double-buffer wording with a
concise note such as: "Loop-carried WAR: treat a read in the prior iteration
followed by a write in the current iteration as a conflict (conservative
check)."

…ared memory for `sorted_token_ids` instead of local memory, improving thread synchronization. Adjust default argument values for M, N, and K in the main function for better testing scenarios.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py`:
- Around line 331-363: The test test_loop_carry_modulo_buffering currently only
prints the transformed IR (s) and lacks a validation; replace the print with an
assertion that the ThreadSync transform did not insert a synchronization call —
e.g., after running mod = tilelang.transform.ThreadSync("shared")(mod) and
computing s = str(mod), assert that the serialized module does not contain
synchronization markers like "thread_sync" (or "tir.thread_sync") to validate
that func's double-buffering pattern avoided a barrier; use the existing symbols
test_loop_carry_modulo_buffering, func, ThreadSync("shared"), mod, and s to
locate and modify the code.
🧹 Nitpick comments (1)
testing/python/transform/test_tilelang_transform_thread_sync.py (1)

292-329: Documentation is misleading - this is a same-iteration cross-thread dependency, not loop-carried.

The test correctly expects a barrier, but the documentation is inaccurate. The dependency here is within the same iteration: thread tx writes A[tx] while thread (tx+1)%128 reads A[tx] in the same iteration. This is a cross-thread WAR hazard that requires synchronization regardless of the loop.

The docstring and test name suggest this tests "loop-carried dependency," but the actual scenario being tested is same-iteration cross-thread access. Consider renaming to test_same_iteration_cross_thread_dependency and updating the docstring.

📝 Suggested documentation fix
 `@tilelang.testing.requires_cuda`
-def test_loop_carry_with_cross_thread_dependency():
-    """Test loop-carried dependency where different threads access overlapping locations.
+def test_same_iteration_cross_thread_dependency():
+    """Test same-iteration dependency where different threads access overlapping locations.
 
     In this test:
     - Thread tx writes to A[tx]
     - Then reads from A[(tx + 127) % 128] (neighbor's data from previous iteration)
-
-    After iteration shift analysis, we compare:
-    - Iteration i: thread tx writes A[tx]
-    - Iteration i+1: thread tx reads A[(tx + 127) % 128]
-
-    This creates a cross-thread dependency where thread tx+1's write conflicts
-    with thread tx's read in the next iteration, requiring a barrier.
+    
+    This creates a cross-thread WAR hazard within the same iteration:
+    thread (tx+1)%128 writes A[(tx+1)%128] while thread tx reads A[(tx+127)%128] = A[tx-1 mod 128].
+    Since different threads access overlapping locations, a barrier is required.
     """

Comment on lines +331 to +363
@tilelang.testing.requires_cuda
def test_loop_carry_modulo_buffering():
"""Test that A[i%2] write followed by A[i%2] read does NOT need barrier (double buffering).

After iteration shift analysis:
- Iteration i writes A[i%2]
- Iteration i+1 reads A[(i+1)%2] (shifted from A[i%2])
- A[i%2] vs A[(i+1)%2] are disjoint (0 vs 1 or 1 vs 0), so no dependency
"""

@T.prim_func(private=True)
def func():
temp_shared = T.alloc_buffer([2, 64], dtype="float32", scope="shared")
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
bx = T.launch_thread("blockIdx.x", 1)
tx = T.launch_thread("threadIdx.x", 64)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
result_local[0] = T.float32(0)
for i in range(10):
# Double buffering pattern: write to buffer[i%2], read from buffer[i%2]
# After shift: write buffer[i%2], read buffer[(i+1)%2]
# These are different buffers, so no conflict
temp_shared[i % 2, tx] = T.float32(i)
result_local[0] = result_local[0] + temp_shared[i % 2, tx]

mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
# Should NOT have sync inside loop due to modulo buffering analysis
# Note: This test verifies the modulo analysis capability
print(f"Modulo buffering result:\n{s}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing assertion - test does not validate expected behavior.

The test documents the expected behavior (no sync due to modulo buffering) in the docstring but only has a print statement at line 362 without any assertion. This means the test will always pass regardless of actual synchronization behavior.

🧪 Add assertion to validate expected behavior
     mod = tvm.IRModule({"main": func})
     mod = tilelang.transform.ThreadSync("shared")(mod)
     s = str(mod)
     # Should NOT have sync inside loop due to modulo buffering analysis
-    # Note: This test verifies the modulo analysis capability
-    print(f"Modulo buffering result:\n{s}")
+    assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync with modulo buffering:\n{s}"
🤖 Prompt for AI Agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py` around lines
331 - 363, The test test_loop_carry_modulo_buffering currently only prints the
transformed IR (s) and lacks a validation; replace the print with an assertion
that the ThreadSync transform did not insert a synchronization call — e.g.,
after running mod = tilelang.transform.ThreadSync("shared")(mod) and computing s
= str(mod), assert that the serialized module does not contain synchronization
markers like "thread_sync" (or "tir.thread_sync") to validate that func's
double-buffering pattern avoided a barrier; use the existing symbols
test_loop_carry_modulo_buffering, func, ThreadSync("shared"), mod, and s to
locate and modify the code.

Introduce the UniformExprChecker class to determine if expressions are uniform across threads, crucial for safe synchronization in conditional statements. Update the TileLangThreadSyncPlanner to hoist synchronization points out of non-uniform if-statements to prevent potential deadlocks. Enhance tests to validate sync hoisting behavior for various non-uniform conditions involving thread indices and shared memory access.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)

739-831: ⚠️ Potential issue | 🟠 Major

Hoisted syncs leave stale sync markers in s.access.

When non-uniform hoisting triggers, you erase syncs from syncs_inserted_ but keep s.access computed earlier with those syncs included. That can suppress downstream conflict detection because outer summaries will treat the if as containing a sync that no longer exists. This is a correctness risk (missed sync insertion).

Please recompute the access summary after hoisting or ensure the hoisted path strips only the inserted sync markers from s.access before propagating it.

🤖 Fix all issues with AI agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py`:
- Around line 365-395: Add an assertion to fail the test if the ThreadSync
transform incorrectly inserts synchronization for the disjoint-index case: after
creating mod = tilelang.transform.ThreadSync("shared")(mod) and computing s =
str(mod), assert that the generated IR does NOT contain the thread-sync/barrier
markers (for example assert "tvm_thread_sync" not in s and assert "tir.barrier"
not in s); this verifies test_loop_carry_different_indices and the
ThreadSync("shared") output instead of just printing the IR.
- Around line 544-571: The test test_sync_hoist_non_uniform_if_in_loop currently
only asserts a sync exists; update it to also assert the sync is placed inside
the loop before the non-uniform if by checking the string position: find the
index of 'for k in range(2):' (or locate the loop body start in s), then ensure
s.index('T.tvm_storage_sync("shared")') is greater than the loop start index but
less than s.index('if token_ids[tx]') (or directly assert the sync index is less
than the if index), so the storage sync for "shared" is hoisted into the loop
and appears before the non-uniform if that references token_ids and data_shared.
🧹 Nitpick comments (1)
src/transform/thread_storage_sync.cc (1)

1354-1608: Consider shifting loop-dependent constraints alongside indices.

The loop-carry path shifts curr indices but leaves curr.cset (and curr.touched in the non-scalar fallback) unshifted. If a guard depends on the loop var (e.g., i % 2), this can over-report conflicts and reduce the precision you just gained from index shifting.

A light refactor to substitute loop_var -> loop_var + step into curr constraints (and the touched-range fallback) will keep the analysis consistent.

♻️ Suggested refinement
-      PrimExpr curr_constr = curr.cset.ToConjunction();
+      PrimExpr curr_constr = curr.cset.ToConjunction();
+      if (loop != nullptr) {
+        curr_constr = Substitute(curr_constr, loop_shift_sub);
+      }
-          auto curr_min = analyzer.Simplify(
-              Substitute(curr.touched[i].min() * curr_dtype.bytes(), curr_sub));
-          auto curr_max = analyzer.Simplify(
-              Substitute(curr.touched[i].max() * curr_dtype.bytes(), curr_sub));
+          auto curr_min = analyzer.Simplify(Substitute(
+              curr.touched[i].min() * curr_dtype.bytes(), curr_sub));
+          auto curr_max = analyzer.Simplify(Substitute(
+              curr.touched[i].max() * curr_dtype.bytes(), curr_sub));
+          if (loop != nullptr) {
+            curr_min = Substitute(curr_min, loop_shift_sub);
+            curr_max = Substitute(curr_max, loop_shift_sub);
+          }

Comment on lines +365 to +395
@tilelang.testing.requires_cuda
def test_loop_carry_different_indices():
"""Test that A[i] write followed by A[i+1] read does NOT need barrier.

After iteration shift analysis:
- Iteration i writes A[i]
- Iteration i+1 reads A[i+2] (shifted from A[i+1], becomes A[(i+1)+1] = A[i+2])
- A[i] vs A[i+2] are disjoint, so no loop-carried dependency
"""

@T.prim_func(private=True)
def func():
temp_shared = T.alloc_buffer([128], dtype="float32", scope="shared")
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
bx = T.launch_thread("blockIdx.x", 1)
tx = T.launch_thread("threadIdx.x", 1)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
result_local[0] = T.float32(0)
for i in range(10):
# Write to A[i], read from A[i+1]
# After shift: comparing A[i] (write) vs A[i+2] (read from i+1 shifted)
# No overlap, no dependency
temp_shared[i] = T.float32(i)
result_local[0] = result_local[0] + temp_shared[i + 1]

mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
print(f"Different indices result:\n{s}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add an assertion to validate the “different indices” case.

The test currently only prints, so it can’t fail if the transform regresses.

🧪 Add assertion
     mod = tvm.IRModule({"main": func})
     mod = tilelang.transform.ThreadSync("shared")(mod)
     s = str(mod)
-    print(f"Different indices result:\n{s}")
+    assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync:\n{s}"

Based on learnings, tests in testing/python/transform should assert structural patterns in generated IR rather than rely on prints or numeric literals.

🤖 Prompt for AI Agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py` around lines
365 - 395, Add an assertion to fail the test if the ThreadSync transform
incorrectly inserts synchronization for the disjoint-index case: after creating
mod = tilelang.transform.ThreadSync("shared")(mod) and computing s = str(mod),
assert that the generated IR does NOT contain the thread-sync/barrier markers
(for example assert "tvm_thread_sync" not in s and assert "tir.barrier" not in
s); this verifies test_loop_carry_different_indices and the ThreadSync("shared")
output instead of just printing the IR.

Comment on lines +544 to +571
@tilelang.testing.requires_cuda
def test_sync_hoist_non_uniform_if_in_loop():
"""Test sync hoisting when non-uniform if is inside a loop."""

@T.prim_func(private=True)
def func():
token_ids = T.alloc_buffer([128], dtype="int32", scope="shared")
data_shared = T.alloc_buffer([128], dtype="float32", scope="shared")
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
bx = T.launch_thread("blockIdx.x", 1)
tx = T.launch_thread("threadIdx.x", 128)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
result_local[0] = T.float32(0)
for k in range(2):
# Write to shared memory
data_shared[tx] = T.float32(tx + k)
# Non-uniform if inside loop
if token_ids[tx] != -1:
result_local[0] = result_local[0] + data_shared[tx]

mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}"
# Sync should be before the if inside the loop, not inside the if
# This ensures all threads can reach the sync point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a position check to confirm the sync is hoisted inside the loop.

This test asserts a sync exists but doesn’t verify it is placed before the non-uniform if inside the loop (vs. outside the loop). A simple index check will tighten it.

🧪 Add position assertion
     mod = tvm.IRModule({"main": func})
     mod = tilelang.transform.ThreadSync("shared")(mod)
     s = str(mod)
     assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}"
     # Sync should be before the if inside the loop, not inside the if
     # This ensures all threads can reach the sync point
+    loop_pos = s.index("for k in range(2)")
+    sync_pos = s.index('T.tvm_storage_sync("shared")')
+    if_pos = s.index("if token_ids")
+    assert loop_pos < sync_pos < if_pos, f"Sync should be inside loop and before if:\n{s}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@tilelang.testing.requires_cuda
def test_sync_hoist_non_uniform_if_in_loop():
"""Test sync hoisting when non-uniform if is inside a loop."""
@T.prim_func(private=True)
def func():
token_ids = T.alloc_buffer([128], dtype="int32", scope="shared")
data_shared = T.alloc_buffer([128], dtype="float32", scope="shared")
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
bx = T.launch_thread("blockIdx.x", 1)
tx = T.launch_thread("threadIdx.x", 128)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
result_local[0] = T.float32(0)
for k in range(2):
# Write to shared memory
data_shared[tx] = T.float32(tx + k)
# Non-uniform if inside loop
if token_ids[tx] != -1:
result_local[0] = result_local[0] + data_shared[tx]
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}"
# Sync should be before the if inside the loop, not inside the if
# This ensures all threads can reach the sync point
`@tilelang.testing.requires_cuda`
def test_sync_hoist_non_uniform_if_in_loop():
"""Test sync hoisting when non-uniform if is inside a loop."""
`@T.prim_func`(private=True)
def func():
token_ids = T.alloc_buffer([128], dtype="int32", scope="shared")
data_shared = T.alloc_buffer([128], dtype="float32", scope="shared")
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
bx = T.launch_thread("blockIdx.x", 1)
tx = T.launch_thread("threadIdx.x", 128)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
result_local[0] = T.float32(0)
for k in range(2):
# Write to shared memory
data_shared[tx] = T.float32(tx + k)
# Non-uniform if inside loop
if token_ids[tx] != -1:
result_local[0] = result_local[0] + data_shared[tx]
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}"
# Sync should be before the if inside the loop, not inside the if
# This ensures all threads can reach the sync point
loop_pos = s.index("for k in range(2)")
sync_pos = s.index('T.tvm_storage_sync("shared")')
if_pos = s.index("if token_ids")
assert loop_pos < sync_pos < if_pos, f"Sync should be inside loop and before if:\n{s}"
🤖 Prompt for AI Agents
In `@testing/python/transform/test_tilelang_transform_thread_sync.py` around lines
544 - 571, The test test_sync_hoist_non_uniform_if_in_loop currently only
asserts a sync exists; update it to also assert the sync is placed inside the
loop before the non-uniform if by checking the string position: find the index
of 'for k in range(2):' (or locate the loop body start in s), then ensure
s.index('T.tvm_storage_sync("shared")') is greater than the loop start index but
less than s.index('if token_ids[tx]') (or directly assert the sync index is less
than the if index), so the storage sync for "shared" is hoisted into the loop
and appears before the non-uniform if that references token_ids and data_shared.

… disabling and kernel source printing for debugging. Update thread synchronization logic in `thread_storage_sync.cc` to check for runtime-dependent conditions, preventing potential deadlocks by hoisting sync points as necessary.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)

1362-1526: ⚠️ Potential issue | 🔴 Critical

Loop‑carry constraints aren’t shifted, which can miss real conflicts.

You shift curr indices for loop‑carry, but curr.cset and the fallback curr.touched expressions are still evaluated in the current iteration. If branch guards depend on loop_var (e.g., i % 2), the equivalence and overlap checks run against the wrong iteration and can miss dependencies, leading to missing syncs.

🩹 Minimal fix: apply the loop shift to constraints and touched ranges
-      PrimExpr prev_constr = prev.cset.ToConjunction();
-      PrimExpr curr_constr = curr.cset.ToConjunction();
+      PrimExpr prev_constr = prev.cset.ToConjunction();
+      ConstrSet curr_cset = curr.cset;
+      if (loop != nullptr) {
+        curr_cset = curr_cset.Substitute(loop_shift_sub);
+      }
+      PrimExpr curr_constr = curr_cset.ToConjunction();
...
-      ConstrSet curr_cset{curr.cset};
+      ConstrSet curr_cset{curr.cset};
+      if (loop != nullptr) {
+        curr_cset = curr_cset.Substitute(loop_shift_sub);
+      }
...
-          auto curr_min = analyzer.Simplify(
-              Substitute(curr.touched[i].min() * curr_dtype.bytes(), curr_sub));
-          auto curr_max = analyzer.Simplify(
-              Substitute(curr.touched[i].max() * curr_dtype.bytes(), curr_sub));
+          auto curr_min = analyzer.Simplify(Substitute(
+              Substitute(curr.touched[i].min() * curr_dtype.bytes(),
+                         loop_shift_sub),
+              curr_sub));
+          auto curr_max = analyzer.Simplify(Substitute(
+              Substitute(curr.touched[i].max() * curr_dtype.bytes(),
+                         loop_shift_sub),
+              curr_sub));
🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 742-786: The hoist path removes entries from syncs_inserted_ (via
syncs_in_then/syncs_in_else) but does not update the previously computed access
summary s.access, leaving stale kSync entries added during Summarize and causing
incorrect later conflict suppression; after erasing the sync pointers (and
before calling insert_syncs(op)) update the summary to reflect the hoist by
either (a) recomputing the summary via the Summarize routine for the affected
statement/op, or (b) explicitly removing kSync entries in s.access that
correspond to the removed sync objects (syncs_in_then and syncs_in_else), so the
access summary matches the actual sync placement.

Comment on lines 742 to 786
// Check if any syncs were inserted inside the if-then-else
std::vector<const Object *> syncs_in_then;
std::vector<const Object *> syncs_in_else;

for (const auto &sync : syncs_inserted_) {
if (syncs_before_then.count(sync) == 0 &&
syncs_before_else.count(sync) != 0) {
// Sync was inserted during then branch processing
syncs_in_then.push_back(sync);
} else if (syncs_before_else.count(sync) == 0) {
// Sync was inserted during else branch processing
syncs_in_else.push_back(sync);
}
}

bool has_syncs_inside = !syncs_in_then.empty() || !syncs_in_else.empty();

if (has_syncs_inside) {
// Check if the condition depends on runtime values (e.g., shared memory
// loads). If so, we cannot determine at compile time how many threads
// will enter the if, so we must hoist the sync to before the if to avoid
// potential deadlock.
//
// If the condition only depends on threadIdx (e.g., `threadIdx.x >=
// 512`), ThreadPartialSyncRewriter can compute the exact thread count at
// compile time, so the sync can safely remain inside the if.
RuntimeDependentConditionChecker checker;
bool depends_on_runtime = checker.DependsOnRuntimeValue(op->condition);

if (depends_on_runtime) {
// Condition depends on runtime values - must hoist sync
LOG(WARNING)
<< "[ThreadSync] Hoisting sync from inside if to before if. "
<< "Condition depends on runtime value: " << op->condition;

for (const auto &sync : syncs_in_then) {
syncs_inserted_.erase(sync);
}
for (const auto &sync : syncs_in_else) {
syncs_inserted_.erase(sync);
}

// Insert sync before the if-statement itself
insert_syncs(op);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hoisting removes sync markers but keeps stale access summaries.

After hoisting, syncs_inserted_ entries from branches are erased, yet s.access still contains the kSync entries that were added during Summarize. This can incorrectly suppress later conflict detection (the outer planner thinks a sync still exists inside the if). Consider recomputing s.access after hoist, or filtering out sync entries corresponding to removed syncs_in_then/else so the summary matches actual sync placement.

🤖 Prompt for AI Agents
In `@src/transform/thread_storage_sync.cc` around lines 742 - 786, The hoist path
removes entries from syncs_inserted_ (via syncs_in_then/syncs_in_else) but does
not update the previously computed access summary s.access, leaving stale kSync
entries added during Summarize and causing incorrect later conflict suppression;
after erasing the sync pointers (and before calling insert_syncs(op)) update the
summary to reflect the hoist by either (a) recomputing the summary via the
Summarize routine for the affected statement/op, or (b) explicitly removing
kSync entries in s.access that correspond to the removed sync objects
(syncs_in_then and syncs_in_else), so the access summary matches the actual sync
placement.

…e_gqa_decode_varlen_logits_paged.py` file. Refactor `example_gqa_decode_varlen_logits.py` to enhance performance and maintainability by removing unused imports and optimizing shared memory usage. Adjust test cases to reflect the removal of the paged example.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (1)

309-312: ⚠️ Potential issue | 🟡 Minor

Hardcoded argument overrides defeat argparse purpose.

These lines override user-provided CLI arguments, making --dtype, --test_varlen, --test_sink, and --num_split flags ineffective. This appears to be leftover debug configuration.

Suggested fix: Remove hardcoded overrides
     args = parser.parse_args()
-    args.test_sink = True
-    args.test_varlen = True
-    args.dtype = T.float16
-    args.num_split = 1
🤖 Fix all issues with AI agents
In `@examples/flash_decoding/example_gqa_decode_varlen_logits.py`:
- Around line 314-319: Restore the original conditional that respects the
args.benchmark flag: revert the unconditional call to
speed_benchmark_decode_comparison and reintroduce the if/else using
args.benchmark to call speed_benchmark_decode_comparison(args) when true and
test_varlen_decode_main(args) when false; specifically update the block using
the symbols args.benchmark, speed_benchmark_decode_comparison, and
test_varlen_decode_main so the CLI flag controls which function runs.
🧹 Nitpick comments (2)
examples/flash_decoding/example_gqa_decode_varlen_logits.py (2)

63-63: Unused loop variable bz should be prefixed with underscore.

The variable bz is unpacked but never used in the kernel body. Since num_split defaults to 1, this appears to be placeholder code for future split functionality.

Suggested fix
-        with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bid, hid, bz):
+        with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bid, hid, _bz):

262-262: Variable head_size shadows parameter from outer scope.

The variable head_size unpacked here shadows the head_size from args.head_size assigned on line 224. While they should have the same value, this creates potential confusion.

Suggested fix
-    _, q_h, head_size = q_decode.shape
+    _, q_h, _ = q_decode.shape

Comment on lines +314 to +319
# if args.benchmark:
# speed_benchmark_decode_comparison(args)
# else:
# test_varlen_decode_main(args)

speed_benchmark_decode_comparison(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Commented-out code breaks the --benchmark CLI flag functionality.

The conditional logic is commented out, making the --benchmark argument (defined at line 306) non-functional. The script now unconditionally runs the benchmark instead of allowing users to choose between test and benchmark modes.

Suggested fix: Restore conditional logic
-    # if args.benchmark:
-    #     speed_benchmark_decode_comparison(args)
-    # else:
-    #     test_varlen_decode_main(args)
-
-    speed_benchmark_decode_comparison(args)
+    if args.benchmark:
+        speed_benchmark_decode_comparison(args)
+    else:
+        test_varlen_decode_main(args)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# if args.benchmark:
# speed_benchmark_decode_comparison(args)
# else:
# test_varlen_decode_main(args)
speed_benchmark_decode_comparison(args)
if args.benchmark:
speed_benchmark_decode_comparison(args)
else:
test_varlen_decode_main(args)
🤖 Prompt for AI Agents
In `@examples/flash_decoding/example_gqa_decode_varlen_logits.py` around lines 314
- 319, Restore the original conditional that respects the args.benchmark flag:
revert the unconditional call to speed_benchmark_decode_comparison and
reintroduce the if/else using args.benchmark to call
speed_benchmark_decode_comparison(args) when true and
test_varlen_decode_main(args) when false; specifically update the block using
the symbols args.benchmark, speed_benchmark_decode_comparison, and
test_varlen_decode_main so the CLI flag controls which function runs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 380-391: The code dereferences as_const_int(iv->dom->extent)
unsafely; first check whether as_const_int(iv->dom->extent) has a value before
using it (e.g., auto opt_extent = as_const_int(iv->dom->extent); if
(!opt_extent) { /* fallback */ }). If opt_extent is present, use *opt_extent for
extent and call analyzer_->z3_prover.CountSatisfyingValues(iv->var, extent);
otherwise skip Z3 enumeration and fall back to the range-based calculation
(using bound->max_value and bound->min_value) to avoid undefined behavior.
- Around line 468-483: In DependsOnRuntimeValue, avoid the unsafe dereference of
as_const_int(iv->dom->extent) — check the result (e.g., ICHECK(result) or handle
the null case and set depends_on_runtime_ accordingly) before casting to int64_t
and using thread_extent; reference the symbol as_const_int and iv->dom->extent.
Replace the hard-coded warp_size = 32 with a query to the target attribute (use
target->GetAttr<Integer>("thread_warp_size") as done in
lower_thread_allreduce.cc), falling back to a sensible default if the attribute
is missing, and then pass that warp_size into
analyzer_->z3_prover.CountSatisfyingValues; reference
analyzer_->z3_prover.CountSatisfyingValues and warp_size in your changes.
- Around line 1416-1422: The loop shift currently always uses a hardcoded step
of 1 which ignores explicit loop strides; update the logic in the block that
constructs loop_shift_sub so that you read loop->step when defined and fall back
to make_const(loop->loop_var.dtype(), 1) otherwise (i.e., compute a PrimExpr
step = loop->step.defined() ? loop->step : make_const(...)), then use that step
for the substitution of loop->loop_var -> loop->loop_var + step; ensure you
reference the existing loop, loop->loop_var, loop->step, loop_shift_sub and
make_const symbols when making the change.
🧹 Nitpick comments (2)
src/transform/thread_storage_sync.cc (2)

506-507: Duplicate private: specifier.

The private: access specifier appears twice (lines 485 and 506). This is harmless but suggests copy-paste; consider removing the redundant one.


1650-1651: Extra semicolon.

Line 1651 has an unnecessary semicolon after the return statement.

     return tl::TileLangThreadSync(std::move(f), storage_scope);
-    ;

Comment on lines +1416 to +1422
ffi::Map<Var, PrimExpr> loop_shift_sub;
if (loop != nullptr) {
// Get loop step, default to 1 if not specified
PrimExpr step = make_const(loop->loop_var.dtype(), 1);
// Substitute loop_var -> loop_var + step for the "next iteration"
loop_shift_sub.Set(loop->loop_var, loop->loop_var + step);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Search for ForNode class definition
ast-grep --pattern 'class ForNode {
  $$$
}'

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Also search in .h files for ForNode definition
fd -e "\.h" | xargs rg -l "class ForNode" | head -5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for ForNode definition and its fields
rg "class ForNode" -A 30 -t cpp --type-list

Repository: tile-ai/tilelang

Length of output: 449


🏁 Script executed:

# Simple search for ForNode definition
rg "class ForNode" -A 20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for step field in relation to ForNode
rg "loop.*step\|step.*loop" -i src/transform/thread_storage_sync.cc -C 3

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check the imports in the file being reviewed
head -50 src/transform/thread_storage_sync.cc | grep -E "#include|using namespace"

Repository: tile-ai/tilelang

Length of output: 784


🏁 Script executed:

# Look for TIR loop semantics or documentation
rg "ForNode\|for.*step" -C 2 | head -100

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Read the actual code section being reviewed
sed -n '1400,1440p' src/transform/thread_storage_sync.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 1983


🏁 Script executed:

# Search for any "step" handling in the file
rg "step|stride" src/transform/thread_storage_sync.cc -n -i

Repository: tile-ai/tilelang

Length of output: 1048


🏁 Script executed:

# Look at the entire FindConflict method to understand loop handling
sed -n '1360,1450p' src/transform/thread_storage_sync.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 4512


🏁 Script executed:

# Search for ForNode usage patterns to understand loop structure
rg "ForNode\|loop\-\>" src/transform/thread_storage_sync.cc -n -B 2 -A 2 | head -80

Repository: tile-ai/tilelang

Length of output: 158


🏁 Script executed:

# Search for how ForNode or loops are created in the codebase
rg "For\(" src/transform/thread_storage_sync.cc -n -B 2 -A 2 | head -60

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for any TODO or FIXME related to step or stride
rg "TODO|FIXME" src/transform/thread_storage_sync.cc -n -i

Repository: tile-ai/tilelang

Length of output: 143


🏁 Script executed:

# Search for loop increment or step patterns in the codebase
rg "loop_var.*\+|loop_var.*step" -n | head -40

Repository: tile-ai/tilelang

Length of output: 1579


🏁 Script executed:

# Check how the loop parameter is obtained/passed to FindConflict
sed -n '1350,1365p' src/transform/thread_storage_sync.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 692


🏁 Script executed:

# Search for examples where ForNode is actually used with step
rg "loop->extend|loop->loop_var" src/transform/thread_storage_sync.cc -n -B 1 -A 1

Repository: tile-ai/tilelang

Length of output: 653


🏁 Script executed:

# Search for how ForNode/loops are constructed in this codebase
rg "For\(" -A 5 | head -80

Repository: tile-ai/tilelang

Length of output: 4385


🏁 Script executed:

# Look for comments explaining loop semantics or iteration
rg "unit.*stride|unit.*step|step.*1" -i -n | head -30

Repository: tile-ai/tilelang

Length of output: 3029


🏁 Script executed:

# Check if there are any loop transformations that create non-unit strides
rg "stride.*2|step.*2|extent.*/" src/ -n | head -30

Repository: tile-ai/tilelang

Length of output: 3099


🏁 Script executed:

# Search for ForNode step field in headers
rg "ForNode" -A 20 | grep -i "step\|extent\|loop_var" | head -40

Repository: tile-ai/tilelang

Length of output: 3452


🏁 Script executed:

# Look in tvm headers included in the file
rg "class.*For.*{" -A 30 | head -80

Repository: tile-ai/tilelang

Length of output: 149


🏁 Script executed:

# Search more specifically for For loop construction to see step usage
rg "For\(" src/ir.cc -B 2 -A 6 | head -100

Repository: tile-ai/tilelang

Length of output: 2571


🏁 Script executed:

# Check if loop->step or loop.step is ever accessed in the codebase
rg "loop->step|loop\.step|fnode->step" -n

Repository: tile-ai/tilelang

Length of output: 170


🏁 Script executed:

# Check the loop_vectorize.cc usage of fnode->step for context
sed -n '650,670p' src/transform/loop_vectorize.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 940


🏁 Script executed:

# Search for other uses of ->step in the codebase
rg "->step|\.step" src/transform/ -n -B 1 -A 1 | grep -i "loop\|for\|node" | head -40

Repository: tile-ai/tilelang

Length of output: 464


🏁 Script executed:

# Check if there's any null/optional check for step before using it
rg "step.*null|step.*optional" -i src/ | head -20

Repository: tile-ai/tilelang

Length of output: 383


🏁 Script executed:

# Look at how For() constructor is defined to understand step parameter
rg "Stmt For\(" src/ -A 3 | head -40

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check how step being optional is handled elsewhere
rg "step.*Optional\|Optional.*step" -A 3 -B 3 src/ | head -60

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for how to properly access and use loop->step
rg "fnode.*step\|op.*step\|loop.*step" src/transform/ -n -B 2 -A 2 | head -100

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if there's a helper or pattern for getting step with default
rg "step.*defined\|defined.*step\|step.*value" -i src/ | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look at the Substitute function to understand what it expects
rg "void Substitute\|PrimExpr Substitute" -A 5 | head -40

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for comments explaining what step field means
rg "step" src/ir.cc -B 2 -A 2 | head -80

Repository: tile-ai/tilelang

Length of output: 2764


🏁 Script executed:

# Look for existing patterns of handling optional step
rg "step.*defined\(\)|if.*step" src/ -i -n | head -30

Repository: tile-ai/tilelang

Length of output: 488


🏁 Script executed:

# Check if there's documentation about loop semantics
find . -name "*.md" -o -name "*.txt" | xargs grep -l "loop.*step\|ForNode" 2>/dev/null | head -5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look at the comment context around the step=1 line more broadly
sed -n '1100,1130p' src/transform/thread_storage_sync.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 1635


🏁 Script executed:

# Check if ForNode->step is ever set in this codebase
rg "loop.*step.*=" src/transform/ -n | grep -v "make_const\|vmap\|Substitute" | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for ForNode creation with step parameter
rg "For\(" src/transform/thread_storage_sync.cc -B 2 -A 5 | head -40

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for any loop transformation that creates loops with non-unit steps
rg "step.*[^1]\|stride.*[^1]" src/transform/thread_storage_sync.cc -i -n | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check what fields the loop parameter has by looking at type info
sed -n '1360,1375p' src/transform/thread_storage_sync.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 972


🏁 Script executed:

# Check if loop parameter in FindConflict is guaranteed to have step set
rg "FindConflict" src/transform/thread_storage_sync.cc -B 2 -A 2 | head -40

Repository: tile-ai/tilelang

Length of output: 1516


🏁 Script executed:

# Look for where FindConflict is called
rg "FindConflict\(" src/transform/thread_storage_sync.cc -n -B 1 -A 1

Repository: tile-ai/tilelang

Length of output: 1244


🏁 Script executed:

# Check what types of loops are passed to FindConflict
sed -n '1000,1050p' src/transform/thread_storage_sync.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 2137


🏁 Script executed:

# Look at the actual call sites to see if loop->step might be null
sed -n '1200,1250p' src/transform/thread_storage_sync.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 2108


🏁 Script executed:

# Search for any loops created with non-unit step in this entire codebase
rg "step\(" src/ -A 2 -B 2 | grep -i "make_const\|PrimExpr" | head -40

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if loop->step field is ever accessed/checked in this file
sed -n '1300,1600p' src/transform/thread_storage_sync.cc | grep -i "step"

Repository: tile-ai/tilelang

Length of output: 529


🏁 Script executed:

# Look for patterns of handling optional step like in ir.cc
rg "step.*defined\|if.*step" src/transform/ -i -B 1 -A 1 | head -40

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Verify the exact structure - does ForNode definitely have step field?
rg "ForNode.*step\|step.*ForNode" | head -20

Repository: tile-ai/tilelang

Length of output: 42


Use actual loop step value instead of hardcoding 1.

The code should check loop->step and use it if defined, falling back to 1 for unit-stride loops. Currently, step is always created as a constant 1, ignoring the loop's actual step value. This causes incorrect iteration shift analysis for loops with explicit non-unit strides.

Replace the hardcoded step assignment:

PrimExpr step = loop->step.defined() ? loop->step : make_const(loop->loop_var.dtype(), 1);
🤖 Prompt for AI Agents
In `@src/transform/thread_storage_sync.cc` around lines 1416 - 1422, The loop
shift currently always uses a hardcoded step of 1 which ignores explicit loop
strides; update the logic in the block that constructs loop_shift_sub so that
you read loop->step when defined and fall back to
make_const(loop->loop_var.dtype(), 1) otherwise (i.e., compute a PrimExpr step =
loop->step.defined() ? loop->step : make_const(...)), then use that step for the
substitution of loop->loop_var -> loop->loop_var + step; ensure you reference
the existing loop, loop->loop_var, loop->step, loop_shift_sub and make_const
symbols when making the change.

@LeiWang1999
Copy link
Member Author

@regression-perf

…dding a configurable warp size parameter to `RuntimeDependentConditionChecker` and `TileLangThreadSyncPlanner`. This allows for better adaptability to different target architectures. Update the logic to ensure thread extent is a constant and improve handling of runtime-dependent conditions.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 573-581: GetThreadVar currently calls LOG(FATAL) when env_threads_
doesn't contain the requested tag, which can crash callers (e.g., when
requesting "threadIdx.x" outside a device context); change GetThreadVar to
return an Optional/nullable IterVar (e.g., Optional<IterVar> or
std::optional<IterVar>) instead of aborting, returning empty when not found, and
update all call sites that assumed a non-null return (notably the place using
"threadIdx.x") to check has_value()/operator bool and handle the missing case
(skip hoisting/return early) before dereferencing the value.
🧹 Nitpick comments (1)
src/transform/thread_storage_sync.cc (1)

490-514: Duplicate private: access specifier.

There are two private: labels (lines 490 and 511) in the class. The second one is redundant and should be removed for cleaner code.

🧹 Proposed fix
 private:
   PrimExpr VisitExpr_(const BufferLoadNode *op) final {
     // Any buffer load introduces runtime dependency
     // (we don't know the buffer contents at compile time)
     depends_on_runtime_ = true;
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
   PrimExpr VisitExpr_(const CallNode *op) final {
     // Check tvm_access_ptr and address_of - if used in condition, it's reading
     // memory
     if (op->op.same_as(builtin::tvm_access_ptr()) ||
         op->op.same_as(builtin::address_of())) {
       depends_on_runtime_ = true;
       return IRMutatorWithAnalyzer::VisitExpr_(op);
     }
     // Other calls might also introduce runtime dependency
     // but we'll be conservative and check children
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-private:
   bool depends_on_runtime_{false};
   int warp_size_;
 };

Comment on lines +573 to +581
IterVar GetThreadVar(const std::string &tag) const {
for (const auto &iv : env_threads_) {
if (iv->thread_tag == tag) {
return iv;
}
}
LOG(FATAL) << "Thread variable " << tag << " not found";
return IterVar();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential crash if thread variable not found.

GetThreadVar will call LOG(FATAL) if the requested thread tag is not found in env_threads_. This is called at line 826 with "threadIdx.x" - if the if-statement is not inside a device context with thread extents defined, this will crash.

Consider returning an Optional<IterVar> and handling the missing case gracefully, or add a check before calling this method.

🛡️ Proposed defensive approach
-  IterVar GetThreadVar(const std::string &tag) const {
+  std::optional<IterVar> GetThreadVar(const std::string &tag) const {
     for (const auto &iv : env_threads_) {
       if (iv->thread_tag == tag) {
         return iv;
       }
     }
-    LOG(FATAL) << "Thread variable " << tag << " not found";
-    return IterVar();
+    return std::nullopt;
   }

Then at the call site (line 826):

auto tx_opt = GetThreadVar("threadIdx.x");
if (!tx_opt.has_value()) {
  // Cannot determine thread-dependency without threadIdx.x, skip hoisting
  continue;  // or return early
}
IterVar tx = tx_opt.value();
🤖 Prompt for AI Agents
In `@src/transform/thread_storage_sync.cc` around lines 573 - 581, GetThreadVar
currently calls LOG(FATAL) when env_threads_ doesn't contain the requested tag,
which can crash callers (e.g., when requesting "threadIdx.x" outside a device
context); change GetThreadVar to return an Optional/nullable IterVar (e.g.,
Optional<IterVar> or std::optional<IterVar>) instead of aborting, returning
empty when not found, and update all call sites that assumed a non-null return
(notably the place using "threadIdx.x") to check has_value()/operator bool and
handle the missing case (skip hoisting/return early) before dereferencing the
value.

@github-actions
Copy link

github-actions bot commented Feb 4, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21670836499

Results

File Original Latency Current Latency Speedup
example_tilelang_block_sparse_attn 0.0100722 0.0192853 0.52227
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.47143 4.01814 0.863938
example_dequant_gemm_bf16_mxfp4_hopper 0.508233 0.577034 0.880768
example_dequant_gemm_bf16_fp4_hopper 0.575021 0.639979 0.8985
example_gqa_sink_bwd_bhsd 0.0415604 0.0448202 0.927269
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0144315 0.0154265 0.935501
example_tilelang_sparse_gqa_decode_varlen_mask 0.0233753 0.0244456 0.956216
example_warp_specialize_gemm_copy_0_gemm_1 0.0384 0.039776 0.965406
example_mha_inference 0.078848 0.081249 0.970449
example_warp_specialize_gemm_copy_1_gemm_0 0.036545 0.037249 0.9811
example_warp_specialize_gemm_softpipe_stage2 0.038144 0.038624 0.987572
sparse_mla_fwd 0.130815 0.131876 0.991952
example_dequant_gemm_fp4_hopper 1.05608 1.06174 0.994667
example_tilelang_sparse_gqa_decode_varlen_indice 0.0169878 0.0170776 0.994738
example_gqa_bwd_wgmma_pipelined 0.0696429 0.0699645 0.995403
example_mha_bwd_bshd 0.0412265 0.0413978 0.995863
example_linear_attn_bwd 0.153546 0.154177 0.995909
example_mha_fwd_varlen 0.0454851 0.0456552 0.996276
example_gemm_intrinsics 0.035041 0.035169 0.99636
example_linear_attn_fwd 0.0367728 0.0369011 0.996523
topk_selector 0.0535703 0.0537496 0.996664
example_mha_sink_bwd_bhsd_sliding_window 0.044962 0.0450965 0.997016
example_gemm_autotune 0.022369 0.022433 0.997147
example_tilelang_nsa_fwd 0.00684659 0.00686591 0.997187
example_mha_bwd_bshd_wgmma_pipelined 0.025665 0.025735 0.997278
example_elementwise_add 0.295386 0.296152 0.997413
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0156133 0.015652 0.997526
example_mha_sink_fwd_bhsd_sliding_window 0.0158044 0.0158404 0.997729
example_mha_bwd_bhsd 0.0406328 0.0407167 0.997939
fp8_lighting_indexer 0.0357333 0.0358056 0.997981
sparse_mla_bwd 0.38499 0.385588 0.998449
example_group_per_split_token_cast_to_fp8 0.0103736 0.0103879 0.998623
example_gqa_sink_bwd_bhsd_sliding_window 0.0255479 0.0255829 0.99863
sparse_mla_fwd_pipelined 0.0959783 0.0961005 0.998728
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0146474 0.0146661 0.998731
example_dynamic 0.655819 0.656619 0.998782
example_tilelang_gemm_splitk_vectorize_atomicadd 1.42352 1.42437 0.999404
example_gqa_bwd 0.0498103 0.0498301 0.999601
example_gqa_bwd_tma_reduce_varlen 0.052195 0.0522087 0.999738
example_gemv 0.284764 0.284815 0.999817
example_gemm_schedule 0.0325499 0.0325488 1.00003
example_convolution 1.33316 1.33305 1.00008
block_sparse_attn_tilelang 0.0102479 0.0102452 1.00026
example_blocksparse_gemm 0.0226856 0.0226752 1.00046
example_dequant_gemm_w4a8 5.39541 5.39207 1.00062
example_tilelang_gemm_splitk 1.42419 1.42324 1.00067
example_mha_sink_bwd_bhsd 0.0626584 0.0626078 1.00081
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0155808 0.0155662 1.00094
example_mla_decode 0.461319 0.460873 1.00097
example_vertical_slash_sparse_attn 0.23712 0.236812 1.0013
example_convolution_autotune 0.997967 0.996533 1.00144
example_tilelang_gemm_fp8_intrinsic 0.934734 0.93262 1.00227
example_tilelang_nsa_decode 0.00744042 0.00742253 1.00241
example_topk 0.01088 0.010848 1.00295
example_mha_sink_fwd_bhsd 0.0160486 0.0160008 1.00299
example_tilelang_gemm_fp8 0.322678 0.321703 1.00303
example_gqa_decode 0.048257 0.048096 1.00335
example_dequant_gemv_fp16xint4 0.0285176 0.02842 1.00344
tilelang_example_sparse_tensorcore 0.0150869 0.0150282 1.0039
example_per_token_cast_to_fp8 0.00744704 0.00740874 1.00517
example_tilelang_gemm_fp8_2xAcc 0.190306 0.188131 1.01156
example_gemm 0.023137 0.022624 1.02268
example_warp_specialize_gemm_barrierpipe_stage2 0.039424 0.038368 1.02752

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

…pointer checks instead of optional values. This change improves clarity and ensures that the thread extent is correctly validated as a constant.
@LeiWang1999
Copy link
Member Author

@regression-perf

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/transform/thread_storage_sync.cc`:
- Around line 380-391: The code dereferences as_const_int(iv->dom->extent)
unsafely; change the logic in the block around as_const_int, iv->dom->extent and
analyzer_->z3_prover.CountSatisfyingValues so you first check the
optional/nullable result of as_const_int before using it (e.g., auto
maybe_extent = as_const_int(iv->dom->extent); if not present, skip Z3
enumeration and fall back to the range-based calculation using bound->max_value
and bound->min_value); only call CountSatisfyingValues when you have a valid
extent and keep the existing fallback to return
static_cast<size_t>(bound->max_value - bound->min_value + 1).
🧹 Nitpick comments (1)
src/transform/thread_storage_sync.cc (1)

1667-1668: Minor: Remove extra semicolon.

Line 1668 has a double semicolon ;; which is harmless but unnecessary.

🧹 Proposed fix
     return tl::TileLangThreadSync(std::move(f), storage_scope);
-    ;

Comment on lines +380 to +391
auto extent = *as_const_int(iv->dom->extent);
// Always use Z3 enumeration to count satisfying values.
// This handles constraints like `tx % 4 == 0` that const_int_bound cannot
// detect. Z3 enumeration will return the exact count of satisfying values.
int64_t z3_count =
analyzer_->z3_prover.CountSatisfyingValues(iv->var, extent);
if (z3_count > 0) {
return static_cast<size_t>(z3_count);
}

// Fallback to range-based calculation if Z3 enumeration failed
return static_cast<size_t>(bound->max_value - bound->min_value + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Unsafe dereference of as_const_int result.

Line 380 dereferences as_const_int(iv->dom->extent) without checking if it returns a valid value. If the extent is not a compile-time constant, this causes undefined behavior.

🐛 Proposed fix
-    auto extent = *as_const_int(iv->dom->extent);
+    auto extent_opt = as_const_int(iv->dom->extent);
+    if (!extent_opt) {
+      // Fallback to range-based calculation if extent is not const
+      return static_cast<size_t>(bound->max_value - bound->min_value + 1);
+    }
+    auto extent = *extent_opt;
🤖 Prompt for AI Agents
In `@src/transform/thread_storage_sync.cc` around lines 380 - 391, The code
dereferences as_const_int(iv->dom->extent) unsafely; change the logic in the
block around as_const_int, iv->dom->extent and
analyzer_->z3_prover.CountSatisfyingValues so you first check the
optional/nullable result of as_const_int before using it (e.g., auto
maybe_extent = as_const_int(iv->dom->extent); if not present, skip Z3
enumeration and fall back to the range-based calculation using bound->max_value
and bound->min_value); only call CountSatisfyingValues when you have a valid
extent and keep the existing fallback to return
static_cast<size_t>(bound->max_value - bound->min_value + 1).

@github-actions
Copy link

github-actions bot commented Feb 4, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21677167658

Results

File Original Latency Current Latency Speedup
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.57714 4.1598 0.859933
example_dequant_gemm_bf16_mxfp4_hopper 0.505994 0.565863 0.894199
example_dequant_gemm_bf16_fp4_hopper 0.567342 0.626122 0.906121
example_warp_specialize_gemm_copy_0_gemm_1 0.038529 0.040032 0.962455
example_gemm_intrinsics 0.034465 0.034818 0.989862
example_linear_attn_bwd 0.151396 0.152445 0.993115
example_dequant_gemm_fp4_hopper 1.03967 1.04671 0.993274
example_tilelang_gemm_fp8_2xAcc 0.180868 0.181973 0.993933
example_gemm 0.022529 0.022624 0.995801
example_warp_specialize_gemm_barrierpipe_stage2 0.038592 0.038753 0.995845
example_gemm_schedule 0.0322207 0.0322549 0.998939
example_convolution_autotune 0.992215 0.993034 0.999175
fp8_lighting_indexer 0.0353641 0.0353867 0.999361
example_tilelang_gemm_fp8 0.318408 0.318597 0.999407
example_gqa_bwd_wgmma_pipelined 0.0687508 0.0687892 0.999441
example_tilelang_sparse_gqa_decode_varlen_indice 0.016892 0.0169005 0.999494
example_tilelang_gemm_splitk 1.40157 1.40227 0.999506
example_mha_bwd_bshd_wgmma_pipelined 0.025514 0.0255237 0.999616
example_blocksparse_gemm 0.0224169 0.022423 0.999729
example_elementwise_add 0.293856 0.293922 0.999775
example_tilelang_block_sparse_attn 0.0100695 0.0100717 0.99978
example_mla_decode 0.449193 0.449289 0.999786
example_gqa_bwd_tma_reduce_varlen 0.0513046 0.0513141 0.999816
example_dequant_gemv_fp16xint4 0.0283651 0.02837 0.99983
example_per_token_cast_to_fp8 0.00739462 0.00739476 0.999981
example_topk 0.01072 0.01072 1
block_sparse_attn_tilelang 0.0101571 0.0101567 1.00004
example_gqa_bwd 0.0490345 0.0490325 1.00004
example_linear_attn_fwd 0.0365529 0.0365497 1.00009
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40167 1.40152 1.00011
example_tilelang_nsa_decode 0.0073083 0.00730725 1.00014
example_convolution 1.3086 1.30838 1.00017
tilelang_example_sparse_tensorcore 0.0148989 0.0148953 1.00025
example_tilelang_sparse_gqa_decode_varlen_mask 0.0231295 0.0231232 1.00027
example_gqa_sink_bwd_bhsd 0.0408332 0.0408213 1.00029
example_dynamic 0.651372 0.650986 1.00059
example_group_per_split_token_cast_to_fp8 0.0103255 0.0103166 1.00086
example_gemv 0.281559 0.281284 1.00098
example_gqa_sink_bwd_bhsd_sliding_window 0.0251476 0.0251223 1.00101
example_dequant_gemm_w4a8 5.30756 5.30174 1.0011
example_tilelang_gemm_fp8_intrinsic 0.911638 0.910444 1.00131
example_vertical_slash_sparse_attn 0.232037 0.231648 1.00168
example_mha_bwd_bhsd 0.0401083 0.0400393 1.00172
example_tilelang_nsa_fwd 0.00682512 0.00681247 1.00186
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0144391 0.0144118 1.00189
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0142922 0.0142615 1.00215
topk_selector 0.0532249 0.0530858 1.00262
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0153339 0.0152936 1.00264
sparse_mla_fwd 0.129952 0.129588 1.00281
example_mha_sink_bwd_bhsd 0.0616214 0.0614337 1.00306
example_mha_sink_bwd_bhsd_sliding_window 0.0445364 0.044398 1.00312
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0154022 0.0153499 1.00341
example_mha_bwd_bshd 0.0407729 0.0406233 1.00368
sparse_mla_bwd 0.378704 0.377263 1.00382
example_gemm_autotune 0.022176 0.02208 1.00435
example_mha_sink_fwd_bhsd_sliding_window 0.0156008 0.01553 1.00455
example_mha_sink_fwd_bhsd 0.0158154 0.0157431 1.00459
example_mha_fwd_varlen 0.0452186 0.044968 1.00557
sparse_mla_fwd_pipelined 0.0953161 0.0947067 1.00643
example_gqa_decode 0.047457 0.047137 1.00679
example_mha_inference 0.078946 0.078016 1.01192
example_warp_specialize_gemm_copy_1_gemm_0 0.036929 0.036384 1.01498
example_warp_specialize_gemm_softpipe_stage2 0.039264 0.038529 1.01908

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999
Copy link
Member Author

@regression-perf

@github-actions
Copy link

github-actions bot commented Feb 4, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21678454780

Results

File Original Latency Current Latency Speedup
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.5139 4.17261 0.842134
example_dequant_gemm_bf16_mxfp4_hopper 0.505482 0.563816 0.896537
example_dequant_gemm_bf16_fp4_hopper 0.568394 0.62478 0.909751
example_warp_specialize_gemm_barrierpipe_stage2 0.037729 0.040033 0.942447
example_warp_specialize_gemm_copy_1_gemm_0 0.03552 0.037376 0.950342
example_mha_fwd_bhsd 0.010753 0.011105 0.968303
example_gqa_decode 0.047073 0.048546 0.969658
example_tilelang_gemm_fp8_2xAcc 0.181826 0.183169 0.992665
example_linear_attn_bwd 0.151425 0.152459 0.993219
example_gemm_intrinsics 0.03456 0.034721 0.995363
sparse_mla_fwd 0.129914 0.130398 0.996291
example_topk 0.010688 0.01072 0.997015
example_tilelang_gemm_fp8 0.318101 0.31879 0.997839
example_mha_fwd_bhsd_wgmma_pipelined 0.014209 0.014239 0.997893
sparse_mla_bwd 0.37854 0.379109 0.998497
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.015332 0.0153541 0.998559
example_tilelang_block_sparse_attn 0.010061 0.0100713 0.998976
example_tilelang_gemm_fp8_intrinsic 0.910616 0.911498 0.999032
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0154076 0.0154222 0.999057
example_mha_sink_fwd_bhsd 0.015759 0.0157735 0.999081
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0144343 0.014445 0.999257
example_tilelang_sparse_gqa_decode_varlen_indice 0.016878 0.0168871 0.999463
tilelang_example_sparse_tensorcore 0.0148972 0.0149043 0.99953
example_gqa_bwd_wgmma_pipelined 0.068857 0.0688876 0.999556
example_mha_inference 0.078498 0.078529 0.999605
example_group_per_split_token_cast_to_fp8 0.0103199 0.0103236 0.999642
example_dequant_gemm_w4a8 5.30422 5.306 0.999663
example_mla_decode 0.449094 0.449193 0.99978
example_tilelang_nsa_decode 0.0073042 0.00730553 0.999818
example_mha_fwd_varlen 0.0452333 0.0452399 0.999855
example_gqa_sink_bwd_bhsd 0.0408103 0.0408153 0.999877
example_mha_bwd_bshd_wgmma_pipelined 0.0255016 0.0255044 0.99989
example_mha_sink_bwd_bhsd_sliding_window 0.0444233 0.0444272 0.999913
example_per_token_cast_to_fp8 0.00739492 0.00739547 0.999925
example_gqa_bwd_tma_reduce_varlen 0.0512638 0.0512675 0.999929
example_gqa_sink_bwd_bhsd_sliding_window 0.0251556 0.025157 0.999947
fp8_lighting_indexer 0.0353705 0.0353703 1.00001
example_gemm_schedule 0.0322355 0.0322327 1.00009
example_tilelang_gemm_splitk 1.40185 1.40166 1.00014
example_mha_sink_bwd_bhsd 0.0616556 0.0616446 1.00018
example_gemv 0.281535 0.281481 1.00019
example_convolution 1.30911 1.30886 1.0002
example_linear_attn_fwd 0.0365515 0.036542 1.00026
example_gqa_bwd 0.0490396 0.0490228 1.00034
example_tilelang_sparse_gqa_decode_varlen_mask 0.0231372 0.0231275 1.00042
example_mha_bwd_bhsd 0.0401103 0.0400935 1.00042
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40119 1.40051 1.00048
example_dynamic 0.652195 0.65185 1.00053
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0142997 0.0142915 1.00057
topk_selector 0.0532269 0.0531964 1.00057
example_elementwise_add 0.294113 0.293944 1.00057
example_mha_bwd_bshd 0.0407748 0.0407511 1.00058
example_dequant_gemv_fp16xint4 0.0283737 0.0283569 1.00059
example_mha_sink_fwd_bhsd_sliding_window 0.0156008 0.0155916 1.00059
example_blocksparse_gemm 0.0224169 0.0224035 1.0006
block_sparse_attn_tilelang 0.0101624 0.0101522 1.001
sparse_mla_fwd_pipelined 0.0949536 0.0948415 1.00118
example_gemm_autotune 0.022145 0.022113 1.00145
example_tilelang_nsa_fwd 0.00682057 0.00680954 1.00162
example_gqa_fwd_bshd 0.070849 0.07072 1.00182
example_convolution_autotune 0.99333 0.991433 1.00191
example_dequant_gemm_fp4_hopper 1.0371 1.03467 1.00235
example_vertical_slash_sparse_attn 0.232164 0.23161 1.00239
example_gemm 0.022688 0.022592 1.00425
example_warp_specialize_gemm_copy_0_gemm_1 0.038271 0.038081 1.00499
example_mha_fwd_bshd 0.025761 0.025568 1.00755
example_gqa_fwd_bshd_wgmma_pipelined 0.054753 0.054176 1.01065
example_mha_fwd_bshd_wgmma_pipelined 0.0144 0.014176 1.0158
example_warp_specialize_gemm_softpipe_stage2 0.039041 0.037857 1.03128

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

…-carry analysis by modifying the extent calculation. This change ensures valid iteration comparisons by reducing the extent by one, allowing for accurate analysis of loop iterations.
@LeiWang1999
Copy link
Member Author

@regression-perf

@github-actions
Copy link

github-actions bot commented Feb 4, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21683345156

Results

File Original Latency Current Latency Speedup
example_dequant_gemm_bf16_mxfp4_hopper 0.496874 0.564177 0.880706
example_dequant_gemm_bf16_fp4_hopper 0.55518 0.625001 0.888287
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.58554 3.85634 0.929778
example_warp_specialize_gemm_barrierpipe_stage2 0.038529 0.039649 0.971752
example_warp_specialize_gemm_softpipe_stage2 0.036673 0.037696 0.972862
example_dequant_gemm_fp4_hopper 1.02722 1.04346 0.98444
example_gqa_decode 0.047585 0.04832 0.984789
example_tilelang_gemm_fp8_2xAcc 0.182068 0.18471 0.985696
example_topk 0.010751 0.010848 0.991058
example_gemm_intrinsics 0.034497 0.034784 0.991749
example_linear_attn_bwd 0.151352 0.152448 0.992807
example_gemm_autotune 0.022177 0.022272 0.995735
example_dynamic 0.651312 0.653424 0.996768
sparse_mla_fwd 0.129967 0.13036 0.996983
example_gqa_bwd_wgmma_pipelined 0.068787 0.0688816 0.998626
example_tilelang_gemm_fp8 0.319367 0.319781 0.998705
example_tilelang_sparse_gqa_decode_varlen_indice 0.0168903 0.0169118 0.998732
sparse_mla_bwd 0.378625 0.379081 0.998795
example_per_token_cast_to_fp8 0.00739112 0.00739618 0.999316
example_mha_bwd_bshd_wgmma_pipelined 0.0255104 0.025525 0.999427
example_tilelang_nsa_decode 0.00730348 0.00730751 0.999448
sparse_mla_fwd_pipelined 0.0948644 0.0949149 0.999467
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40055 1.40102 0.999667
example_gqa_bwd_tma_reduce_varlen 0.0513053 0.0513208 0.999698
example_mha_bwd_bhsd 0.0400992 0.0401091 0.999754
example_mla_decode 0.44913 0.449226 0.999786
example_tilelang_sparse_gqa_decode_varlen_mask 0.0231327 0.0231346 0.999918
example_tilelang_block_sparse_attn 0.0100692 0.0100689 1.00003
example_elementwise_add 0.293983 0.293966 1.00006
example_gqa_sink_bwd_bhsd 0.0408282 0.0408231 1.00012
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0142946 0.0142928 1.00013
example_tilelang_gemm_splitk 1.40228 1.40191 1.00026
example_gemm_schedule 0.0322513 0.0322427 1.00027
example_gqa_bwd 0.0490215 0.0490063 1.00031
tilelang_example_sparse_tensorcore 0.0149037 0.0148985 1.00035
example_mha_bwd_bshd 0.0407743 0.0407601 1.00035
example_gqa_sink_bwd_bhsd_sliding_window 0.0251492 0.0251393 1.0004
block_sparse_attn_tilelang 0.0101639 0.0101594 1.00044
example_linear_attn_fwd 0.0365525 0.036534 1.0005
fp8_lighting_indexer 0.0353907 0.0353695 1.0006
example_dequant_gemv_fp16xint4 0.0283642 0.0283461 1.00064
example_mha_inference 0.078562 0.078497 1.00083
topk_selector 0.0532339 0.0531885 1.00085
example_group_per_split_token_cast_to_fp8 0.0103364 0.0103273 1.00088
example_convolution 1.31004 1.30876 1.00098
example_tilelang_gemm_fp8_intrinsic 0.911318 0.910261 1.00116
example_gemv 0.281809 0.281441 1.00131
example_blocksparse_gemm 0.0224613 0.0224295 1.00142
example_mha_fwd_varlen 0.0451942 0.0451302 1.00142
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0144525 0.0144316 1.00145
example_gemm 0.022753 0.02272 1.00145
example_tilelang_nsa_fwd 0.00682168 0.00681168 1.00147
example_vertical_slash_sparse_attn 0.232069 0.231606 1.002
example_mha_sink_bwd_bhsd 0.0616855 0.0614848 1.00326
example_mha_sink_fwd_bhsd_sliding_window 0.0155945 0.0155427 1.00333
example_convolution_autotune 0.991989 0.988275 1.00376
example_gqa_fwd_bshd 0.070625 0.070273 1.00501
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0154258 0.0153463 1.00517
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0153532 0.0152703 1.00543
example_mha_sink_fwd_bhsd 0.0158014 0.0157162 1.00543
example_mha_sink_bwd_bhsd_sliding_window 0.0445732 0.0442705 1.00684
example_warp_specialize_gemm_copy_0_gemm_1 0.039168 0.038657 1.01322
example_gqa_fwd_bshd_wgmma_pipelined 0.055361 0.054594 1.01405
example_warp_specialize_gemm_copy_1_gemm_0 0.037984 0.036864 1.03038

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999 LeiWang1999 merged commit 191d879 into main Feb 4, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant