Skip to content

fix(fmha_v2): enable flash_attention for Q_PAGED_KV regardless of s_kv#3106

Open
blake-snc wants to merge 1 commit intoflashinfer-ai:mainfrom
blake-snc:fix/fmha-v2-q-paged-kv-flash-attention-gate
Open

fix(fmha_v2): enable flash_attention for Q_PAGED_KV regardless of s_kv#3106
blake-snc wants to merge 1 commit intoflashinfer-ai:mainfrom
blake-snc:fix/fmha-v2-q-paged-kv-flash-attention-gate

Conversation

@blake-snc
Copy link
Copy Markdown
Contributor

@blake-snc blake-snc commented Apr 17, 2026

Problem

determine_launch_params() in csrc/fmha_v2_run.cu disables flash_attention when max_kv_len < 16 via the s >= 16 check. For Q_PAGED_KV layouts this falls back to a non-flash-attention kernel that does not correctly read paged KV — producing wrong output (max_diff 4+ vs reference, or NaN) for any prefill whose total KV length is under 16 tokens.

Reproducer

import torch
from flashinfer.prefill import trtllm_fmha_v2_prefill

device = torch.device("cuda")
torch.manual_seed(42)
# Qwen2.5-0.5B-style GQA: 14 Q heads, 2 KV heads, D=64, page_size=16
# seq_len=9 (< page_size), q_len=8 (prefix-cached scenario)
seq_len, q_len, PS = 9, 8, 16
full_k = torch.randn(seq_len, 2, 64, device=device, dtype=torch.float16)
full_v = torch.randn(seq_len, 2, 64, device=device, dtype=torch.float16)
q_full = torch.randn(seq_len, 14, 64, device=device, dtype=torch.float16)

kv = torch.zeros(1, 2, PS, 2, 64, device=device, dtype=torch.float16)
kv[0, 0, :seq_len] = full_k
kv[0, 1, :seq_len] = full_v
# ... (call trtllm_fmha_v2_prefill with Q_PAGED_KV_NHD)
# Result before fix: max_diff = 2.27 vs reference attention
# Result after fix:  max_diff = 0.002

Root cause

The s >= 16 gate was added for pre-flash-attention fused kernels that required a minimum sequence length for tiling. Q_PAGED_KV layouts only have a flash-attention code path, so this gate incorrectly disqualifies them. Per-request bounds are already enforced via seq_lens; s is the padded max_kv_len used for tile scheduling, not a correctness bound.

Fix

Exempt Q_PAGED_KV layouts from the s >= 16 check:

bool const is_paged_kv = input_layout == Attention_input_layout::Q_PAGED_KV;
launch_params.flash_attention =
    (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) &&
    (is_paged_kv || (s >= 16 && d >= 16)) && !force_non_flash_attention;

Validation

Validated on SM121a (DGX Spark GB10, NVIDIA GB10):

  • Reproducer (seq_len=9, page_size=16): max_diff 2.27 → 0.002 vs reference
  • seq_len in {1, 5, 9, 15, 16, 17, 25, 33}: all match reference (max_diff < 0.01)
  • GQA (14 Q heads, 2 KV heads): max_diff 0.002 vs reference
  • vLLM end-to-end (Qwen2.5-0.5B, trl-internal-testing/tiny-GptOssForCausalLM) with prefix caching: deterministic output across repeated calls, no regressions on prior page-aligned cases

Contributed by Second Nature Computing (https://joinsecondnature.com)

Summary by CodeRabbit

  • Bug Fixes
    • Improved flash attention dispatch for paged-KV attention layouts, removing sequence length and dimension constraints.

determine_launch_params() disables flash_attention when max_kv_len < 16
(via the `s >= 16` check). For Q_PAGED_KV layouts this falls back to a
non-flash-attention kernel that does not correctly read paged KV —
producing wrong output (max_diff 4+ vs reference, or NaN) for any
prefill whose total KV length is under 16 tokens.

The `s >= 16` gate was added for pre-flash-attention fused kernels that
required a minimum sequence length for tiling. Q_PAGED_KV layouts only
have a flash-attention code path, so this gate incorrectly disqualifies
them. Per-request bounds are already enforced via seq_lens; s is the
padded max_kv_len used for tile scheduling, not a correctness bound.

I validated on SM121a (DGX Spark GB10):
- Repro (seq_len=9, page_size=16): max_diff 2.27 -> 0.002 vs reference
- seq_len in {1, 5, 9, 15, 16, 17, 25, 33}: all match reference
- GQA (14 Q heads, 2 KV heads): matches reference (max_diff 0.002)
- vLLM end-to-end (Qwen2.5-0.5B, gpt_oss-tiny) with prefix caching:
  deterministic output, no regressions on prior page-aligned cases

Contributed by Second Nature Computing (https://joinsecondnature.com)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 17, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 63f71669-e663-4585-b39b-b8ca7f117bf1

📥 Commits

Reviewing files that changed from the base of the PR and between 24f2032 and 5192ec5.

📒 Files selected for processing (1)
  • csrc/fmha_v2_run.cu

📝 Walkthrough

Walkthrough

Modified the flash_attention dispatch selection in determine_launch_params to unconditionally enable flash_attention for paged-KV layouts, bypassing the sequence length and dimension threshold constraints that normally apply.

Changes

Cohort / File(s) Summary
Flash Attention Paged-KV Dispatch
csrc/fmha_v2_run.cu
Introduced is_paged_kv boolean to detect Q_PAGED_KV layout and updated flash_attention condition to enable kernel selection for paged-KV regardless of (s >= 16 && d >= 16) threshold.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

op: attention

Suggested reviewers

  • yzh119
  • bkryu
  • jimmyzho
  • cyx-6
  • nvmbreughe

Poem

🐰 A paged-KV takes the stage so bright,
No size constraints to dim its light,
Flash kernels dance without a care,
Attention flows through memory fair! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title clearly and specifically describes the main change: enabling flash_attention for Q_PAGED_KV layouts regardless of sequence length constraints.
Description check ✅ Passed The pull request description provides comprehensive coverage: it clearly states the problem, includes a code reproducer, explains the root cause, details the fix with code snippets, and documents extensive validation results.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the determine_launch_params function in csrc/fmha_v2_run.cu to ensure that flash attention is always enabled for Q_PAGED_KV layouts, regardless of the sequence length. This change addresses an issue where falling back to non-flash paths for small sequence lengths resulted in incorrect outputs, as the paged KV dispatch path only supports flash attention kernels. I have no feedback to provide.

blake-snc added a commit to blake-snc/vllm that referenced this pull request Apr 17, 2026
The prior CONTIGUOUS_Q_KV path read only the current step's raw K/V
tensors, so any prefill with prefix-cached tokens or multi-step prefill
chunking missed the cached KV and produced incorrect attention.

Switch to Q_PAGED_KV_NHD and read directly from vLLM's paged KV cache.
FlashInfer allocates kv_cache in [num_pages, 2, page_size, num_kv_heads,
D] format — already Q_PAGED_KV_NHD layout, no permute needed.

trtllm_fmha_v2_prefill expects cum_seq_lens_kv as cumulative TOKEN
lengths starting from 0. TRTLLMPrefill metadata stores PAGE indptrs in
that field (trtllm-gen uses them differently), so I build the proper
token cumsum from seq_lens.

Requires flashinfer-ai/flashinfer#3106, which fixes a latent gate in
determine_launch_params() that disabled flash_attention for Q_PAGED_KV
when max_kv_len < 16 — producing silently wrong output for short
prompts.

I validated on SM121a (DGX Spark GB10):
- Qwen2.5-0.5B deterministic batched prompts (2, 11, 22, 44 chars): PASS
- Qwen2.5-0.5B with prefix caching (warm + hit): deterministic, coherent
- gpt_oss-tiny with sinks + prefix caching: deterministic
- Qwen2.5-0.5B vs triton_attn: both produce coherent continuations
  (minor numerical differences expected across kernels)

Contributed by Second Nature Computing (https://joinsecondnature.com)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
blake-snc added a commit to blake-snc/vllm that referenced this pull request Apr 17, 2026
…ackend

SM120 (consumer Blackwell — RTX PRO 6000, DGX Spark) lacks the trtllm-gen
precompiled cubins that SM100 uses for attention sinks. This adds an
SM120-specific prefill path using flashinfer's fmha_v2 HMMA kernels,
which support sinks natively via JIT compilation.

- vllm/utils/flashinfer.py: add supports_fmha_v2_sm120_attention() for
  SM120 detection; extend can_use_trtllm_attention / use_trtllm_attention
  to cover the SM120 fmha_v2 path.
- vllm/v1/attention/backends/flashinfer.py: update supports_sink() to
  return True for SM120. Add SM120 fmha_v2 prefill dispatch using
  Q_PAGED_KV_NHD so the kernel reads directly from vLLM's paged KV cache
  (correct under prefix caching and chunked prefill). Use self.bmm2_scale
  for FP8 KV dequant. SM120 decode uses standard flashinfer (not
  trtllm-gen).

trtllm_fmha_v2_prefill expects cum_seq_lens_kv as cumulative TOKEN
lengths starting from 0. TRTLLMPrefill metadata stores PAGE indptrs in
that field (trtllm-gen uses them differently), so I build the proper
token cumsum from seq_lens.

Requires flashinfer-ai/flashinfer#3016 (sinks passthrough in
trtllm_fmha_v2_prefill) and flashinfer-ai/flashinfer#3106 (latent gate
in determine_launch_params that disabled flash_attention for
Q_PAGED_KV when max_kv_len < 16 — producing silently wrong output for
short prompts).

I validated on SM121a (DGX Spark GB10):
- Qwen/Qwen2.5-0.5B-Instruct batched prompts (2, 11, 22, 44 chars):
  deterministic, coherent
- Qwen2.5-0.5B with prefix caching (warm + hit): deterministic
- trl-internal-testing/tiny-GptOssForCausalLM (gpt_oss with built-in
  sinks + sliding window + GQA): deterministic
- Qwen2.5-0.5B vs triton_attn baseline: both produce coherent
  continuations (minor numerical differences expected across kernels)
- supports_sink() returns True on SM120

Contributed by Second Nature Computing (https://joinsecondnature.com)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant