Skip to content

[Attention] Allow using system FA4#38823

Closed
IwakuraRein wants to merge 5 commits intovllm-project:mainfrom
IwakuraRein:use-system-fa4
Closed

[Attention] Allow using system FA4#38823
IwakuraRein wants to merge 5 commits intovllm-project:mainfrom
IwakuraRein:use-system-fa4

Conversation

@IwakuraRein
Copy link
Copy Markdown
Contributor

@IwakuraRein IwakuraRein commented Apr 2, 2026

Purpose

This PR aims to allow user to choose sytem FA4 in order to unblock the head dim 512 and page KV for SM90.

Related Flash Attention PRs:

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@mergify mergify bot added the v1 label Apr 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the use_system_flash_attn configuration, enabling vLLM to utilize a system-installed Flash Attention 4 library. The implementation includes updates to the attention configuration, backend selection logic, and the Flash Attention interface to support head sizes up to 512. However, several critical issues were identified: the MLA implementation hardcodes Flash Attention version 3, which conflicts with the system Flash Attention requirement for version 4; performance overhead is introduced by performing configuration lookups and imports within the hot path of attention operations; and the Flash Attention 4 interface is missing the q_v parameter necessary for MLA. Additionally, it is recommended to verify the availability of the system library within the configuration check to ensure accurate backend capability reporting.

cp_world_size=self.dcp_world_size,
cp_rank=self.dcp_rank,
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
use_system_flash_attn=self.use_system_flash_attn,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In forward_mqa, fa_version is hardcoded to 3 (at line 359). However, the PR adds use_system_flash_attn=self.use_system_flash_attn. If use_system_flash_attn is True, flash_attn_varlen_func will raise a NotImplementedError because it explicitly forbids using system Flash Attention with fa_version=3. To support system Flash Attention 4 for MLA, the fa_version parameter must be updated to 4 when use_system_flash_attn is enabled. Additionally, the FA4 interface in flash_attn_interface.py must be updated to pass the q_v parameter required by MLA.

Comment on lines +56 to +65
def should_use_system_flash_attn() -> bool:
"""Check if the system flash-attn library should be used based on config/env."""
if not current_platform.is_cuda():
return False
from vllm.config import get_current_vllm_config_or_none

vllm_config = get_current_vllm_config_or_none()
if vllm_config is not None:
return vllm_config.attention_config.use_system_flash_attn
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The should_use_system_flash_attn function currently only checks the configuration but does not verify if the system Flash Attention library is actually installed. This can lead to supports_head_size returning True for head sizes up to 512 even when the required library is missing. While is_flash_attn_varlen_func_available performs this check later, it is more robust to include the availability check here to ensure consistency and prevent incorrect backend selection for unsupported head sizes.

Suggested change
def should_use_system_flash_attn() -> bool:
"""Check if the system flash-attn library should be used based on config/env."""
if not current_platform.is_cuda():
return False
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if vllm_config is not None:
return vllm_config.attention_config.use_system_flash_attn
return False
def should_use_system_flash_attn() -> bool:
"""Check if the system flash-attn library should be used based on config/env."""
if not current_platform.is_cuda():
return False
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if vllm_config is not None and vllm_config.attention_config.use_system_flash_attn:
import importlib.util
return importlib.util.find_spec("flash_attn.cute") is not None
return False

Comment on lines +2208 to +2217
if should_use_system_flash_attn():
if self.vllm_flash_attn_version != 4:
logger.warning_once(
f"System Flash Attention is only compatible with Flash Attention 4 "
f"but the detected version is {self.vllm_flash_attn_version}. "
f"Disabling system flash attention.",
scope="local",
)
else:
kwargs["use_system_flash_attn"] = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Performing should_use_system_flash_attn() and version checks inside _flash_attn_varlen_diff_headdims is inefficient as this method is called during the attention forward pass (prefill). should_use_system_flash_attn() involves a global configuration lookup which should be avoided in the hot path. This logic should be moved to the __init__ method of the implementation class, and a boolean flag (e.g., self.use_system_flash_attn) should be used here instead to minimize overhead.

Comment on lines +382 to +395
if use_system_flash_attn:
try:
from flash_attn.cute.interface import _flash_attn_fwd
except ImportError:
use_system_flash_attn = False
if not use_system_flash_attn:
if block_table is not None and current_platform.is_device_capability_family(
90
):
raise NotImplementedError(
"FA4 with paged KV is not supported on SM90 (Hopper). "
"Use FA3 or upgrade to Blackwell (SM100+)."
)
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing modules inside flash_attn_varlen_func introduces unnecessary overhead in the critical path of every attention operation. While Python caches imports, the repeated string lookups and try-except blocks can impact performance in high-throughput scenarios. These imports should be performed at the module level or cached. Furthermore, the FA4 implementation of _flash_attn_fwd (lines 397-415) is missing the q_v parameter, which is essential for MLA support. Without passing q_v, MLA models will produce incorrect results when using FA4.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Copy Markdown
Contributor Author

Close this PR as it's replaced by #38835

@IwakuraRein IwakuraRein closed this Apr 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant