From 7ea9bd7b50fb01439f0998f47b6a1471fdfadabc Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 11 Jun 2025 10:33:28 -0400 Subject: [PATCH 1/2] [Bugfix] Allow manual FlashAttention for Blackwell V1 Signed-off-by: mgoin --- vllm/platforms/cuda.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 48d1aacba185..db271572e886 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -226,15 +226,21 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" - if selected_backend == _Backend.FLEX_ATTENTION: + elif selected_backend == _Backend.FLEX_ATTENTION: logger.info("Using FlexAttenion backend on V1 engine.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: logger.info_once("Using Triton backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") + elif selected_backend in _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "flash_attn.FlashAttentionBackend") + + # Default backends for V1 engine + # Prefer FlashInfer for Blackwell GPUs if installed if cls.is_device_capability(100): - # Prefer FlashInfer for V1 on Blackwell GPUs if installed try: import flashinfer # noqa: F401 logger.info_once( @@ -248,10 +254,13 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, "Blackwell (SM 10.0) GPUs; it is recommended to " "install FlashInfer for better performance.") pass - if cls.has_device_capability(80): + # FlashAttention is the default for SM 8.0+ GPUs + elif cls.has_device_capability(80): logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "flash_attn.FlashAttentionBackend") + + # Backends for V0 engine if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" From f849458105f362af2502c43aa0fedd1a882a5058 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 11 Jun 2025 10:41:43 -0400 Subject: [PATCH 2/2] Fix Signed-off-by: mgoin --- vllm/platforms/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index db271572e886..8262495e7697 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -233,7 +233,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info_once("Using Triton backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") - elif selected_backend in _Backend.FLASH_ATTN: + elif selected_backend == _Backend.FLASH_ATTN: logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "flash_attn.FlashAttentionBackend")