[Attention] Allow using system FA4#38823
[Attention] Allow using system FA4#38823IwakuraRein wants to merge 5 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request introduces the use_system_flash_attn configuration, enabling vLLM to utilize a system-installed Flash Attention 4 library. The implementation includes updates to the attention configuration, backend selection logic, and the Flash Attention interface to support head sizes up to 512. However, several critical issues were identified: the MLA implementation hardcodes Flash Attention version 3, which conflicts with the system Flash Attention requirement for version 4; performance overhead is introduced by performing configuration lookups and imports within the hot path of attention operations; and the Flash Attention 4 interface is missing the q_v parameter necessary for MLA. Additionally, it is recommended to verify the availability of the system library within the configuration check to ensure accurate backend capability reporting.
| cp_world_size=self.dcp_world_size, | ||
| cp_rank=self.dcp_rank, | ||
| cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, | ||
| use_system_flash_attn=self.use_system_flash_attn, |
There was a problem hiding this comment.
In forward_mqa, fa_version is hardcoded to 3 (at line 359). However, the PR adds use_system_flash_attn=self.use_system_flash_attn. If use_system_flash_attn is True, flash_attn_varlen_func will raise a NotImplementedError because it explicitly forbids using system Flash Attention with fa_version=3. To support system Flash Attention 4 for MLA, the fa_version parameter must be updated to 4 when use_system_flash_attn is enabled. Additionally, the FA4 interface in flash_attn_interface.py must be updated to pass the q_v parameter required by MLA.
| def should_use_system_flash_attn() -> bool: | ||
| """Check if the system flash-attn library should be used based on config/env.""" | ||
| if not current_platform.is_cuda(): | ||
| return False | ||
| from vllm.config import get_current_vllm_config_or_none | ||
|
|
||
| vllm_config = get_current_vllm_config_or_none() | ||
| if vllm_config is not None: | ||
| return vllm_config.attention_config.use_system_flash_attn | ||
| return False |
There was a problem hiding this comment.
The should_use_system_flash_attn function currently only checks the configuration but does not verify if the system Flash Attention library is actually installed. This can lead to supports_head_size returning True for head sizes up to 512 even when the required library is missing. While is_flash_attn_varlen_func_available performs this check later, it is more robust to include the availability check here to ensure consistency and prevent incorrect backend selection for unsupported head sizes.
| def should_use_system_flash_attn() -> bool: | |
| """Check if the system flash-attn library should be used based on config/env.""" | |
| if not current_platform.is_cuda(): | |
| return False | |
| from vllm.config import get_current_vllm_config_or_none | |
| vllm_config = get_current_vllm_config_or_none() | |
| if vllm_config is not None: | |
| return vllm_config.attention_config.use_system_flash_attn | |
| return False | |
| def should_use_system_flash_attn() -> bool: | |
| """Check if the system flash-attn library should be used based on config/env.""" | |
| if not current_platform.is_cuda(): | |
| return False | |
| from vllm.config import get_current_vllm_config_or_none | |
| vllm_config = get_current_vllm_config_or_none() | |
| if vllm_config is not None and vllm_config.attention_config.use_system_flash_attn: | |
| import importlib.util | |
| return importlib.util.find_spec("flash_attn.cute") is not None | |
| return False |
| if should_use_system_flash_attn(): | ||
| if self.vllm_flash_attn_version != 4: | ||
| logger.warning_once( | ||
| f"System Flash Attention is only compatible with Flash Attention 4 " | ||
| f"but the detected version is {self.vllm_flash_attn_version}. " | ||
| f"Disabling system flash attention.", | ||
| scope="local", | ||
| ) | ||
| else: | ||
| kwargs["use_system_flash_attn"] = True |
There was a problem hiding this comment.
Performing should_use_system_flash_attn() and version checks inside _flash_attn_varlen_diff_headdims is inefficient as this method is called during the attention forward pass (prefill). should_use_system_flash_attn() involves a global configuration lookup which should be avoided in the hot path. This logic should be moved to the __init__ method of the implementation class, and a boolean flag (e.g., self.use_system_flash_attn) should be used here instead to minimize overhead.
| if use_system_flash_attn: | ||
| try: | ||
| from flash_attn.cute.interface import _flash_attn_fwd | ||
| except ImportError: | ||
| use_system_flash_attn = False | ||
| if not use_system_flash_attn: | ||
| if block_table is not None and current_platform.is_device_capability_family( | ||
| 90 | ||
| ): | ||
| raise NotImplementedError( | ||
| "FA4 with paged KV is not supported on SM90 (Hopper). " | ||
| "Use FA3 or upgrade to Blackwell (SM100+)." | ||
| ) | ||
| from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd |
There was a problem hiding this comment.
Importing modules inside flash_attn_varlen_func introduces unnecessary overhead in the critical path of every attention operation. While Python caches imports, the repeated string lookups and try-except blocks can impact performance in high-throughput scenarios. These imports should be performed at the module level or cached. Furthermore, the FA4 implementation of _flash_attn_fwd (lines 397-415) is missing the q_v parameter, which is essential for MLA support. Without passing q_v, MLA models will produce incorrect results when using FA4.
|
Close this PR as it's replaced by #38835 |
Purpose
This PR aims to allow user to choose sytem FA4 in order to unblock the head dim 512 and page KV for SM90.
Related Flash Attention PRs:
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.