Skip to content

[Example] Add Seesaw Sparse MLA Forward Kernel for DeepSeek-V3.2#1636

Merged
LeiWang1999 merged 1 commit intotile-ai:mainfrom
hammersam:main
Jan 8, 2026
Merged

[Example] Add Seesaw Sparse MLA Forward Kernel for DeepSeek-V3.2#1636
LeiWang1999 merged 1 commit intotile-ai:mainfrom
hammersam:main

Conversation

@hammersam
Copy link
Contributor

@hammersam hammersam commented Jan 8, 2026

Add a new sparse MLA forward kernel implementation using the "Seesaw" synchronization pattern, an alternative to the existing pipelined approach.

Dual-Consumer Parallel Architecture:

  • Unlike pipelined version where WG1 depends on WG0's S_shared/alpha_shared, Seesaw allows both consumers to work independently on different KV blocks
  • Consumer 0 (WG0): Processes even blocks (BI_2*i), computes O_L (left half)
  • Consumer 1 (WG1): Processes odd blocks (BI_2*i+1), computes O_R (right half)
  • Each consumer maintains its own softmax statistics (m_i, sumexp)

Seesaw Synchronization Mechanism:

  • Consumers exchange local row_max via bar_stats_0/1_ready barriers
  • Both compute global max by taking max of local and peer's max
  • S matrices exchanged via bar_S_0/1_ready for cross-attention:
    • O_L += P0 @ V0_L (self) + P1 @ V1_L (from peer)
    • O_R += P1 @ V1_R (self) + P0 @ V0_R (from peer)

Memory Optimizations:

  • Reuses K_tail_shared_0/1 as S_shared_0/1 to save shared memory
  • Double-buffered is_kv_valid[2, BI] mask to avoid race conditions
  • Index prefetching in producer to hide memory latency
Aspect Pipelined Seesaw
Consumer dependency WG1 waits for WG0's S/alpha Independent parallel compute

Prefill Benchmark (B=2, S=4096, SKV=8192, H=128, topk=2048):

  • Average time: 10.276 ms
  • IO bandwidth: 1.88 TB/s
  • TFLOPS: 454.76

Decode Benchmark (B=2048, S=2, SKV=8192, H=128, topk=2048):

  • Average time: 5.554 ms
  • IO bandwidth: 1.74 TB/s
  • TFLOPS: 420.68

Summary by CodeRabbit

Release Notes

  • New Features
    • Added a new sparse multi-head attention forward kernel example with support for efficient computation on specialized hardware architectures
    • Includes reference implementation for validation and comprehensive test utilities with profiling capabilities

✏️ Tip: You can customize this high-level summary in your review settings.

Add a new sparse MLA forward kernel implementation using the "Seesaw"
synchronization pattern, an alternative to the existing pipelined approach.

**Dual-Consumer Parallel Architecture:**
- Unlike pipelined version where WG1 depends on WG0's S_shared/alpha_shared,
  Seesaw allows both consumers to work independently on different KV blocks
- Consumer 0 (WG0): Processes even blocks (BI_2*i), computes O_L (left half)
- Consumer 1 (WG1): Processes odd blocks (BI_2*i+1), computes O_R (right half)
- Each consumer maintains its own softmax statistics (m_i, sumexp)

**Seesaw Synchronization Mechanism:**
- Consumers exchange local row_max via bar_stats_0/1_ready barriers
- Both compute global max by taking max of local and peer's max
- S matrices exchanged via bar_S_0/1_ready for cross-attention:
  - O_L += P0 @ V0_L (self) + P1 @ V1_L (from peer)
  - O_R += P1 @ V1_R (self) + P0 @ V0_R (from peer)

**Memory Optimizations:**
- Reuses K_tail_shared_0/1 as S_shared_0/1 to save shared memory
- Double-buffered is_kv_valid[2, BI] mask to avoid race conditions
- Index prefetching in producer to hide memory latency

| Aspect | Pipelined | Seesaw |
|--------|-----------|--------|
| Consumer dependency | WG1 waits for WG0's S/alpha | Independent parallel compute |
| S matrix | Single S_shared | Dual S_shared_0/1 (reused) |
| Softmax stats | Single m_i, sumexp | Per-consumer stats with exchange |
| KV valid mask | Single buffer [BI] | Double buffer [2, BI] |
| Index prefetch | None | Async prefetch next iteration |
| Register alloc | WG0:240, WG1:168, Prod:80 | WG0:216, WG1:216, Prod:72 |

**Prefill Benchmark** (B=2, S=4096, SKV=8192, H=128, topk=2048):
- Average time: 10.276 ms
- IO bandwidth: 1.88 TB/s
- TFLOPS: 454.76

**Decode Benchmark** (B=2048, S=2, SKV=8192, H=128, topk=2048):
- Average time: 5.554 ms
- IO bandwidth: 1.74 TB/s
- TFLOPS: 420.68

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@github-actions
Copy link

github-actions bot commented Jan 8, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

📝 Walkthrough

Walkthrough

Introduces a new TileLang-based sparse multi-head attention forward kernel with seesaw producer-consumer synchronization, including kernel implementation, public execution interfaces, reference validation function, and test harness with profiling capabilities.

Changes

Cohort / File(s) Summary
Sparse MLA Forward Kernel
examples/deepseek_v32/sparse_mla_fwd_seesaw.py
New file containing a complete sparse MLA forward implementation: sparse_mla_fwd() kernel creator with shape/dtype/parameter validation; complex multi-phase kernel with shared memory, barrier synchronization, and seesaw producer-consumer flow; sparse_mla_fwd_interface() for kernel execution and optional zeroing; ref_sparse_mla_fwd_interface() PyTorch reference; and test_sparse_mla_fwd_pipelined() test harness with correctness checking and profiling metrics.

Sequence Diagram(s)

sequenceDiagram
    participant ThreadBlocks as Block Groups
    participant SharedMem as Shared Memory
    participant Barrier as Barrier Sync
    participant Softmax as Softmax Engine
    participant Output as Output Buffer
    
    rect rgb(240, 248, 255)
    Note over ThreadBlocks,Output: Phase 0 (Even Blocks)
    ThreadBlocks->>SharedMem: Prefetch KV data
    ThreadBlocks->>Barrier: Block barrier enter
    Barrier-->>ThreadBlocks: Proceed when ready
    ThreadBlocks->>SharedMem: Load Q, compute QK prod
    ThreadBlocks->>Softmax: Per-block softmax + max-trace
    end
    
    rect rgb(255, 240, 245)
    Note over ThreadBlocks,Output: Seesaw Switch
    Barrier->>Barrier: Producer-consumer flip
    end
    
    rect rgb(240, 255, 240)
    Note over ThreadBlocks,Output: Phase 1 (Odd Blocks)
    ThreadBlocks->>SharedMem: Async load AV pairs
    ThreadBlocks->>Barrier: Block barrier wait
    Barrier-->>ThreadBlocks: Previous phase complete
    ThreadBlocks->>Output: Compute V aggregation
    ThreadBlocks->>Output: Store final output
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 A kernel hops through phases bright,
Seesaw blocks synchronize just right,
Barriers dance with Q and K,
Sparse attention leads the way—
Attention, efficient, oh what a sight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding a seesaw sparse MLA forward kernel for DeepSeek-V3.2, which is fully supported by the 644 lines of new code in the sparse_mla_fwd_seesaw.py file.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In @examples/deepseek_v32/sparse_mla_fwd_seesaw.py:
- Around line 470-493: The parameter name is_casual in sparse_mla_fwd_interface
is a typo and should be is_causal; rename the parameter in the function
signature and all uses inside sparse_mla_fwd_interface (including the boolean
CP0 computation and the argument passed to sparse_mla_fwd) from is_casual to
is_causal so the API and the kernel call (sparse_mla_fwd) use the correct
spelling and avoid confusion.
- Line 511: The parameter name in ref_sparse_mla_fwd_interface is misspelled:
change the parameter from is_casual to is_causal and update all references
inside ref_sparse_mla_fwd_interface and any call sites to use is_causal; also
ensure the kernel invocation and any wrapper/forwarding functions (e.g., the
corresponding kernel function name referenced in this file) use the same
is_causal name so signatures match and tests/linters stop complaining about the
typo.
- Around line 576-582: The sentinel value SKV currently equals 8192 and can be
mistaken as a valid index when computed max_kv_i exceeds SKV; change the
sentinel approach by initializing indices with -1 (ensure dtype supports signed
ints) or add explicit bounds checks before any KV access: replace any truth
check like index <= max_kv_i with a compound check index >= 0 && index < SKV,
and ensure downstream code that uses indices treats negative sentinels as
"unused" (e.g., skip/mask accesses). Update the generation site (indices =
torch.full(..., SKV, ...)) and the validation sites referenced in the diff (the
checks currently at lines handling max_kv_i) to use the new sentinel/bounds
logic so KV[..., index, ...] is only accessed for 0 <= index < SKV.
🧹 Nitpick comments (1)
examples/deepseek_v32/sparse_mla_fwd_seesaw.py (1)

340-344: Consider removing commented-out code or adding clarifying documentation.

Line 342 contains commented-out barrier wait code. While the comment at lines 313-315 explains the implicit synchronization through stats barriers, keeping dead code can be confusing for future maintainers. Either remove the line entirely or convert it to a more explicit documentation note explaining why this synchronization point is unnecessary.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aca9218 and 3e821d2.

📒 Files selected for processing (1)
  • examples/deepseek_v32/sparse_mla_fwd_seesaw.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/deepseek_v32/sparse_mla_fwd_seesaw.py (5)
tilelang/math/__init__.py (1)
  • next_power_of_2 (1-2)
tilelang/language/allocate.py (3)
  • alloc_shared (39-54)
  • alloc_fragment (71-82)
  • alloc_barrier (150-159)
tilelang/language/copy_op.py (1)
  • copy (14-116)
tilelang/language/builtin.py (2)
  • barrier_arrive (649-656)
  • barrier_wait (634-646)
tilelang/language/reduce_op.py (2)
  • reduce_max (107-125)
  • reduce_sum (144-166)
🔇 Additional comments (2)
examples/deepseek_v32/sparse_mla_fwd_seesaw.py (2)

127-130: Verify dimension match for S_shared aliasing K_tail_shared.

The aliasing S_shared_0 = K_tail_shared_0 assumes K_tail_shared shape [BI, D_tail] matches the needed S matrix shape [H_per_block, BI]. This requires H_per_block == D_tail. If these dimensions diverge in future configurations, this aliasing will cause incorrect behavior or crashes.

Consider adding an assertion to enforce this invariant:

assert H_per_block == D_tail, f"S_shared aliasing requires H_per_block ({H_per_block}) == D_tail ({D_tail})"

378-378: Remove this review comment; the concern is incorrect.

The LSE computation at line 378 using m_i_0 is correct. Through iterative per-iteration synchronization with clear=False accumulation in the loop, m_i_0 contains the global maximum across all blocks from both consumers by loop termination, not just Consumer 0's local maximum. Both consumers converge to identical m_i and sumexp values, so LSE written by Consumer 0 is consistent.

Comment on lines +470 to +493
def sparse_mla_fwd_interface(
q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False
):
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
batch, seq_len, heads, dim_plus_tail_dim = q.shape
_, seq_len_kv, kv_group, _ = kv.shape

assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = 512

assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
assert kv.shape[0] == batch
_, _, _, topk = indices.shape
assert indices.shape == (batch, seq_len, kv_group, topk)

if q_start_index_s != 0:
assert q_start_index_s > kv_stride, (
"If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
)
CP0 = q_start_index_s == 0

# Compile the kernel
kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typo: is_casual should be is_causal.

The parameter name is_casual (line 471) is a typo that should be is_causal. This inconsistency propagates to line 493 where it's passed to the kernel function that expects is_causal. While Python's keyword argument passing will still work due to positional ordering, this creates confusing API inconsistency with the kernel parameter name.

🐛 Proposed fix
 def sparse_mla_fwd_interface(
-    q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False
+    q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_causal=True, return_kernel=False, print_kernel=False
 ):
     # ... lines 473-492 ...
-    kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0)
+    kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_causal, CP0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sparse_mla_fwd_interface(
q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False
):
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
batch, seq_len, heads, dim_plus_tail_dim = q.shape
_, seq_len_kv, kv_group, _ = kv.shape
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = 512
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
assert kv.shape[0] == batch
_, _, _, topk = indices.shape
assert indices.shape == (batch, seq_len, kv_group, topk)
if q_start_index_s != 0:
assert q_start_index_s > kv_stride, (
"If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
)
CP0 = q_start_index_s == 0
# Compile the kernel
kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0)
def sparse_mla_fwd_interface(
q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_causal=True, return_kernel=False, print_kernel=False
):
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
batch, seq_len, heads, dim_plus_tail_dim = q.shape
_, seq_len_kv, kv_group, _ = kv.shape
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = 512
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
assert kv.shape[0] == batch
_, _, _, topk = indices.shape
assert indices.shape == (batch, seq_len, kv_group, topk)
if q_start_index_s != 0:
assert q_start_index_s > kv_stride, (
"If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
)
CP0 = q_start_index_s == 0
# Compile the kernel
kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_causal, CP0)
🤖 Prompt for AI Agents
In @examples/deepseek_v32/sparse_mla_fwd_seesaw.py around lines 470 - 493, The
parameter name is_casual in sparse_mla_fwd_interface is a typo and should be
is_causal; rename the parameter in the function signature and all uses inside
sparse_mla_fwd_interface (including the boolean CP0 computation and the argument
passed to sparse_mla_fwd) from is_casual to is_causal so the API and the kernel
call (sparse_mla_fwd) use the correct spelling and avoid confusion.

return out, lse


def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=1, sm_scale=None, is_casual=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same typo: is_casual should be is_causal.

For consistency with the kernel function and standard terminology, rename is_casual to is_causal.

🐛 Proposed fix
-def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=1, sm_scale=None, is_casual=True):
+def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=1, sm_scale=None, is_causal=True):

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @examples/deepseek_v32/sparse_mla_fwd_seesaw.py at line 511, The parameter
name in ref_sparse_mla_fwd_interface is misspelled: change the parameter from
is_casual to is_causal and update all references inside
ref_sparse_mla_fwd_interface and any call sites to use is_causal; also ensure
the kernel invocation and any wrapper/forwarding functions (e.g., the
corresponding kernel function name referenced in this file) use the same
is_causal name so signatures match and tests/linters stop complaining about the
typo.

Comment on lines +576 to +582
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda")
for b in range(B):
for t in range(S):
for h in range(HKV):
# Add offset q_start_s_index to convert to global sequence position
i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk]
indices[b, t, h, : len(i_i)] = i_i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file structure and find the kernel definition
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | head -50

Repository: tile-ai/tilelang

Length of output: 1820


🏁 Script executed:

# Now search for is_kv_valid and max_kv_i to understand the validity checks
rg -n "is_kv_valid|max_kv_i" examples/deepseek_v32/ -A 2 -B 2

Repository: tile-ai/tilelang

Length of output: 9582


🏁 Script executed:

# Look at lines 576-582 in context and surrounding code
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '560,600p'

Repository: tile-ai/tilelang

Length of output: 1985


🏁 Script executed:

# Check the producer kernel code around lines 218-219, 253-254 mentioned in the review
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '210,260p'

Repository: tile-ai/tilelang

Length of output: 4035


🏁 Script executed:

# Search for how indices are actually used/accessed in the kernel
rg -n "indices\[" examples/deepseek_v32/sparse_mla_fwd_seesaw.py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 354


🏁 Script executed:

# Check the function signature and default parameters
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '545,575p'

Repository: tile-ai/tilelang

Length of output: 1117


🏁 Script executed:

# Look for how this test function is called and what values are passed
rg -n "test_sparse_mla_fwd_seesaw" examples/deepseek_v32/sparse_mla_fwd_seesaw.py -B 5 -A 15

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if there's any actual OOB protection after validity check
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '215,245p'

Repository: tile-ai/tilelang

Length of output: 2697


🏁 Script executed:

# Check how max_kv_i is computed in sparse_mla_fwd_pipelined.py (the kernel being tested)
cat -n examples/deepseek_v32/sparse_mla_fwd_pipelined.py | sed -n '125,135p'

Repository: tile-ai/tilelang

Length of output: 564


🏁 Script executed:

# Verify the relationship: with test defaults, what is the maximum max_kv_i value?
python3 << 'EOF'
# Test parameters from the function
B = 1
S = 4096  # sequence length (num query positions)
SKV = 8192  # KV sequence length (size of KV tensor dimension 1)
q_start_s_index = 2048

# In kernel, q_i = q_start_index_s[0] + s_i, where s_i is the sequence index (0 to S-1)
# With q_start_index_s[0] = q_start_s_index = 2048, and s_i in range(S)
# q_i ranges from 2048 to 2048 + 4095 = 6143

# max_kv_i = (q_i + 1 - KV_stride) // KV_stride, with KV_stride=1
# max_kv_i = q_i (since (q_i + 1 - 1) // 1 = q_i)

min_q_i = q_start_s_index  # 2048
max_q_i = q_start_s_index + S - 1  # 2048 + 4095 = 6143

max_kv_i_min = min_q_i  # 2048
max_kv_i_max = max_q_i  # 6143

print(f"q_start_s_index: {q_start_s_index}")
print(f"S (num query positions): {S}")
print(f"SKV (KV sequence length): {SKV}")
print(f"q_i range: [{min_q_i}, {max_q_i}]")
print(f"max_kv_i range: [{max_kv_i_min}, {max_kv_i_max}]")
print(f"\nSentinel value (SKV): {SKV}")
print(f"Highest max_kv_i: {max_kv_i_max}")
print(f"SKV > max_kv_i? {SKV > max_kv_i_max}")
print(f"Would sentinel pass validity check 'index <= max_kv_i'? {SKV <= max_kv_i_max}")
EOF

Repository: tile-ai/tilelang

Length of output: 146


🏁 Script executed:

# Check if there are any other test calls with different parameters that could trigger the issue
rg -n "test_sparse_mla_fwd_pipelined\(" examples/deepseek_v32/ -A 5

Repository: tile-ai/tilelang

Length of output: 3342


🏁 Script executed:

# Calculate for the test case at line 643 with q_start_s_index=6144
python3 << 'EOF'
# Test parameters from line 634-636 and 642-644
# Prefill test: q_start_s_index=4096
B = 2
S = 4096
SKV = 8192
topk = 2048

print("=== Prefill test (line 634-636) ===")
print(f"q_start_s_index=4096, S={S}, SKV={SKV}")

min_q_i = 4096
max_q_i = 4096 + S - 1  # 4096 + 4095 = 8191
max_kv_i_max = max_q_i

print(f"q_i range: [{min_q_i}, {max_q_i}]")
print(f"max_kv_i range: [{min_q_i}, {max_kv_i_max}]")
print(f"SKV (valid KV indices): 0 to {SKV-1}")
print(f"Sentinel value: {SKV}")
print(f"Would SKV pass 'index <= max_kv_i'? {SKV <= max_kv_i_max}")

print("\n=== Decode test (line 642-644) ===")
print(f"q_start_s_index=6144, S={S}, SKV={SKV}")

min_q_i = 6144
max_q_i = 6144 + S - 1  # 6144 + 4095 = 10239
max_kv_i_max = max_q_i

print(f"q_i range: [{min_q_i}, {max_q_i}]")
print(f"max_kv_i range: [{min_q_i}, {max_kv_i_max}]")
print(f"SKV (valid KV indices): 0 to {SKV-1}")
print(f"Sentinel value: {SKV}")
print(f"Would SKV pass 'index <= max_kv_i'? {SKV <= max_kv_i_max}")
print(f"This WOULD cause OOB access to KV[..., {SKV}, ...] when max_kv_i={max_kv_i_max}")

print("\n=== Checking when sentinel indices are created ===")
# At line 581: i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk]
# Sentinel indices remain when len(i_i) < topk
# This happens when min(max(1, ...), SKV) < topk

# For decode test, at various t positions:
q_start_s_index = 6144
KV_stride = 1
topk = 2048

for t in [0, 100, 2000, 4095]:
    max_range = min(max(1, ((t + q_start_s_index) // KV_stride)), SKV)
    num_filled = min(max_range, topk)
    num_unfilled = topk - num_filled
    print(f"t={t}: range={max_range}, filled={num_filled}, sentinel_positions={num_unfilled}")
    if num_unfilled > 0:
        print(f"  -> Sentinel indices (value={SKV}) ARE created and passed to kernel!")
EOF

Repository: tile-ai/tilelang

Length of output: 844


🏁 Script executed:

# Verify that sentinel indices are actually in the indices tensor passed to kernel
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '576,590p'

Repository: tile-ai/tilelang

Length of output: 800


🏁 Script executed:

# Final check: show the actual out-of-bounds condition with decode test parameters
python3 << 'EOF'
# Decode test case: lines 642-644
q_start_s_index = 2048 + 4096  # 6144
S = 4096
SKV = 8192
topk = 2048

print("Decode Test Case Analysis:")
print(f"q_start_s_index = {q_start_s_index}")
print(f"Query sequence length S = {S}")
print(f"KV sequence length SKV = {SKV}")
print(f"topk = {topk}")

# At t=0
t = 0
q_i = q_start_s_index + t
max_kv_i = q_i
range_size = min(max(1, ((t + q_start_s_index) // 1)), SKV)
num_filled = min(range_size, topk)
num_sentinel = topk - num_filled

print(f"\nAt t=0:")
print(f"  q_i = {q_i}, max_kv_i = {max_kv_i}")
print(f"  Range size = min(max(1, {t + q_start_s_index}), {SKV}) = {range_size}")
print(f"  Filled positions: {num_filled}, Sentinel positions: {num_sentinel}")

if num_sentinel > 0:
    print(f"\n*** SENTINEL INDICES (value={SKV}) CREATED ***")
    print(f"Kernel validity check: is_kv_valid = index <= max_kv_i = {SKV} <= {max_kv_i}?")
    print(f"Result: {SKV <= max_kv_i} (KV access allowed)")
    if SKV <= max_kv_i:
        print(f"*** OUT-OF-BOUNDS ACCESS: KV[..., {SKV}, ...] but valid range is [0, {SKV-1}] ***")

# At last position
t = S - 1
q_i = q_start_s_index + t
max_kv_i = q_i
range_size = min(max(1, ((t + q_start_s_index) // 1)), SKV)
num_filled = min(range_size, topk)
num_sentinel = topk - num_filled

print(f"\nAt t={t}:")
print(f"  q_i = {q_i}, max_kv_i = {max_kv_i}")
print(f"  Range size = min(max(1, {t + q_start_s_index}), {SKV}) = {range_size}")
print(f"  Filled positions: {num_filled}, Sentinel positions: {num_sentinel}")

if SKV <= max_kv_i:
    print(f"*** SENTINEL INDEX {SKV} WOULD PASS VALIDITY CHECK: {SKV} <= {max_kv_i} ***")
EOF

Repository: tile-ai/tilelang

Length of output: 523


Sentinel index vulnerability exists for certain parameter configurations.

The test initializes indices with sentinel value SKV=8192 for unfilled positions (line 576). While sentinel indices are not created in the current decode test (line 643) due to topk=2048 being smaller than the valid sequence range at all positions, the vulnerability exists in the code logic.

In the decode test configuration where q_start_s_index=6144, the kernel computes max_kv_i up to 10239, which exceeds SKV. If sentinel indices are created (by increasing topk or using smaller q_start_s_index), the validity check index <= max_kv_i at lines 218 and 253 would incorrectly allow the sentinel value to pass, causing out-of-bounds access to KV[..., 8192, ...] when valid indices are [0, 8191].

Use -1 as the sentinel value instead, or add explicit bounds checking before KV access to ensure index < SKV.

🤖 Prompt for AI Agents
In @examples/deepseek_v32/sparse_mla_fwd_seesaw.py around lines 576 - 582, The
sentinel value SKV currently equals 8192 and can be mistaken as a valid index
when computed max_kv_i exceeds SKV; change the sentinel approach by initializing
indices with -1 (ensure dtype supports signed ints) or add explicit bounds
checks before any KV access: replace any truth check like index <= max_kv_i with
a compound check index >= 0 && index < SKV, and ensure downstream code that uses
indices treats negative sentinels as "unused" (e.g., skip/mask accesses). Update
the generation site (indices = torch.full(..., SKV, ...)) and the validation
sites referenced in the diff (the checks currently at lines handling max_kv_i)
to use the new sentinel/bounds logic so KV[..., index, ...] is only accessed for
0 <= index < SKV.

@LeiWang1999 LeiWang1999 merged commit d5503cd into tile-ai:main Jan 8, 2026
4 checks passed
@hammersam
Copy link
Contributor Author

Thanks for the merge!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants