diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0361dd3bd4ea..24a222bac501 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,8 +9,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] - class FlashAttentionBackend(AttentionBackend): @@ -60,6 +58,10 @@ def copy_blocks( value_caches = [kv_cache[1] for kv_cache in kv_caches] cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + @dataclass class FlashAttentionMetadata(AttentionMetadata): @@ -237,10 +239,11 @@ def __init__( # paged KV cache. raise ValueError( "Sliding window is not supported in FlashAttention.") - if head_size not in _SUPPORTED_HEAD_SIZES: + supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + f"Supported head sizes are: {supported_head_sizes}.") def forward( self, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5140c3cc86a3..0ac760f6fc79 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -34,9 +34,23 @@ def get_attn_backend( sliding_window, dtype, kv_cache_dtype, block_size) if backend == _Backend.FLASH_ATTN: - logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) + + # We check it here not in _which_attn_to_use because we cannot know + # the head size until we import FlashAttentionBackend. + flash_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in flash_head_sizes: + from vllm.attention.backends.xformers import ( # noqa: F401 + XFormersBackend) + logger.info( + "The model requires head size %d that's not " + "compatible with flash attention's head size %s. " + "Using XFormers backend instead.", head_size, + str(flash_head_sizes)) + return XFormersBackend + + logger.info("Using FlashAttention-2 backend.") return FlashAttentionBackend elif backend == _Backend.XFORMERS: logger.info("Using XFormers backend.")