Skip to content

[Core] Remove FlashAttention block size restriction for hybrid models#36701

Open
tdoublep wants to merge 1 commit intovllm-project:mainfrom
tdoublep:remove-hybrid-flash-attn-block-size-restriction
Open

[Core] Remove FlashAttention block size restriction for hybrid models#36701
tdoublep wants to merge 1 commit intovllm-project:mainfrom
tdoublep:remove-hybrid-flash-attn-block-size-restriction

Conversation

@tdoublep
Copy link
Copy Markdown
Member

Summary

Test plan

Verified on H100 with nvidia/NVIDIA-Nemotron-Nano-9B-v2 (hybrid Mamba model) using the same reproduction script from #27753.

Test script

from vllm import LLM, SamplingParams
import os

os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"

prompts = ["Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.\n\nConsider the paths of length $16$ that follow the lines from the lower left corner to the upper right corner on an $8\\times 8$ grid. Find the number of such paths that change direction exactly four times, as in the examples shown below.\n\nRemember to put your answer on its own line after \"Answer:\"."]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="nvidia/NVIDIA-Nemotron-Nano-9B-v2",
    trust_remote_code=True,
    num_gpu_blocks_override=10,
    compilation_config={"cudagraph_capture_sizes": [1]},
    mamba_ssm_cache_dtype="float32",
    max_num_seqs=1,
    enable_prefix_caching=True,
)

outputs = []
for i in range(10):
    result = llm.generate(prompts, sampling_params)[0]
    outputs.append(result)
    generated_text = result.outputs[0].text
    token_ids = result.outputs[0].token_ids
    print(f"--- Iteration {i+1} ---")
    print(f"Token IDs (first 20): {token_ids[:20]}")
    print(f"Generated text (first 200 chars): {generated_text[:200]!r}")
    print()

all_ok = all(len(o.outputs[0].token_ids) > 0 and o.outputs[0].token_ids[0] != 0 for o in outputs)
print(f"All iterations produced meaningful output: {all_ok}")

Test output

All 10 iterations produced meaningful, coherent output:

--- Iteration 1 ---
Token IDs (first 20): [1885, 74045, 1561, 1784, 4127, 10867, 13170, 1278, 2782, 1307, 22344, 1307, 5592, 1032, 1049, 1054]
Generated text (first 200 chars): '</think>\nThe problem requires finding the number of paths of length 16'

--- Iteration 2 ---
Token IDs (first 20): [1032, 1267, 20396, 22344, 1877, 1045, 17669, 1032, 1049, 1058, 21285, 1044, 21285, 1044, 16999, 1044]
Generated text (first 200 chars): ' \n\nExample paths:\n- Path 1: Right, Right, Down,'

--- Iteration 3 ---
Token IDs (first 20): [1885, 74045, 1561, 1784, 2782, 1307, 22344, 1307, 5592, 1032, 1049, 1054, 1562, 1278, 4953, 3979]
Generated text (first 200 chars): '</think>\nThe number of paths of length 16 from the lower left'

--- Iteration 4 ---
Token IDs (first 20): [1885, 74045, 1561, 1784, 4127, 19263, 13170, 1278, 2782, 1307, 22344, 1307, 5592, 1032, 1049, 1054]
Generated text (first 200 chars): '</think>\nThe problem involves finding the number of paths of length 16'

--- Iteration 5 ---
Token IDs (first 20): [4848, 1058, 3870, 15047, 1278, 4127, 1307, 13170, 1278, 2782, 1307, 22344, 1307, 5592, 1032, 1049]
Generated text (first 200 chars): ' output: To solve the problem of finding the number of paths of length 1'

--- Iteration 6 ---
Token IDs (first 20): [1032, 1010, 1885, 74045, 1561, 1784, 4127, 10867, 13170, 1278, 2782, 1307, 22344, 1307, 5592, 1032]
Generated text (first 200 chars): ' \n</think>\nThe problem requires finding the number of paths of length '

--- Iteration 7 ---
Token IDs (first 20): [1885, 74045, 1561, 1784, 4127, 10867, 13170, 1278, 2782, 1307, 22344, 1307, 5592, 1032, 1049, 1054]
Generated text (first 200 chars): '</think>\nThe problem requires finding the number of paths of length 16'

--- Iteration 8 ---
Token IDs (first 20): [1032, 1010, 1885, 74045, 1561, 1784, 22344, 1307, 5592, 1032, 1049, 1054, 1562, 1278, 4953, 3979]
Generated text (first 200 chars): ' \n</think>\nThe paths of length 16 from the lower left'

--- Iteration 9 ---
Token IDs (first 20): [1032, 1267, 49250, 2077, 1561, 44053, 1044, 2878, 1681, 3219, 1046, 1362, 2534, 1317, 3081, 1278]
Generated text (first 200 chars): " \n\n<think>\nOkay, let's see. I need to find the"

--- Iteration 10 ---
Token IDs (first 20): [1060, 74045, 1561, 44053, 1044, 1878, 1362, 2534, 1317, 3081, 1278, 2782, 1307, 22344, 1408, 1420]
Generated text (first 200 chars): '<think>\nOkay, so I need to find the number of paths on an'

All iterations produced meaningful output: True

No NaN, no zero tokens, no empty strings across all 10 iterations.

🤖 Generated with Claude Code

The restriction limiting FA block sizes to [16, 32, 64] for hybrid
models with float32 Mamba cache is no longer needed. PR vllm-project#35219
introduced KVBlockZeroer which zeros freshly allocated KV cache blocks,
preventing NaN propagation from stale fp32 data in reused blocks.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@mergify mergify bot added the v1 label Mar 10, 2026
@tdoublep tdoublep marked this pull request as ready for review March 10, 2026 20:53
@tdoublep
Copy link
Copy Markdown
Member Author

tdoublep commented Mar 10, 2026

cc @NickLucche

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes a block size restriction in FlashAttentionBackend.get_supported_kernel_block_sizes() that was a workaround for a previously fixed bug. The change simplifies the code by removing the now-obsolete conditional logic, which improves maintainability. The provided test plan confirms that removing this restriction does not reintroduce the original issue. The changes are correct and well-justified.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant