diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 4f919e1e57c8..bdbe46ad9177 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -174,7 +174,7 @@ Priority is **1 = highest** (tried first). | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ✅ | ❌ | ✅ | All | ≥10.0 | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | -| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | +| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | | `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A | diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index b70902478e8f..1de6eb408ae2 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -35,6 +35,7 @@ AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, + MultipleOf, ) from vllm.v1.kv_cache_interface import AttentionSpec, EncoderOnlyAttentionSpec @@ -134,6 +135,10 @@ def use_cascade_attention(*args, **kwargs) -> bool: def get_supported_head_sizes(cls) -> list[int]: return [] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] + # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping(