From 4ea6b8529cf8078afe3e203de1ee28083a47a8d6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 10 Mar 2026 21:35:36 +0100 Subject: [PATCH] [Core] Remove FlashAttention block size restriction for hybrid models The restriction limiting FA block sizes to [16, 32, 64] for hybrid models with float32 Mamba cache is no longer needed. PR #35219 introduced KVBlockZeroer which zeros freshly allocated KV cache blocks, preventing NaN propagation from stale fp32 data in reused blocks. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/flash_attn.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 81d62629d85e..0ad211bda217 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -35,7 +35,6 @@ ) from vllm.config import ( VllmConfig, - get_current_vllm_config, get_current_vllm_config_or_none, get_layers_from_vllm_config, ) @@ -67,22 +66,6 @@ class FlashAttentionBackend(AttentionBackend): @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - vllm_config = get_current_vllm_config() - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - if ( - model_config - and model_config.is_hybrid - and ( - cache_config.mamba_ssm_cache_dtype == "float32" - or cache_config.mamba_cache_dtype == "float32" - ) - ): - # NOTE(tdoublep): while in principle, FA supports - # MultipleOf(16), these are the block sizes that do not - # suffer from the NaN propagation problem described here: - # https://github.com/Dao-AILab/flash-attention/issues/1974 - return [16, 32, 64] return [MultipleOf(16)] forward_includes_kv_cache_update: bool = False