diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 48d1aacba185..8262495e7697 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -226,15 +226,21 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" - if selected_backend == _Backend.FLEX_ATTENTION: + elif selected_backend == _Backend.FLEX_ATTENTION: logger.info("Using FlexAttenion backend on V1 engine.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: logger.info_once("Using Triton backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "flash_attn.FlashAttentionBackend") + + # Default backends for V1 engine + # Prefer FlashInfer for Blackwell GPUs if installed if cls.is_device_capability(100): - # Prefer FlashInfer for V1 on Blackwell GPUs if installed try: import flashinfer # noqa: F401 logger.info_once( @@ -248,10 +254,13 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, "Blackwell (SM 10.0) GPUs; it is recommended to " "install FlashInfer for better performance.") pass - if cls.has_device_capability(80): + # FlashAttention is the default for SM 8.0+ GPUs + elif cls.has_device_capability(80): logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "flash_attn.FlashAttentionBackend") + + # Backends for V0 engine if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend"