diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index 8e4322ea335d..455d23080168 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -30,14 +30,6 @@ CommonAttentionMetadata, subclass_attention_backend_with_overrides, ) -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend - -try: - from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend -except ImportError: - AiterFlashAttentionBackend = None -from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend -from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend from vllm.v1.attention.selector import get_attn_backend from vllm.v1.kv_cache_interface import AttentionSpec @@ -211,36 +203,6 @@ def forward( output_block_scale, ) - _SUPPORTED_BACKENDS = tuple( - b - for b in ( - AiterFlashAttentionBackend, - FlashAttentionBackend, - RocmAttentionBackend, - TritonAttentionBackend, - ) - if b is not None - ) - - if not issubclass(underlying_attn_backend, _SUPPORTED_BACKENDS): - raise NotImplementedError( - f"{underlying_attn_backend} is not yet supported." - "Contributions to support more backends are much " - "appreciated." - ) - - if not issubclass(underlying_attn_backend, FlashAttentionBackend): - logger.info( - "Using %s for Whisper causal attention with block pooling. " - "This backend was recently enabled for this model. " - "If you encounter any accuracy or performance issues, " - "please open an issue at " - "https://github.com/vllm-project/vllm/issues " - "with the [ROCm] tag so it can be triaged by the " - "appropriate team.", - underlying_attn_backend.get_name(), - ) - attn_backend = subclass_attention_backend_with_overrides( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, @@ -251,13 +213,14 @@ def forward( block_size, num_kv_heads, head_size, - cache_dtype_str: underlying_attn_backend.get_kv_cache_shape( - num_blocks, - # we stretch each block by `block_pool_size` - block_size * block_pool_size, - num_kv_heads // block_pool_size, - head_size, - cache_dtype_str, + cache_dtype_str: ( + underlying_attn_backend.get_kv_cache_shape( + num_blocks, + block_size * block_pool_size, + num_kv_heads // block_pool_size, + head_size, + cache_dtype_str, + ) ), "forward_includes_kv_cache_update": True, },