diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a8a1d59f1bf0..6df47767a9f3 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -321,6 +321,11 @@ def get_attn_backend_cls( return AttentionBackendEnum.ROCM_ATTN.get_path() if selected_backend == AttentionBackendEnum.ROCM_AITER_FA: + if attn_selector_config.has_sink: + raise ValueError( + f"The selected backend, {selected_backend.name}, " + "does not support sinks." + ) if on_gfx9(): logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() @@ -343,7 +348,12 @@ def get_attn_backend_cls( # Priority 2: Check for AITER MHA (Flash Attention) # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1) - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + if ( + envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_MHA + and on_gfx9() + and not attn_selector_config.has_sink + ): logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() @@ -363,7 +373,8 @@ def get_attn_backend_cls( if ( envs.VLLM_ROCM_USE_AITER and on_gfx9() - and envs.VLLM_ROCM_USE_AITER_MHA is not False + and envs.VLLM_ROCM_USE_AITER_MHA + and not attn_selector_config.has_sink ): logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path()