diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index 1a615878bb8b..17eb20a29c98 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -83,8 +83,14 @@ class Eagle3ModelConfig: def get_available_attention_backends() -> list[str]: - if not hasattr(current_platform, "get_valid_backends"): - return ["FLASH_ATTN"] + # Check if get_valid_backends is actually defined in the platform class + # (not just returning None from __getattr__) + get_valid_backends = getattr(current_platform.__class__, "get_valid_backends", None) + if get_valid_backends is None: + if current_platform.is_rocm(): + return ["TRITON_ATTN"] + else: + return ["FLASH_ATTN"] device_capability = current_platform.get_device_capability() if device_capability is None: