-
-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[Attention] Allow using system FA4 #38823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,6 +206,7 @@ def flash_attn_varlen_func( | |
| cp_world_size=1, | ||
| cp_rank=0, | ||
| cp_tot_seqused_k=None, | ||
| use_system_flash_attn: bool = False, | ||
| ): | ||
| """dropout_p should be set to 0.0 during evaluation | ||
| Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads | ||
|
|
@@ -253,6 +254,7 @@ def flash_attn_varlen_func( | |
| return_attn_probs: bool. Whether to return the attention probabilities. This option is for | ||
| testing only. The returned probabilities are not guaranteed to be correct | ||
| (they might not have the right scaling). | ||
| use_system_flash_attn: bool. Whether to use the Flash Attention library installed in the system. | ||
| Return: | ||
| out: (total, nheads, headdim). | ||
| softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The | ||
|
|
@@ -283,6 +285,10 @@ def flash_attn_varlen_func( | |
| dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) | ||
|
|
||
| if fa_version == 2: | ||
| if use_system_flash_attn: | ||
| raise NotImplementedError( | ||
| "Using system flash-attn is not supported for FA2. Please use FA4." | ||
| ) | ||
| if ( | ||
| scheduler_metadata is not None | ||
| and q_descale is not None | ||
|
|
@@ -324,6 +330,10 @@ def flash_attn_varlen_func( | |
| None, | ||
| ) | ||
| elif fa_version == 3: | ||
| if use_system_flash_attn: | ||
| raise NotImplementedError( | ||
| "Using system flash-attn is not supported for FA3. Please use FA4." | ||
| ) | ||
| assert alibi_slopes is None, "Alibi is not supported in FA3" | ||
| out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( | ||
| q, | ||
|
|
@@ -369,12 +379,20 @@ def flash_attn_varlen_func( | |
| # FA4 on SM90 doesn't support paged KV; SM100+ does | ||
| from vllm.platforms import current_platform | ||
|
|
||
| if block_table is not None and current_platform.is_device_capability_family(90): | ||
| raise NotImplementedError( | ||
| "FA4 with paged KV is not supported on SM90 (Hopper). " | ||
| "Use FA3 or upgrade to Blackwell (SM100+)." | ||
| ) | ||
| from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd | ||
| if use_system_flash_attn: | ||
| try: | ||
| from flash_attn.cute.interface import _flash_attn_fwd | ||
| except ImportError: | ||
| use_system_flash_attn = False | ||
| if not use_system_flash_attn: | ||
| if block_table is not None and current_platform.is_device_capability_family( | ||
| 90 | ||
| ): | ||
| raise NotImplementedError( | ||
| "FA4 with paged KV is not supported on SM90 (Hopper). " | ||
| "Use FA3 or upgrade to Blackwell (SM100+)." | ||
| ) | ||
| from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd | ||
|
Comment on lines
+382
to
+395
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Importing modules inside |
||
|
|
||
| out, softmax_lse = _flash_attn_fwd( | ||
| q, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
forward_mqa,fa_versionis hardcoded to3(at line 359). However, the PR addsuse_system_flash_attn=self.use_system_flash_attn. Ifuse_system_flash_attnisTrue,flash_attn_varlen_funcwill raise aNotImplementedErrorbecause it explicitly forbids using system Flash Attention withfa_version=3. To support system Flash Attention 4 for MLA, thefa_versionparameter must be updated to4whenuse_system_flash_attnis enabled. Additionally, the FA4 interface inflash_attn_interface.pymust be updated to pass theq_vparameter required by MLA.