fix(fmha_v2): enable flash_attention for Q_PAGED_KV regardless of s_kv#3106
fix(fmha_v2): enable flash_attention for Q_PAGED_KV regardless of s_kv#3106blake-snc wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
determine_launch_params() disables flash_attention when max_kv_len < 16
(via the `s >= 16` check). For Q_PAGED_KV layouts this falls back to a
non-flash-attention kernel that does not correctly read paged KV —
producing wrong output (max_diff 4+ vs reference, or NaN) for any
prefill whose total KV length is under 16 tokens.
The `s >= 16` gate was added for pre-flash-attention fused kernels that
required a minimum sequence length for tiling. Q_PAGED_KV layouts only
have a flash-attention code path, so this gate incorrectly disqualifies
them. Per-request bounds are already enforced via seq_lens; s is the
padded max_kv_len used for tile scheduling, not a correctness bound.
I validated on SM121a (DGX Spark GB10):
- Repro (seq_len=9, page_size=16): max_diff 2.27 -> 0.002 vs reference
- seq_len in {1, 5, 9, 15, 16, 17, 25, 33}: all match reference
- GQA (14 Q heads, 2 KV heads): matches reference (max_diff 0.002)
- vLLM end-to-end (Qwen2.5-0.5B, gpt_oss-tiny) with prefix caching:
deterministic output, no regressions on prior page-aligned cases
Contributed by Second Nature Computing (https://joinsecondnature.com)
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughModified the flash_attention dispatch selection in Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request modifies the determine_launch_params function in csrc/fmha_v2_run.cu to ensure that flash attention is always enabled for Q_PAGED_KV layouts, regardless of the sequence length. This change addresses an issue where falling back to non-flash paths for small sequence lengths resulted in incorrect outputs, as the paged KV dispatch path only supports flash attention kernels. I have no feedback to provide.
The prior CONTIGUOUS_Q_KV path read only the current step's raw K/V tensors, so any prefill with prefix-cached tokens or multi-step prefill chunking missed the cached KV and produced incorrect attention. Switch to Q_PAGED_KV_NHD and read directly from vLLM's paged KV cache. FlashInfer allocates kv_cache in [num_pages, 2, page_size, num_kv_heads, D] format — already Q_PAGED_KV_NHD layout, no permute needed. trtllm_fmha_v2_prefill expects cum_seq_lens_kv as cumulative TOKEN lengths starting from 0. TRTLLMPrefill metadata stores PAGE indptrs in that field (trtllm-gen uses them differently), so I build the proper token cumsum from seq_lens. Requires flashinfer-ai/flashinfer#3106, which fixes a latent gate in determine_launch_params() that disabled flash_attention for Q_PAGED_KV when max_kv_len < 16 — producing silently wrong output for short prompts. I validated on SM121a (DGX Spark GB10): - Qwen2.5-0.5B deterministic batched prompts (2, 11, 22, 44 chars): PASS - Qwen2.5-0.5B with prefix caching (warm + hit): deterministic, coherent - gpt_oss-tiny with sinks + prefix caching: deterministic - Qwen2.5-0.5B vs triton_attn: both produce coherent continuations (minor numerical differences expected across kernels) Contributed by Second Nature Computing (https://joinsecondnature.com) Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ackend SM120 (consumer Blackwell — RTX PRO 6000, DGX Spark) lacks the trtllm-gen precompiled cubins that SM100 uses for attention sinks. This adds an SM120-specific prefill path using flashinfer's fmha_v2 HMMA kernels, which support sinks natively via JIT compilation. - vllm/utils/flashinfer.py: add supports_fmha_v2_sm120_attention() for SM120 detection; extend can_use_trtllm_attention / use_trtllm_attention to cover the SM120 fmha_v2 path. - vllm/v1/attention/backends/flashinfer.py: update supports_sink() to return True for SM120. Add SM120 fmha_v2 prefill dispatch using Q_PAGED_KV_NHD so the kernel reads directly from vLLM's paged KV cache (correct under prefix caching and chunked prefill). Use self.bmm2_scale for FP8 KV dequant. SM120 decode uses standard flashinfer (not trtllm-gen). trtllm_fmha_v2_prefill expects cum_seq_lens_kv as cumulative TOKEN lengths starting from 0. TRTLLMPrefill metadata stores PAGE indptrs in that field (trtllm-gen uses them differently), so I build the proper token cumsum from seq_lens. Requires flashinfer-ai/flashinfer#3016 (sinks passthrough in trtllm_fmha_v2_prefill) and flashinfer-ai/flashinfer#3106 (latent gate in determine_launch_params that disabled flash_attention for Q_PAGED_KV when max_kv_len < 16 — producing silently wrong output for short prompts). I validated on SM121a (DGX Spark GB10): - Qwen/Qwen2.5-0.5B-Instruct batched prompts (2, 11, 22, 44 chars): deterministic, coherent - Qwen2.5-0.5B with prefix caching (warm + hit): deterministic - trl-internal-testing/tiny-GptOssForCausalLM (gpt_oss with built-in sinks + sliding window + GQA): deterministic - Qwen2.5-0.5B vs triton_attn baseline: both produce coherent continuations (minor numerical differences expected across kernels) - supports_sink() returns True on SM120 Contributed by Second Nature Computing (https://joinsecondnature.com) Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Problem
determine_launch_params()incsrc/fmha_v2_run.cudisablesflash_attentionwhenmax_kv_len < 16via thes >= 16check. ForQ_PAGED_KVlayouts this falls back to a non-flash-attention kernel that does not correctly read paged KV — producing wrong output (max_diff 4+ vs reference, or NaN) for any prefill whose total KV length is under 16 tokens.Reproducer
Root cause
The
s >= 16gate was added for pre-flash-attention fused kernels that required a minimum sequence length for tiling.Q_PAGED_KVlayouts only have a flash-attention code path, so this gate incorrectly disqualifies them. Per-request bounds are already enforced viaseq_lens;sis the paddedmax_kv_lenused for tile scheduling, not a correctness bound.Fix
Exempt
Q_PAGED_KVlayouts from thes >= 16check:Validation
Validated on SM121a (DGX Spark GB10, NVIDIA GB10):
trl-internal-testing/tiny-GptOssForCausalLM) with prefix caching: deterministic output across repeated calls, no regressions on prior page-aligned casesContributed by Second Nature Computing (https://joinsecondnature.com)
Summary by CodeRabbit