diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 75f1f040d764..72abef497375 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,7 +3,7 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import Optional +from typing import Optional, Protocol import numpy as np import torch @@ -33,6 +33,17 @@ PADDING_SLOT_ID = -1 +class EagleAttentionMetadata(Protocol): + # Required attributes + num_actual_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + class EagleProposer: def __init__( @@ -97,7 +108,7 @@ def __init__( device=device) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type, ...] = () + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend @@ -240,10 +251,6 @@ def propose( # there's a multi-layer MTP module. assert isinstance(attn_metadata, self.allowed_attn_types) - # The mypy errors are caused because mypy cannot infer the type of - # attn_metadata. We add this assert to help mypy. - assert isinstance(attn_metadata, FlashAttentionMetadata) - # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids]