Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 8 additions & 45 deletions vllm/model_executor/models/whisper_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@
CommonAttentionMetadata,
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend

try:
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
except ImportError:
AiterFlashAttentionBackend = None
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -211,36 +203,6 @@ def forward(
output_block_scale,
)

_SUPPORTED_BACKENDS = tuple(
b
for b in (
AiterFlashAttentionBackend,
FlashAttentionBackend,
RocmAttentionBackend,
TritonAttentionBackend,
)
if b is not None
)

if not issubclass(underlying_attn_backend, _SUPPORTED_BACKENDS):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
"Contributions to support more backends are much "
"appreciated."
)

if not issubclass(underlying_attn_backend, FlashAttentionBackend):
logger.info(
"Using %s for Whisper causal attention with block pooling. "
"This backend was recently enabled for this model. "
"If you encounter any accuracy or performance issues, "
"please open an issue at "
"https://github.com/vllm-project/vllm/issues "
"with the [ROCm] tag so it can be triaged by the "
"appropriate team.",
underlying_attn_backend.get_name(),
)

attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
Expand All @@ -251,13 +213,14 @@ def forward(
block_size,
num_kv_heads,
head_size,
cache_dtype_str: underlying_attn_backend.get_kv_cache_shape(
num_blocks,
# we stretch each block by `block_pool_size`
block_size * block_pool_size,
num_kv_heads // block_pool_size,
head_size,
cache_dtype_str,
cache_dtype_str: (
underlying_attn_backend.get_kv_cache_shape(
num_blocks,
block_size * block_pool_size,
num_kv_heads // block_pool_size,
head_size,
cache_dtype_str,
)
),
"forward_includes_kv_cache_update": True,
},
Expand Down
Loading