diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 014bb9b22601..4044a4bd68b5 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -20,6 +20,10 @@ class AttentionConfig: """Force vllm to use a specific flash-attention version (2, 3, or 4). Only valid when using the flash-attention backend.""" + use_system_flash_attn: bool = False + """Force vllm to use the Flash Attention library installed in the system instead of + the one bundled with vllm. Currently only supports Flash Attention 4""" + use_prefill_decode_attention: bool = False """Use separate prefill and decode kernels for attention instead of the unified triton kernel.""" diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 0be46fbbc5a4..4cd8acb02890 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -254,7 +254,10 @@ MLAAttentionImpl, SparseMLAAttentionImpl, ) -from vllm.v1.attention.backends.fa_utils import get_flash_attn_version +from vllm.v1.attention.backends.fa_utils import ( + get_flash_attn_version, + should_use_system_flash_attn, +) from vllm.v1.attention.backends.utils import ( get_dcp_local_seq_lens, get_per_layer_parameters, @@ -2202,6 +2205,7 @@ def _flash_attn_varlen_diff_headdims( kwargs["return_attn_probs"] = return_softmax_lse if envs.VLLM_BATCH_INVARIANT: kwargs["num_splits"] = 1 + kwargs["use_system_flash_attn"] = should_use_system_flash_attn() attn_out = self.flash_attn_varlen_func( q=q, diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index a4423b301d69..c766e97156ed 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -154,6 +154,27 @@ def get_flash_attn_version( return None +def should_use_system_flash_attn() -> bool: + """Check if the system flash-attn library should be used based on config/env.""" + if not current_platform.is_cuda(): + return False + from vllm.config import get_current_vllm_config_or_none + + vllm_config = get_current_vllm_config_or_none() + if vllm_config is not None and vllm_config.attention_config.use_system_flash_attn: + if get_flash_attn_version() != 4: + logger.warning_once( + f"System Flash Attention is only compatible with Flash Attention 4 " + f"but the detected version is {get_flash_attn_version()}. " + f"Disabling system flash attention.", + scope="local", + ) + return False + else: + return True + return False + + def flash_attn_supports_fp8() -> bool: return ( get_flash_attn_version() == 3 @@ -207,6 +228,18 @@ def is_flash_attn_varlen_func_available() -> bool: Returns: bool: True if a working flash_attn_varlen_func implementation is available. """ + if should_use_system_flash_attn(): + # Attempt to import flash_attn_varlen_func from system flash-attn + # Currently only supports Flash Attention 4 + import importlib.util + + if not importlib.util.find_spec("flash_attn"): + logger.warning_once( + "attention-config.use_system_flash_attn is set, but " + "failed to import system flash_attn. " + ) + return False + return True if current_platform.is_cuda() or current_platform.is_xpu(): # CUDA and XPU always have flash_attn_varlen_func available return True diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d72c2aeb6161..e2a6a3f17474 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -21,6 +21,7 @@ flash_attn_supports_fp8, get_flash_attn_version, is_flash_attn_varlen_func_available, + should_use_system_flash_attn, ) from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens from vllm.v1.attention.ops.common import cp_lse_ag_out_rs @@ -171,7 +172,11 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: @classmethod def supports_head_size(cls, head_size: int) -> bool: - return head_size % 8 == 0 and head_size <= 256 + if head_size % 8 != 0: + return False + if should_use_system_flash_attn(): + return head_size <= 512 + return head_size <= 256 @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: @@ -624,6 +629,7 @@ def __init__( self.vllm_flash_attn_version, scope="local", ) + self.use_system_flash_attn = should_use_system_flash_attn() # Cache the batch invariant result for use in forward passes self.batch_invariant_enabled = envs.VLLM_BATCH_INVARIANT @@ -790,6 +796,7 @@ def forward( v_descale=v_descale, num_splits=attn_metadata.max_num_splits, s_aux=self.sinks, + use_system_flash_attn=self.use_system_flash_attn, ) return output @@ -1017,6 +1024,7 @@ def _forward_encoder_attention( k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), num_splits=1 if self.batch_invariant_enabled else 0, + use_system_flash_attn=self.use_system_flash_attn, ) return output @@ -1125,6 +1133,7 @@ def cascade_attention( k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, s_aux: torch.Tensor | None = None, + use_system_flash_attn: bool = False, ) -> torch.Tensor: assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. @@ -1163,6 +1172,7 @@ def cascade_attention( # enabling its effect during the final attention merge. s_aux=s_aux, num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits, + use_system_flash_attn=use_system_flash_attn, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -1188,6 +1198,7 @@ def cascade_attention( k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits, + use_system_flash_attn=use_system_flash_attn, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index f58d9aeb302b..2a95d40827b8 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -30,6 +30,7 @@ from vllm.v1.attention.backends.fa_utils import ( flash_attn_supports_mla, get_flash_attn_version, + should_use_system_flash_attn, ) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] @@ -301,6 +302,7 @@ def __init__( raise NotImplementedError( "FlashAttnMLA V1 with FP8 KV cache not yet supported" ) + self.use_system_flash_attn = should_use_system_flash_attn() def forward_mqa( self, @@ -349,6 +351,7 @@ def forward_mqa( cp_world_size=self.dcp_world_size, cp_rank=self.dcp_rank, cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, + use_system_flash_attn=self.use_system_flash_attn, ) if self.need_to_return_lse_for_decode: diff --git a/vllm/vllm_flash_attn/flash_attn_interface.py b/vllm/vllm_flash_attn/flash_attn_interface.py index 9d9a9be2f316..a64d42aa9e7d 100644 --- a/vllm/vllm_flash_attn/flash_attn_interface.py +++ b/vllm/vllm_flash_attn/flash_attn_interface.py @@ -206,6 +206,7 @@ def flash_attn_varlen_func( cp_world_size=1, cp_rank=0, cp_tot_seqused_k=None, + use_system_flash_attn: bool = False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -253,6 +254,7 @@ def flash_attn_varlen_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + use_system_flash_attn: bool. Whether to use the Flash Attention library installed in the system. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The @@ -283,6 +285,10 @@ def flash_attn_varlen_func( dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) if fa_version == 2: + if use_system_flash_attn: + raise NotImplementedError( + "Using system flash-attn is not supported for FA2. Please use FA4." + ) if ( scheduler_metadata is not None and q_descale is not None @@ -324,6 +330,10 @@ def flash_attn_varlen_func( None, ) elif fa_version == 3: + if use_system_flash_attn: + raise NotImplementedError( + "Using system flash-attn is not supported for FA3. Please use FA4." + ) assert alibi_slopes is None, "Alibi is not supported in FA3" out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, @@ -369,12 +379,20 @@ def flash_attn_varlen_func( # FA4 on SM90 doesn't support paged KV; SM100+ does from vllm.platforms import current_platform - if block_table is not None and current_platform.is_device_capability_family(90): - raise NotImplementedError( - "FA4 with paged KV is not supported on SM90 (Hopper). " - "Use FA3 or upgrade to Blackwell (SM100+)." - ) - from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd + if use_system_flash_attn: + try: + from flash_attn.cute.interface import _flash_attn_fwd + except ImportError: + use_system_flash_attn = False + if not use_system_flash_attn: + if block_table is not None and current_platform.is_device_capability_family( + 90 + ): + raise NotImplementedError( + "FA4 with paged KV is not supported on SM90 (Hopper). " + "Use FA3 or upgrade to Blackwell (SM100+)." + ) + from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd out, softmax_lse = _flash_attn_fwd( q,