[Attention] relax the head dim 512 and paged kv for sm90+FA4#38835
[Attention] relax the head dim 512 and paged kv for sm90+FA4#38835IwakuraRein wants to merge 6 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request enables support for head sizes up to 512 on Hopper (SM90) architectures when using Flash Attention version 4, and removes the previous restriction on FA4 with paged KV for these devices. A review comment suggests passing the 'head_size' argument to 'get_flash_attn_version()' to ensure the correct version is detected for larger head dimensions, as the current implementation might default to version 3 and fail the condition.
| if ( | ||
| current_platform.is_cuda() | ||
| and current_platform.is_device_capability_family(90) | ||
| and get_flash_attn_version() == 4 |
There was a problem hiding this comment.
The call to get_flash_attn_version() without the head_size argument will return the default version for the platform (which is version 3 for SM90/Hopper). This causes the condition get_flash_attn_version() == 4 to be false by default, effectively blocking automatic support for head dimension 512 on Hopper unless the user manually overrides the version via environment variables. To enable this automatically, you should pass head_size=head_size to the version check and ensure that the logic in fa_utils.py is updated to prefer FA4 when the head dimension exceeds 256 on SM90.
| and get_flash_attn_version() == 4 | |
| and get_flash_attn_version(head_size=head_size) == 4 |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
This pull request has merge conflicts that must be resolved before it can be |
-s Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
8d7102a to
75ea3b3
Compare
…tn PR is merged Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
…r large head sizes - `supports_head_size` now returns True for head_size <= 512 when FA4 is available (regardless of current default FA version), allowing the backend selector to pick FLASH_ATTN for models like Gemma4 with global_head_dim=512 - When head_size > 256 on SM90+ (Hopper), the impl auto-upgrades from FA3 to FA4 since FA3 doesn't support head sizes above 256 - Add `is_fa_version_supported` wrapper in fa_utils for safe import Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
LucasWilkinson
left a comment
There was a problem hiding this comment.
LGTM, thanks for the contribution!
Purpose
This PR updates the checks for FA4+SM90 in order to unblock the head dim 512 and page KV for SM90.
vLLm Flash Attention PR dependency: vllm-project/flash-attention#130
Related Flash Attention PRs:
Test Plan
Test with google/gemma-4-31B-it and 1 H200.
Test Result
Accuracy Test
Performance Benchmark
FlashAttention4 is faster than Triton during Prefill. It shows better performance for large concurrency and large input sequence length cases.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.