From c23242388b12bd77f33a818d455e3b9745e868fe Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Thu, 2 Apr 2026 15:52:54 +0000 Subject: [PATCH 1/5] add attention-config.use_system_flash_attn Signed-off-by: Siyuan Fu --- vllm/config/attention.py | 4 +++ .../layers/attention/mla_attention.py | 15 +++++++++- vllm/v1/attention/backends/fa_utils.py | 24 +++++++++++++++ vllm/v1/attention/backends/flash_attn.py | 23 +++++++++++++- .../attention/backends/mla/flashattn_mla.py | 16 +++++++++- vllm/vllm_flash_attn/flash_attn_interface.py | 30 +++++++++++++++---- 6 files changed, 103 insertions(+), 9 deletions(-) diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 014bb9b22601..4044a4bd68b5 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -20,6 +20,10 @@ class AttentionConfig: """Force vllm to use a specific flash-attention version (2, 3, or 4). Only valid when using the flash-attention backend.""" + use_system_flash_attn: bool = False + """Force vllm to use the Flash Attention library installed in the system instead of + the one bundled with vllm. Currently only supports Flash Attention 4""" + use_prefill_decode_attention: bool = False """Use separate prefill and decode kernels for attention instead of the unified triton kernel.""" diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 0be46fbbc5a4..0709e400a4b7 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -254,7 +254,10 @@ MLAAttentionImpl, SparseMLAAttentionImpl, ) -from vllm.v1.attention.backends.fa_utils import get_flash_attn_version +from vllm.v1.attention.backends.fa_utils import ( + get_flash_attn_version, + should_use_system_flash_attn, +) from vllm.v1.attention.backends.utils import ( get_dcp_local_seq_lens, get_per_layer_parameters, @@ -2202,6 +2205,16 @@ def _flash_attn_varlen_diff_headdims( kwargs["return_attn_probs"] = return_softmax_lse if envs.VLLM_BATCH_INVARIANT: kwargs["num_splits"] = 1 + if should_use_system_flash_attn(): + if self.vllm_flash_attn_version != 4: + logger.warning_once( + f"System Flash Attention is only compatible with Flash Attention 4 " + f"but the detected version is {self.vllm_flash_attn_version}. " + f"Disabling system flash attention.", + scope="local", + ) + else: + kwargs["use_system_flash_attn"] = True attn_out = self.flash_attn_varlen_func( q=q, diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index a4423b301d69..ae6b34b9c562 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -53,6 +53,18 @@ def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[m reshape_and_cache_flash = ops.reshape_and_cache_flash +def should_use_system_flash_attn() -> bool: + """Check if the system flash-attn library should be used based on config/env.""" + if not current_platform.is_cuda(): + return False + from vllm.config import get_current_vllm_config_or_none + + vllm_config = get_current_vllm_config_or_none() + if vllm_config is not None: + return vllm_config.attention_config.use_system_flash_attn + return False + + def get_flash_attn_version( requires_alibi: bool = False, head_size: int | None = None ) -> int | None: @@ -207,6 +219,18 @@ def is_flash_attn_varlen_func_available() -> bool: Returns: bool: True if a working flash_attn_varlen_func implementation is available. """ + if should_use_system_flash_attn(): + # Attempt to import flash_attn_varlen_func from system flash-attn + # Currently only supports Flash Attention 4 + import importlib.util + + if not importlib.util.find_spec("flash_attn.cute"): + logger.warning( + "attention-config.use_system_flash_attn is set, but " + "failed to import system flash-attn. " + ) + return False + return True if current_platform.is_cuda() or current_platform.is_xpu(): # CUDA and XPU always have flash_attn_varlen_func available return True diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d72c2aeb6161..12ca648f93a5 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -21,6 +21,7 @@ flash_attn_supports_fp8, get_flash_attn_version, is_flash_attn_varlen_func_available, + should_use_system_flash_attn, ) from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens from vllm.v1.attention.ops.common import cp_lse_ag_out_rs @@ -171,7 +172,11 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: @classmethod def supports_head_size(cls, head_size: int) -> bool: - return head_size % 8 == 0 and head_size <= 256 + if head_size % 8 != 0: + return False + if should_use_system_flash_attn(): + return head_size <= 512 + return head_size <= 256 @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: @@ -624,6 +629,17 @@ def __init__( self.vllm_flash_attn_version, scope="local", ) + self.use_system_flash_attn = False + if should_use_system_flash_attn(): + if self.vllm_flash_attn_version != 4: + logger.warning_once( + f"System Flash Attention is only compatible with Flash Attention 4 " + f"but the detected version is {self.vllm_flash_attn_version}. " + f"Disabling system flash attention.", + scope="local", + ) + else: + self.use_system_flash_attn = True # Cache the batch invariant result for use in forward passes self.batch_invariant_enabled = envs.VLLM_BATCH_INVARIANT @@ -790,6 +806,7 @@ def forward( v_descale=v_descale, num_splits=attn_metadata.max_num_splits, s_aux=self.sinks, + use_system_flash_attn=self.use_system_flash_attn, ) return output @@ -1017,6 +1034,7 @@ def _forward_encoder_attention( k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), num_splits=1 if self.batch_invariant_enabled else 0, + use_system_flash_attn=self.use_system_flash_attn, ) return output @@ -1125,6 +1143,7 @@ def cascade_attention( k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, s_aux: torch.Tensor | None = None, + use_system_flash_attn: bool = False, ) -> torch.Tensor: assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. @@ -1163,6 +1182,7 @@ def cascade_attention( # enabling its effect during the final attention merge. s_aux=s_aux, num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits, + use_system_flash_attn=use_system_flash_attn, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -1188,6 +1208,7 @@ def cascade_attention( k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits, + use_system_flash_attn=use_system_flash_attn, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index f58d9aeb302b..6f2c7f3f36e9 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -30,6 +30,7 @@ from vllm.v1.attention.backends.fa_utils import ( flash_attn_supports_mla, get_flash_attn_version, + should_use_system_flash_attn, ) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] @@ -121,7 +122,19 @@ def __init__( supports_dcp_with_varlen=(interleave_size == 1), ) self.max_num_splits = 0 # No upper bound on the number of splits. - self.fa_aot_schedule = get_flash_attn_version() == 3 + self.vllm_flash_attn_version = get_flash_attn_version() + self.use_system_flash_attn = False + if should_use_system_flash_attn(): + if self.vllm_flash_attn_version != 4: + logger.warning_once( + f"System Flash Attention is only compatible with Flash Attention 4 " + f"but the detected version is {self.vllm_flash_attn_version}. " + f"Disabling system flash attention.", + scope="local", + ) + else: + self.use_system_flash_attn = True + self.fa_aot_schedule = self.vllm_flash_attn_version == 3 self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() @@ -349,6 +362,7 @@ def forward_mqa( cp_world_size=self.dcp_world_size, cp_rank=self.dcp_rank, cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, + use_system_flash_attn=self.use_system_flash_attn, ) if self.need_to_return_lse_for_decode: diff --git a/vllm/vllm_flash_attn/flash_attn_interface.py b/vllm/vllm_flash_attn/flash_attn_interface.py index 9d9a9be2f316..a64d42aa9e7d 100644 --- a/vllm/vllm_flash_attn/flash_attn_interface.py +++ b/vllm/vllm_flash_attn/flash_attn_interface.py @@ -206,6 +206,7 @@ def flash_attn_varlen_func( cp_world_size=1, cp_rank=0, cp_tot_seqused_k=None, + use_system_flash_attn: bool = False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -253,6 +254,7 @@ def flash_attn_varlen_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + use_system_flash_attn: bool. Whether to use the Flash Attention library installed in the system. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The @@ -283,6 +285,10 @@ def flash_attn_varlen_func( dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) if fa_version == 2: + if use_system_flash_attn: + raise NotImplementedError( + "Using system flash-attn is not supported for FA2. Please use FA4." + ) if ( scheduler_metadata is not None and q_descale is not None @@ -324,6 +330,10 @@ def flash_attn_varlen_func( None, ) elif fa_version == 3: + if use_system_flash_attn: + raise NotImplementedError( + "Using system flash-attn is not supported for FA3. Please use FA4." + ) assert alibi_slopes is None, "Alibi is not supported in FA3" out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, @@ -369,12 +379,20 @@ def flash_attn_varlen_func( # FA4 on SM90 doesn't support paged KV; SM100+ does from vllm.platforms import current_platform - if block_table is not None and current_platform.is_device_capability_family(90): - raise NotImplementedError( - "FA4 with paged KV is not supported on SM90 (Hopper). " - "Use FA3 or upgrade to Blackwell (SM100+)." - ) - from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd + if use_system_flash_attn: + try: + from flash_attn.cute.interface import _flash_attn_fwd + except ImportError: + use_system_flash_attn = False + if not use_system_flash_attn: + if block_table is not None and current_platform.is_device_capability_family( + 90 + ): + raise NotImplementedError( + "FA4 with paged KV is not supported on SM90 (Hopper). " + "Use FA3 or upgrade to Blackwell (SM100+)." + ) + from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd out, softmax_lse = _flash_attn_fwd( q, From d377ec4ed1f91329bfd33652191947e174b04a9f Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Thu, 2 Apr 2026 16:15:21 +0000 Subject: [PATCH 2/5] move version check to should_use_system_flash_attn() Signed-off-by: Siyuan Fu --- .../layers/attention/mla_attention.py | 11 +----- vllm/v1/attention/backends/fa_utils.py | 37 +++++++++++-------- vllm/v1/attention/backends/flash_attn.py | 12 +----- .../attention/backends/mla/flashattn_mla.py | 16 +------- 4 files changed, 26 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 0709e400a4b7..4cd8acb02890 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -2205,16 +2205,7 @@ def _flash_attn_varlen_diff_headdims( kwargs["return_attn_probs"] = return_softmax_lse if envs.VLLM_BATCH_INVARIANT: kwargs["num_splits"] = 1 - if should_use_system_flash_attn(): - if self.vllm_flash_attn_version != 4: - logger.warning_once( - f"System Flash Attention is only compatible with Flash Attention 4 " - f"but the detected version is {self.vllm_flash_attn_version}. " - f"Disabling system flash attention.", - scope="local", - ) - else: - kwargs["use_system_flash_attn"] = True + kwargs["use_system_flash_attn"] = should_use_system_flash_attn() attn_out = self.flash_attn_varlen_func( q=q, diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index ae6b34b9c562..a81be13a1cae 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -53,18 +53,6 @@ def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[m reshape_and_cache_flash = ops.reshape_and_cache_flash -def should_use_system_flash_attn() -> bool: - """Check if the system flash-attn library should be used based on config/env.""" - if not current_platform.is_cuda(): - return False - from vllm.config import get_current_vllm_config_or_none - - vllm_config = get_current_vllm_config_or_none() - if vllm_config is not None: - return vllm_config.attention_config.use_system_flash_attn - return False - - def get_flash_attn_version( requires_alibi: bool = False, head_size: int | None = None ) -> int | None: @@ -166,6 +154,25 @@ def get_flash_attn_version( return None +def should_use_system_flash_attn() -> bool: + """Check if the system flash-attn library should be used based on config/env.""" + if not current_platform.is_cuda(): + return False + if get_flash_attn_version() != 4: + logger.warning_once( + f"System Flash Attention is only compatible with Flash Attention 4 " + f"but the detected version is {get_flash_attn_version()}. " + f"Disabling system flash attention.", + scope="local", + ) + from vllm.config import get_current_vllm_config_or_none + + vllm_config = get_current_vllm_config_or_none() + if vllm_config is not None: + return vllm_config.attention_config.use_system_flash_attn + return False + + def flash_attn_supports_fp8() -> bool: return ( get_flash_attn_version() == 3 @@ -224,10 +231,10 @@ def is_flash_attn_varlen_func_available() -> bool: # Currently only supports Flash Attention 4 import importlib.util - if not importlib.util.find_spec("flash_attn.cute"): - logger.warning( + if not importlib.util.find_spec("flash_attn"): + logger.warning_once( "attention-config.use_system_flash_attn is set, but " - "failed to import system flash-attn. " + "failed to import system flash_attn. " ) return False return True diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 12ca648f93a5..e2a6a3f17474 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -629,17 +629,7 @@ def __init__( self.vllm_flash_attn_version, scope="local", ) - self.use_system_flash_attn = False - if should_use_system_flash_attn(): - if self.vllm_flash_attn_version != 4: - logger.warning_once( - f"System Flash Attention is only compatible with Flash Attention 4 " - f"but the detected version is {self.vllm_flash_attn_version}. " - f"Disabling system flash attention.", - scope="local", - ) - else: - self.use_system_flash_attn = True + self.use_system_flash_attn = should_use_system_flash_attn() # Cache the batch invariant result for use in forward passes self.batch_invariant_enabled = envs.VLLM_BATCH_INVARIANT diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 6f2c7f3f36e9..db7e150a6acb 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -29,7 +29,6 @@ ) from vllm.v1.attention.backends.fa_utils import ( flash_attn_supports_mla, - get_flash_attn_version, should_use_system_flash_attn, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -122,19 +121,7 @@ def __init__( supports_dcp_with_varlen=(interleave_size == 1), ) self.max_num_splits = 0 # No upper bound on the number of splits. - self.vllm_flash_attn_version = get_flash_attn_version() - self.use_system_flash_attn = False - if should_use_system_flash_attn(): - if self.vllm_flash_attn_version != 4: - logger.warning_once( - f"System Flash Attention is only compatible with Flash Attention 4 " - f"but the detected version is {self.vllm_flash_attn_version}. " - f"Disabling system flash attention.", - scope="local", - ) - else: - self.use_system_flash_attn = True - self.fa_aot_schedule = self.vllm_flash_attn_version == 3 + self.fa_aot_schedule = should_use_system_flash_attn() == 3 self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() @@ -314,6 +301,7 @@ def __init__( raise NotImplementedError( "FlashAttnMLA V1 with FP8 KV cache not yet supported" ) + self.use_system_flash_attn = should_use_system_flash_attn() def forward_mqa( self, From f253ae9f3382c4de1f64114b5c6d84b0cb23e218 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Thu, 2 Apr 2026 16:16:51 +0000 Subject: [PATCH 3/5] fix typo Signed-off-by: Siyuan Fu --- vllm/v1/attention/backends/mla/flashattn_mla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index db7e150a6acb..2a95d40827b8 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -29,6 +29,7 @@ ) from vllm.v1.attention.backends.fa_utils import ( flash_attn_supports_mla, + get_flash_attn_version, should_use_system_flash_attn, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -121,7 +122,7 @@ def __init__( supports_dcp_with_varlen=(interleave_size == 1), ) self.max_num_splits = 0 # No upper bound on the number of splits. - self.fa_aot_schedule = should_use_system_flash_attn() == 3 + self.fa_aot_schedule = get_flash_attn_version() == 3 self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() From 0e95198c1ad62b2f8c7c7e205ecb4440c5e353f6 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Thu, 2 Apr 2026 16:18:33 +0000 Subject: [PATCH 4/5] minor fix Signed-off-by: Siyuan Fu --- vllm/v1/attention/backends/fa_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index a81be13a1cae..416402deff28 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -165,6 +165,7 @@ def should_use_system_flash_attn() -> bool: f"Disabling system flash attention.", scope="local", ) + return False from vllm.config import get_current_vllm_config_or_none vllm_config = get_current_vllm_config_or_none() From fe7731e3c0e4fb06b957f2a3fdbd43245aaeb475 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Thu, 2 Apr 2026 16:20:46 +0000 Subject: [PATCH 5/5] minor fix Signed-off-by: Siyuan Fu --- vllm/v1/attention/backends/fa_utils.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 416402deff28..c766e97156ed 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -158,19 +158,20 @@ def should_use_system_flash_attn() -> bool: """Check if the system flash-attn library should be used based on config/env.""" if not current_platform.is_cuda(): return False - if get_flash_attn_version() != 4: - logger.warning_once( - f"System Flash Attention is only compatible with Flash Attention 4 " - f"but the detected version is {get_flash_attn_version()}. " - f"Disabling system flash attention.", - scope="local", - ) - return False from vllm.config import get_current_vllm_config_or_none vllm_config = get_current_vllm_config_or_none() - if vllm_config is not None: - return vllm_config.attention_config.use_system_flash_attn + if vllm_config is not None and vllm_config.attention_config.use_system_flash_attn: + if get_flash_attn_version() != 4: + logger.warning_once( + f"System Flash Attention is only compatible with Flash Attention 4 " + f"but the detected version is {get_flash_attn_version()}. " + f"Disabling system flash attention.", + scope="local", + ) + return False + else: + return True return False