diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 526465ee5886..d10819c0dbe2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -40,6 +40,7 @@ get_device, get_device_memory_capacity, get_device_sm, + is_blackwell, is_blackwell_supported, is_cuda, is_fa3_default_architecture, @@ -1283,7 +1284,8 @@ def _handle_attention_backend_compatibility(self): 1. Models with MHA Architecture (e.g: Llama, QWen) 1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1. - 1.2 In other cases, we will use flashinfer if available, otherwise use triton. + 1.2 Use trtllm_mha for Blackwell excluding spec with topk > 1. + 1.3 In other cases, we will use flashinfer if available, otherwise use triton. 2. Models with MLA Architecture and using FA3 2.1 We will use FA3 backend on hopper. 2.2 We will use Flashinfer backend on blackwell. @@ -1298,6 +1300,8 @@ def _handle_attention_backend_compatibility(self): and is_fa3_default_architecture(self.model_config.hf_config) ): self.attention_backend = "fa3" + elif is_blackwell() and is_no_spec_infer_or_topk_one(self): + self.attention_backend = "trtllm_mha" elif is_hip(): self.attention_backend = "aiter" elif is_npu(): diff --git a/test/srt/test_flash_attention_4.py b/test/srt/test_flash_attention_4.py index 4322263c459e..44623a132c3a 100644 --- a/test/srt/test_flash_attention_4.py +++ b/test/srt/test_flash_attention_4.py @@ -22,6 +22,8 @@ def setUpClass(cls): "0.8", "--prefill-attention-backend", "fa4", + "--decode-attention-backend", + "flashinfer", ] cls.process = popen_launch_server( cls.model,