diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index b92b822c1d19..f9b33fef45bb 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -43,20 +43,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: # 2. override if passed by environment if envs.VLLM_FLASH_ATTN_VERSION is not None: - assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] + assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3, 4] fa_version = envs.VLLM_FLASH_ATTN_VERSION - # 3. fallback for unsupported combinations - if device_capability.major == 10 and fa_version == 3: + if requires_alibi and fa_version in [3, 4]: logger.warning_once( - "Cannot use FA version 3 on Blackwell platform " - "defaulting to FA version 2." - ) - fa_version = 2 - - if requires_alibi and fa_version == 3: - logger.warning_once( - "Cannot use FA version 3 with ALiBi, defaulting to FA version 2." + "Cannot use FA version %d with ALiBi, defaulting to FA version 2.", + fa_version, ) fa_version = 2 diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0252c3acb08c..ac1d65dccf5e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -374,6 +374,12 @@ def get_attn_backend_cls( # Default backends for V1 engine # Prefer FlashInfer for Blackwell GPUs if installed if cls.is_device_capability(100): + if envs.VLLM_FLASH_ATTN_VERSION == 4 and is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ): + logger.info_once("Using Flash Attention 4 backend on V1 engine.") + return FLASH_ATTN_V1 + if is_default_backend_supported := is_attn_backend_supported( FLASHINFER_V1, head_size, dtype ): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fa4e34536135..744a8d9d368f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -57,10 +57,15 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: @classmethod def get_supported_head_sizes(cls) -> list[int]: + # FIXME (zyongye): change this until FA4 support more head_dim + if envs.VLLM_FLASH_ATTN_VERSION == 4: + return [64, 96, 128] return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod def get_supported_kernel_block_size() -> list[int | MultipleOf]: + if envs.VLLM_FLASH_ATTN_VERSION == 4: + return [128] return [MultipleOf(16)] @classmethod @@ -486,8 +491,8 @@ def __init__( self.sinks = sinks if self.sinks is not None: - assert self.vllm_flash_attn_version == 3, ( - "Sinks are only supported in FlashAttention 3" + assert self.vllm_flash_attn_version in [3, 4], ( + "Sinks are only supported in FlashAttention 3 and 4" ) assert self.sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " @@ -606,6 +611,9 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) if self.dcp_world_size > 1: + assert get_flash_attn_version() != 4, ( + "Distributed Context Parallel doesn't support FA4 yet" + ) self._forward_with_dcp( query[:num_actual_tokens], key[:num_actual_tokens], @@ -620,29 +628,59 @@ def forward( ) return output else: - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - scheduler_metadata=scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=attn_metadata.max_num_splits, - s_aux=self.sinks, - ) + if envs.VLLM_FLASH_ATTN_VERSION == 4: + from flash_attn.cute.interface import _flash_attn_fwd + + window_size = ( + None + if self.sliding_window[0] == -1 + else self.sliding_window[0], + None + if self.sliding_window[1] == -1 + else self.sliding_window[1], + ) + output, lse, *rest = _flash_attn_fwd( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=seqused_k, + page_table=block_table, + softmax_scale=self.scale, + causal=attn_metadata.causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + learnable_sink=self.sinks, + softcap=self.logits_soft_cap, + return_lse=False, + out=output[:num_actual_tokens], + ) + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) return output # Cascade attention (rare case).