[Example] Add Seesaw Sparse MLA Forward Kernel for DeepSeek-V3.2#1636
[Example] Add Seesaw Sparse MLA Forward Kernel for DeepSeek-V3.2#1636LeiWang1999 merged 1 commit intotile-ai:mainfrom
Conversation
Add a new sparse MLA forward kernel implementation using the "Seesaw" synchronization pattern, an alternative to the existing pipelined approach. **Dual-Consumer Parallel Architecture:** - Unlike pipelined version where WG1 depends on WG0's S_shared/alpha_shared, Seesaw allows both consumers to work independently on different KV blocks - Consumer 0 (WG0): Processes even blocks (BI_2*i), computes O_L (left half) - Consumer 1 (WG1): Processes odd blocks (BI_2*i+1), computes O_R (right half) - Each consumer maintains its own softmax statistics (m_i, sumexp) **Seesaw Synchronization Mechanism:** - Consumers exchange local row_max via bar_stats_0/1_ready barriers - Both compute global max by taking max of local and peer's max - S matrices exchanged via bar_S_0/1_ready for cross-attention: - O_L += P0 @ V0_L (self) + P1 @ V1_L (from peer) - O_R += P1 @ V1_R (self) + P0 @ V0_R (from peer) **Memory Optimizations:** - Reuses K_tail_shared_0/1 as S_shared_0/1 to save shared memory - Double-buffered is_kv_valid[2, BI] mask to avoid race conditions - Index prefetching in producer to hide memory latency | Aspect | Pipelined | Seesaw | |--------|-----------|--------| | Consumer dependency | WG1 waits for WG0's S/alpha | Independent parallel compute | | S matrix | Single S_shared | Dual S_shared_0/1 (reused) | | Softmax stats | Single m_i, sumexp | Per-consumer stats with exchange | | KV valid mask | Single buffer [BI] | Double buffer [2, BI] | | Index prefetch | None | Async prefetch next iteration | | Register alloc | WG0:240, WG1:168, Prod:80 | WG0:216, WG1:216, Prod:72 | **Prefill Benchmark** (B=2, S=4096, SKV=8192, H=128, topk=2048): - Average time: 10.276 ms - IO bandwidth: 1.88 TB/s - TFLOPS: 454.76 **Decode Benchmark** (B=2048, S=2, SKV=8192, H=128, topk=2048): - Average time: 5.554 ms - IO bandwidth: 1.74 TB/s - TFLOPS: 420.68 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughIntroduces a new TileLang-based sparse multi-head attention forward kernel with seesaw producer-consumer synchronization, including kernel implementation, public execution interfaces, reference validation function, and test harness with profiling capabilities. Changes
Sequence Diagram(s)sequenceDiagram
participant ThreadBlocks as Block Groups
participant SharedMem as Shared Memory
participant Barrier as Barrier Sync
participant Softmax as Softmax Engine
participant Output as Output Buffer
rect rgb(240, 248, 255)
Note over ThreadBlocks,Output: Phase 0 (Even Blocks)
ThreadBlocks->>SharedMem: Prefetch KV data
ThreadBlocks->>Barrier: Block barrier enter
Barrier-->>ThreadBlocks: Proceed when ready
ThreadBlocks->>SharedMem: Load Q, compute QK prod
ThreadBlocks->>Softmax: Per-block softmax + max-trace
end
rect rgb(255, 240, 245)
Note over ThreadBlocks,Output: Seesaw Switch
Barrier->>Barrier: Producer-consumer flip
end
rect rgb(240, 255, 240)
Note over ThreadBlocks,Output: Phase 1 (Odd Blocks)
ThreadBlocks->>SharedMem: Async load AV pairs
ThreadBlocks->>Barrier: Block barrier wait
Barrier-->>ThreadBlocks: Previous phase complete
ThreadBlocks->>Output: Compute V aggregation
ThreadBlocks->>Output: Store final output
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In @examples/deepseek_v32/sparse_mla_fwd_seesaw.py:
- Around line 470-493: The parameter name is_casual in sparse_mla_fwd_interface
is a typo and should be is_causal; rename the parameter in the function
signature and all uses inside sparse_mla_fwd_interface (including the boolean
CP0 computation and the argument passed to sparse_mla_fwd) from is_casual to
is_causal so the API and the kernel call (sparse_mla_fwd) use the correct
spelling and avoid confusion.
- Line 511: The parameter name in ref_sparse_mla_fwd_interface is misspelled:
change the parameter from is_casual to is_causal and update all references
inside ref_sparse_mla_fwd_interface and any call sites to use is_causal; also
ensure the kernel invocation and any wrapper/forwarding functions (e.g., the
corresponding kernel function name referenced in this file) use the same
is_causal name so signatures match and tests/linters stop complaining about the
typo.
- Around line 576-582: The sentinel value SKV currently equals 8192 and can be
mistaken as a valid index when computed max_kv_i exceeds SKV; change the
sentinel approach by initializing indices with -1 (ensure dtype supports signed
ints) or add explicit bounds checks before any KV access: replace any truth
check like index <= max_kv_i with a compound check index >= 0 && index < SKV,
and ensure downstream code that uses indices treats negative sentinels as
"unused" (e.g., skip/mask accesses). Update the generation site (indices =
torch.full(..., SKV, ...)) and the validation sites referenced in the diff (the
checks currently at lines handling max_kv_i) to use the new sentinel/bounds
logic so KV[..., index, ...] is only accessed for 0 <= index < SKV.
🧹 Nitpick comments (1)
examples/deepseek_v32/sparse_mla_fwd_seesaw.py (1)
340-344: Consider removing commented-out code or adding clarifying documentation.Line 342 contains commented-out barrier wait code. While the comment at lines 313-315 explains the implicit synchronization through stats barriers, keeping dead code can be confusing for future maintainers. Either remove the line entirely or convert it to a more explicit documentation note explaining why this synchronization point is unnecessary.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/deepseek_v32/sparse_mla_fwd_seesaw.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/deepseek_v32/sparse_mla_fwd_seesaw.py (5)
tilelang/math/__init__.py (1)
next_power_of_2(1-2)tilelang/language/allocate.py (3)
alloc_shared(39-54)alloc_fragment(71-82)alloc_barrier(150-159)tilelang/language/copy_op.py (1)
copy(14-116)tilelang/language/builtin.py (2)
barrier_arrive(649-656)barrier_wait(634-646)tilelang/language/reduce_op.py (2)
reduce_max(107-125)reduce_sum(144-166)
🔇 Additional comments (2)
examples/deepseek_v32/sparse_mla_fwd_seesaw.py (2)
127-130: Verify dimension match for S_shared aliasing K_tail_shared.The aliasing
S_shared_0 = K_tail_shared_0assumesK_tail_sharedshape[BI, D_tail]matches the needed S matrix shape[H_per_block, BI]. This requiresH_per_block == D_tail. If these dimensions diverge in future configurations, this aliasing will cause incorrect behavior or crashes.Consider adding an assertion to enforce this invariant:
assert H_per_block == D_tail, f"S_shared aliasing requires H_per_block ({H_per_block}) == D_tail ({D_tail})"
378-378: Remove this review comment; the concern is incorrect.The LSE computation at line 378 using
m_i_0is correct. Through iterative per-iteration synchronization withclear=Falseaccumulation in the loop,m_i_0contains the global maximum across all blocks from both consumers by loop termination, not just Consumer 0's local maximum. Both consumers converge to identical m_i and sumexp values, so LSE written by Consumer 0 is consistent.
| def sparse_mla_fwd_interface( | ||
| q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False | ||
| ): | ||
| assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() | ||
| batch, seq_len, heads, dim_plus_tail_dim = q.shape | ||
| _, seq_len_kv, kv_group, _ = kv.shape | ||
|
|
||
| assert dim_plus_tail_dim == 576, "you should assign dim otherwise" | ||
| dim = 512 | ||
|
|
||
| assert kv.shape[-1] == dim_plus_tail_dim | ||
| tail_dim = dim_plus_tail_dim - dim | ||
| assert kv.shape[0] == batch | ||
| _, _, _, topk = indices.shape | ||
| assert indices.shape == (batch, seq_len, kv_group, topk) | ||
|
|
||
| if q_start_index_s != 0: | ||
| assert q_start_index_s > kv_stride, ( | ||
| "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" | ||
| ) | ||
| CP0 = q_start_index_s == 0 | ||
|
|
||
| # Compile the kernel | ||
| kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) |
There was a problem hiding this comment.
Typo: is_casual should be is_causal.
The parameter name is_casual (line 471) is a typo that should be is_causal. This inconsistency propagates to line 493 where it's passed to the kernel function that expects is_causal. While Python's keyword argument passing will still work due to positional ordering, this creates confusing API inconsistency with the kernel parameter name.
🐛 Proposed fix
def sparse_mla_fwd_interface(
- q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False
+ q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_causal=True, return_kernel=False, print_kernel=False
):
# ... lines 473-492 ...
- kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0)
+ kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_causal, CP0)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def sparse_mla_fwd_interface( | |
| q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False | |
| ): | |
| assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() | |
| batch, seq_len, heads, dim_plus_tail_dim = q.shape | |
| _, seq_len_kv, kv_group, _ = kv.shape | |
| assert dim_plus_tail_dim == 576, "you should assign dim otherwise" | |
| dim = 512 | |
| assert kv.shape[-1] == dim_plus_tail_dim | |
| tail_dim = dim_plus_tail_dim - dim | |
| assert kv.shape[0] == batch | |
| _, _, _, topk = indices.shape | |
| assert indices.shape == (batch, seq_len, kv_group, topk) | |
| if q_start_index_s != 0: | |
| assert q_start_index_s > kv_stride, ( | |
| "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" | |
| ) | |
| CP0 = q_start_index_s == 0 | |
| # Compile the kernel | |
| kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) | |
| def sparse_mla_fwd_interface( | |
| q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_causal=True, return_kernel=False, print_kernel=False | |
| ): | |
| assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() | |
| batch, seq_len, heads, dim_plus_tail_dim = q.shape | |
| _, seq_len_kv, kv_group, _ = kv.shape | |
| assert dim_plus_tail_dim == 576, "you should assign dim otherwise" | |
| dim = 512 | |
| assert kv.shape[-1] == dim_plus_tail_dim | |
| tail_dim = dim_plus_tail_dim - dim | |
| assert kv.shape[0] == batch | |
| _, _, _, topk = indices.shape | |
| assert indices.shape == (batch, seq_len, kv_group, topk) | |
| if q_start_index_s != 0: | |
| assert q_start_index_s > kv_stride, ( | |
| "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" | |
| ) | |
| CP0 = q_start_index_s == 0 | |
| # Compile the kernel | |
| kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_causal, CP0) |
🤖 Prompt for AI Agents
In @examples/deepseek_v32/sparse_mla_fwd_seesaw.py around lines 470 - 493, The
parameter name is_casual in sparse_mla_fwd_interface is a typo and should be
is_causal; rename the parameter in the function signature and all uses inside
sparse_mla_fwd_interface (including the boolean CP0 computation and the argument
passed to sparse_mla_fwd) from is_casual to is_causal so the API and the kernel
call (sparse_mla_fwd) use the correct spelling and avoid confusion.
| return out, lse | ||
|
|
||
|
|
||
| def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=1, sm_scale=None, is_casual=True): |
There was a problem hiding this comment.
Same typo: is_casual should be is_causal.
For consistency with the kernel function and standard terminology, rename is_casual to is_causal.
🐛 Proposed fix
-def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=1, sm_scale=None, is_casual=True):
+def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=1, sm_scale=None, is_causal=True):Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In @examples/deepseek_v32/sparse_mla_fwd_seesaw.py at line 511, The parameter
name in ref_sparse_mla_fwd_interface is misspelled: change the parameter from
is_casual to is_causal and update all references inside
ref_sparse_mla_fwd_interface and any call sites to use is_causal; also ensure
the kernel invocation and any wrapper/forwarding functions (e.g., the
corresponding kernel function name referenced in this file) use the same
is_causal name so signatures match and tests/linters stop complaining about the
typo.
| indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") | ||
| for b in range(B): | ||
| for t in range(S): | ||
| for h in range(HKV): | ||
| # Add offset q_start_s_index to convert to global sequence position | ||
| i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] | ||
| indices[b, t, h, : len(i_i)] = i_i |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file structure and find the kernel definition
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | head -50Repository: tile-ai/tilelang
Length of output: 1820
🏁 Script executed:
# Now search for is_kv_valid and max_kv_i to understand the validity checks
rg -n "is_kv_valid|max_kv_i" examples/deepseek_v32/ -A 2 -B 2Repository: tile-ai/tilelang
Length of output: 9582
🏁 Script executed:
# Look at lines 576-582 in context and surrounding code
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '560,600p'Repository: tile-ai/tilelang
Length of output: 1985
🏁 Script executed:
# Check the producer kernel code around lines 218-219, 253-254 mentioned in the review
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '210,260p'Repository: tile-ai/tilelang
Length of output: 4035
🏁 Script executed:
# Search for how indices are actually used/accessed in the kernel
rg -n "indices\[" examples/deepseek_v32/sparse_mla_fwd_seesaw.py -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 354
🏁 Script executed:
# Check the function signature and default parameters
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '545,575p'Repository: tile-ai/tilelang
Length of output: 1117
🏁 Script executed:
# Look for how this test function is called and what values are passed
rg -n "test_sparse_mla_fwd_seesaw" examples/deepseek_v32/sparse_mla_fwd_seesaw.py -B 5 -A 15Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check if there's any actual OOB protection after validity check
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '215,245p'Repository: tile-ai/tilelang
Length of output: 2697
🏁 Script executed:
# Check how max_kv_i is computed in sparse_mla_fwd_pipelined.py (the kernel being tested)
cat -n examples/deepseek_v32/sparse_mla_fwd_pipelined.py | sed -n '125,135p'Repository: tile-ai/tilelang
Length of output: 564
🏁 Script executed:
# Verify the relationship: with test defaults, what is the maximum max_kv_i value?
python3 << 'EOF'
# Test parameters from the function
B = 1
S = 4096 # sequence length (num query positions)
SKV = 8192 # KV sequence length (size of KV tensor dimension 1)
q_start_s_index = 2048
# In kernel, q_i = q_start_index_s[0] + s_i, where s_i is the sequence index (0 to S-1)
# With q_start_index_s[0] = q_start_s_index = 2048, and s_i in range(S)
# q_i ranges from 2048 to 2048 + 4095 = 6143
# max_kv_i = (q_i + 1 - KV_stride) // KV_stride, with KV_stride=1
# max_kv_i = q_i (since (q_i + 1 - 1) // 1 = q_i)
min_q_i = q_start_s_index # 2048
max_q_i = q_start_s_index + S - 1 # 2048 + 4095 = 6143
max_kv_i_min = min_q_i # 2048
max_kv_i_max = max_q_i # 6143
print(f"q_start_s_index: {q_start_s_index}")
print(f"S (num query positions): {S}")
print(f"SKV (KV sequence length): {SKV}")
print(f"q_i range: [{min_q_i}, {max_q_i}]")
print(f"max_kv_i range: [{max_kv_i_min}, {max_kv_i_max}]")
print(f"\nSentinel value (SKV): {SKV}")
print(f"Highest max_kv_i: {max_kv_i_max}")
print(f"SKV > max_kv_i? {SKV > max_kv_i_max}")
print(f"Would sentinel pass validity check 'index <= max_kv_i'? {SKV <= max_kv_i_max}")
EOFRepository: tile-ai/tilelang
Length of output: 146
🏁 Script executed:
# Check if there are any other test calls with different parameters that could trigger the issue
rg -n "test_sparse_mla_fwd_pipelined\(" examples/deepseek_v32/ -A 5Repository: tile-ai/tilelang
Length of output: 3342
🏁 Script executed:
# Calculate for the test case at line 643 with q_start_s_index=6144
python3 << 'EOF'
# Test parameters from line 634-636 and 642-644
# Prefill test: q_start_s_index=4096
B = 2
S = 4096
SKV = 8192
topk = 2048
print("=== Prefill test (line 634-636) ===")
print(f"q_start_s_index=4096, S={S}, SKV={SKV}")
min_q_i = 4096
max_q_i = 4096 + S - 1 # 4096 + 4095 = 8191
max_kv_i_max = max_q_i
print(f"q_i range: [{min_q_i}, {max_q_i}]")
print(f"max_kv_i range: [{min_q_i}, {max_kv_i_max}]")
print(f"SKV (valid KV indices): 0 to {SKV-1}")
print(f"Sentinel value: {SKV}")
print(f"Would SKV pass 'index <= max_kv_i'? {SKV <= max_kv_i_max}")
print("\n=== Decode test (line 642-644) ===")
print(f"q_start_s_index=6144, S={S}, SKV={SKV}")
min_q_i = 6144
max_q_i = 6144 + S - 1 # 6144 + 4095 = 10239
max_kv_i_max = max_q_i
print(f"q_i range: [{min_q_i}, {max_q_i}]")
print(f"max_kv_i range: [{min_q_i}, {max_kv_i_max}]")
print(f"SKV (valid KV indices): 0 to {SKV-1}")
print(f"Sentinel value: {SKV}")
print(f"Would SKV pass 'index <= max_kv_i'? {SKV <= max_kv_i_max}")
print(f"This WOULD cause OOB access to KV[..., {SKV}, ...] when max_kv_i={max_kv_i_max}")
print("\n=== Checking when sentinel indices are created ===")
# At line 581: i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk]
# Sentinel indices remain when len(i_i) < topk
# This happens when min(max(1, ...), SKV) < topk
# For decode test, at various t positions:
q_start_s_index = 6144
KV_stride = 1
topk = 2048
for t in [0, 100, 2000, 4095]:
max_range = min(max(1, ((t + q_start_s_index) // KV_stride)), SKV)
num_filled = min(max_range, topk)
num_unfilled = topk - num_filled
print(f"t={t}: range={max_range}, filled={num_filled}, sentinel_positions={num_unfilled}")
if num_unfilled > 0:
print(f" -> Sentinel indices (value={SKV}) ARE created and passed to kernel!")
EOFRepository: tile-ai/tilelang
Length of output: 844
🏁 Script executed:
# Verify that sentinel indices are actually in the indices tensor passed to kernel
cat -n examples/deepseek_v32/sparse_mla_fwd_seesaw.py | sed -n '576,590p'Repository: tile-ai/tilelang
Length of output: 800
🏁 Script executed:
# Final check: show the actual out-of-bounds condition with decode test parameters
python3 << 'EOF'
# Decode test case: lines 642-644
q_start_s_index = 2048 + 4096 # 6144
S = 4096
SKV = 8192
topk = 2048
print("Decode Test Case Analysis:")
print(f"q_start_s_index = {q_start_s_index}")
print(f"Query sequence length S = {S}")
print(f"KV sequence length SKV = {SKV}")
print(f"topk = {topk}")
# At t=0
t = 0
q_i = q_start_s_index + t
max_kv_i = q_i
range_size = min(max(1, ((t + q_start_s_index) // 1)), SKV)
num_filled = min(range_size, topk)
num_sentinel = topk - num_filled
print(f"\nAt t=0:")
print(f" q_i = {q_i}, max_kv_i = {max_kv_i}")
print(f" Range size = min(max(1, {t + q_start_s_index}), {SKV}) = {range_size}")
print(f" Filled positions: {num_filled}, Sentinel positions: {num_sentinel}")
if num_sentinel > 0:
print(f"\n*** SENTINEL INDICES (value={SKV}) CREATED ***")
print(f"Kernel validity check: is_kv_valid = index <= max_kv_i = {SKV} <= {max_kv_i}?")
print(f"Result: {SKV <= max_kv_i} (KV access allowed)")
if SKV <= max_kv_i:
print(f"*** OUT-OF-BOUNDS ACCESS: KV[..., {SKV}, ...] but valid range is [0, {SKV-1}] ***")
# At last position
t = S - 1
q_i = q_start_s_index + t
max_kv_i = q_i
range_size = min(max(1, ((t + q_start_s_index) // 1)), SKV)
num_filled = min(range_size, topk)
num_sentinel = topk - num_filled
print(f"\nAt t={t}:")
print(f" q_i = {q_i}, max_kv_i = {max_kv_i}")
print(f" Range size = min(max(1, {t + q_start_s_index}), {SKV}) = {range_size}")
print(f" Filled positions: {num_filled}, Sentinel positions: {num_sentinel}")
if SKV <= max_kv_i:
print(f"*** SENTINEL INDEX {SKV} WOULD PASS VALIDITY CHECK: {SKV} <= {max_kv_i} ***")
EOFRepository: tile-ai/tilelang
Length of output: 523
Sentinel index vulnerability exists for certain parameter configurations.
The test initializes indices with sentinel value SKV=8192 for unfilled positions (line 576). While sentinel indices are not created in the current decode test (line 643) due to topk=2048 being smaller than the valid sequence range at all positions, the vulnerability exists in the code logic.
In the decode test configuration where q_start_s_index=6144, the kernel computes max_kv_i up to 10239, which exceeds SKV. If sentinel indices are created (by increasing topk or using smaller q_start_s_index), the validity check index <= max_kv_i at lines 218 and 253 would incorrectly allow the sentinel value to pass, causing out-of-bounds access to KV[..., 8192, ...] when valid indices are [0, 8191].
Use -1 as the sentinel value instead, or add explicit bounds checking before KV access to ensure index < SKV.
🤖 Prompt for AI Agents
In @examples/deepseek_v32/sparse_mla_fwd_seesaw.py around lines 576 - 582, The
sentinel value SKV currently equals 8192 and can be mistaken as a valid index
when computed max_kv_i exceeds SKV; change the sentinel approach by initializing
indices with -1 (ensure dtype supports signed ints) or add explicit bounds
checks before any KV access: replace any truth check like index <= max_kv_i with
a compound check index >= 0 && index < SKV, and ensure downstream code that uses
indices treats negative sentinels as "unused" (e.g., skip/mask accesses). Update
the generation site (indices = torch.full(..., SKV, ...)) and the validation
sites referenced in the diff (the checks currently at lines handling max_kv_i)
to use the new sentinel/bounds logic so KV[..., index, ...] is only accessed for
0 <= index < SKV.
|
Thanks for the merge! |
Add a new sparse MLA forward kernel implementation using the "Seesaw" synchronization pattern, an alternative to the existing pipelined approach.
Dual-Consumer Parallel Architecture:
Seesaw Synchronization Mechanism:
Memory Optimizations:
Prefill Benchmark (B=2, S=4096, SKV=8192, H=128, topk=2048):
Decode Benchmark (B=2048, S=2, SKV=8192, H=128, topk=2048):
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.