Skip to content

[Attention] relax the head dim 512 and paged kv for sm90+FA4#38835

Open
IwakuraRein wants to merge 6 commits intovllm-project:mainfrom
IwakuraRein:update-sm90-fa4
Open

[Attention] relax the head dim 512 and paged kv for sm90+FA4#38835
IwakuraRein wants to merge 6 commits intovllm-project:mainfrom
IwakuraRein:update-sm90-fa4

Conversation

@IwakuraRein
Copy link
Copy Markdown
Contributor

@IwakuraRein IwakuraRein commented Apr 2, 2026

Purpose

This PR updates the checks for FA4+SM90 in order to unblock the head dim 512 and page KV for SM90.

vLLm Flash Attention PR dependency: vllm-project/flash-attention#130

Related Flash Attention PRs:

Test Plan

Test with google/gemma-4-31B-it and 1 H200.

vllm serve google/gemma-4-31B-it --attention-backend FLASH_ATTN --attention-config.flash_attn_version=4

Test Result

Accuracy Test

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9340|±  |0.0068|
|     |       |strict-match    |     5|exact_match|↑  |0.9295|±  |0.0071|

Performance Benchmark

FlashAttention4 is faster than Triton during Prefill. It shows better performance for large concurrency and large input sequence length cases.

output-throughput total-throughput
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Apr 2, 2026
@IwakuraRein IwakuraRein changed the title relax the head dim 512 and paged kv for sm90+FA4 [Attention] relax the head dim 512 and paged kv for sm90+FA4 Apr 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for head sizes up to 512 on Hopper (SM90) architectures when using Flash Attention version 4, and removes the previous restriction on FA4 with paged KV for these devices. A review comment suggests passing the 'head_size' argument to 'get_flash_attn_version()' to ensure the correct version is detected for larger head dimensions, as the current implementation might default to version 3 and fail the condition.

if (
current_platform.is_cuda()
and current_platform.is_device_capability_family(90)
and get_flash_attn_version() == 4
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to get_flash_attn_version() without the head_size argument will return the default version for the platform (which is version 3 for SM90/Hopper). This causes the condition get_flash_attn_version() == 4 to be false by default, effectively blocking automatic support for head dimension 512 on Hopper unless the user manually overrides the version via environment variables. To enable this automatically, you should pass head_size=head_size to the version check and ensure that the logic in fa_utils.py is updated to prefer FA4 when the head dimension exceeds 256 on SM90.

Suggested change
and get_flash_attn_version() == 4
and get_flash_attn_version(head_size=head_size) == 4

@mergify mergify bot added the ci/build label Apr 2, 2026
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Apr 2, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @IwakuraRein.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 2, 2026
-s

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
IwakuraRein and others added 3 commits April 2, 2026 18:20
…tn PR is merged

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
…r large head sizes

- `supports_head_size` now returns True for head_size <= 512 when FA4 is
  available (regardless of current default FA version), allowing the backend
  selector to pick FLASH_ATTN for models like Gemma4 with global_head_dim=512
- When head_size > 256 on SM90+ (Hopper), the impl auto-upgrades from FA3
  to FA4 since FA3 doesn't support head sizes above 256
- Add `is_fa_version_supported` wrapper in fa_utils for safe import

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the contribution!

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 3, 2026
@LucasWilkinson LucasWilkinson enabled auto-merge (squash) April 3, 2026 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants