Skip to content
Merged
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG c0ec424fd8a546d0cbbf4bf050bbcfe837c55afb
GIT_TAG f5bc33cfc02c744d24a2e9d50e6db656de40611c
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
11 changes: 11 additions & 0 deletions vllm/v1/attention/backends/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ def get_flash_attn_version(
return None


def is_fa_version_supported(fa_version: int) -> bool:
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
is_fa_version_supported as _is_fa_version_supported,
)

return _is_fa_version_supported(fa_version)
except ImportError:
return False


def flash_attn_supports_fp8() -> bool:
return (
get_flash_attn_version() == 3
Expand Down
19 changes: 17 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
Expand All @@ -20,6 +21,7 @@
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_fp8,
get_flash_attn_version,
is_fa_version_supported,
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
Expand All @@ -45,7 +47,6 @@
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import (
Expand Down Expand Up @@ -170,7 +171,13 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:

@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size % 8 == 0 and head_size <= 256
if head_size % 8 != 0:
return False
if head_size <= 256:
return True
if is_fa_version_supported(4):
return head_size <= 512
return False

@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
Expand Down Expand Up @@ -618,6 +625,14 @@ def __init__(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
# head_size > 256 requires FA4 on SM90+; force upgrade from FA3
if (
head_size > 256
and self.vllm_flash_attn_version == 3
and current_platform.is_cuda()
and current_platform.is_device_capability_family(90)
):
self.vllm_flash_attn_version = 4
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
Expand Down
7 changes: 0 additions & 7 deletions vllm/vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,7 @@ def flash_attn_varlen_func(
)
elif fa_version == 4:
assert alibi_slopes is None, "Alibi is not supported in FA4"
# FA4 on SM90 doesn't support paged KV; SM100+ does
from vllm.platforms import current_platform

if block_table is not None and current_platform.is_device_capability_family(90):
raise NotImplementedError(
"FA4 with paged KV is not supported on SM90 (Hopper). "
"Use FA3 or upgrade to Blackwell (SM100+)."
)
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd

out, softmax_lse = _flash_attn_fwd(
Expand Down
Loading