Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)

_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -60,6 +58,10 @@ def copy_blocks(
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]


@dataclass
class FlashAttentionMetadata(AttentionMetadata):
Expand Down Expand Up @@ -237,10 +239,11 @@ def __init__(
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")
if head_size not in _SUPPORTED_HEAD_SIZES:
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
f"Supported head sizes are: {supported_head_sizes}.")

def forward(
self,
Expand Down
16 changes: 15 additions & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,23 @@ def get_attn_backend(
sliding_window, dtype, kv_cache_dtype,
block_size)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

# We check it here not in _which_attn_to_use because we cannot know
# the head size until we import FlashAttentionBackend.
flash_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in flash_head_sizes:
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
logger.info(
"The model requires head size %d that's not "
"compatible with flash attention's head size %s. "
"Using XFormers backend instead.", head_size,
str(flash_head_sizes))
return XFormersBackend

logger.info("Using FlashAttention-2 backend.")
return FlashAttentionBackend
elif backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
Expand Down