From 09edaab13b180a82614ed98cf8d2d11e46d9912c Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Tue, 16 Dec 2025 18:07:39 +0000 Subject: [PATCH 1/2] update attn backend for async scheduling without spec decode test Signed-off-by: Micah Williamson --- tests/v1/e2e/test_async_scheduling.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 5cef9b33c998..5af5dd0bf72c 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -149,7 +149,7 @@ def run_tests( # Use TRITON_ATTN for spec decoding test for consistency m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") else: - m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA") + m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN") else: m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") # lock matmul precision to full FP32 (IEEE) @@ -281,14 +281,6 @@ def run_test( print(f"---- TESTING {test_str}: {test_config}") print("-" * 80) - # On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for - # spec decoding test (TRITON_ATTN) for better precision. - # On others: always use float32. - if current_platform.is_rocm() and not is_testing_with_spec_decoding: - dtype = "float16" - else: - dtype = "float32" - with VllmRunner( model, max_model_len=512, @@ -298,7 +290,7 @@ def run_test( # enforce_eager=True, async_scheduling=async_scheduling, distributed_executor_backend=executor, - dtype=dtype, + dtype="float32", speculative_config=spec_config, disable_log_stats=False, **cache_arg, From 68b19dffc67a00fe41d4a85c8d7bbf95cacc38e8 Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Wed, 17 Dec 2025 02:20:22 +0000 Subject: [PATCH 2/2] update supported dtypes for ROCM_ATTN Signed-off-by: Micah Williamson --- vllm/v1/attention/backends/rocm_attn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index e2410a70b1a6..441feafa2ad5 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -152,7 +152,11 @@ def build( class RocmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] @classmethod def get_supported_head_sizes(cls) -> list[int]: