diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 99c4b5cf5c3b..4c5e32340b9b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -27,7 +27,6 @@ from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.tree_attn import ( TreeAttentionMetadata, TreeAttentionMetadataBuilder, @@ -167,7 +166,12 @@ def __init__( # Determine allowed attention backends once during initialization. self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): - rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata + + rocm_types = [ + TritonAttentionMetadata, + RocmAttentionMetadata, + ] # ROCM_AITER_FA is an optional backend if find_spec( AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)