Skip to content
80 changes: 45 additions & 35 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from dataclasses import replace
from typing import Optional
from importlib.util import find_spec
from typing import Optional, Protocol

import numpy as np
import torch
Expand All @@ -20,8 +21,6 @@
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
Expand All @@ -34,6 +33,17 @@
PADDING_SLOT_ID = -1


class EagleAttentionMetadata(Protocol):
# Required attributes
num_actual_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor


class EagleProposer:

def __init__(
Expand Down Expand Up @@ -97,6 +107,20 @@ def __init__(
dtype=self.dtype,
device=device)

# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
rocm_types.append(AiterFlashAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types)
else:
self.allowed_attn_types = (FlashAttentionMetadata,
TreeAttentionMetadata)

# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
self.tree_choices: list[tuple[int,
Expand Down Expand Up @@ -165,7 +189,7 @@ def propose(
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
Expand Down Expand Up @@ -225,25 +249,13 @@ def propose(
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.

# On ROCm, both AiterFlashAttention and TritonAttention
# support multi-token eagle spec decode.
if current_platform.is_rocm():
assert isinstance(
attn_metadata,
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
FlashAttentionMetadata))
else:
# Currently, only FlashAttention supports multi-token eagle spec
# decode. This is because the code below makes assumptions about
# attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata)
assert isinstance(attn_metadata, self.allowed_attn_types)

# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]

if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
batch_size <= self.cudagraph_batch_sizes[-1]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITs, can you revert all of the unrelated changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi! @tjtanaa
These changes were included so I could pass the precommit. I've been trying to contribute to the project for a short time, and @mgoin told me that precommit normally had to be used:

https://marketplace.visualstudio.com/items?itemName=elagil.pre-commit-helper

https://github.com/pre-commit/pre-commit

https://github.com/vllm-project/vllm/blob/main/.github/workflows/pre-commit.yml

So that it would correctly format the file after the changes.

Sorry if this bothered you. Thank you very much for your time and dedication.

If you find that I have it configured incorrectly, please don't hesitate to let me know.

P.S.: If I remove the spaces and use precommit check again, I get an error, so I have to use the fix. It then adds the spaces back and leaves everything ok.

input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
Expand Down Expand Up @@ -449,7 +461,7 @@ def propose_tree(
num_tokens, -1)

if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_tokens)
else:
Expand Down Expand Up @@ -508,19 +520,19 @@ def prepare_inputs(
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]

device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
Expand Down Expand Up @@ -564,9 +576,9 @@ def prepare_inputs(
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)
Expand Down Expand Up @@ -615,20 +627,18 @@ def load_model(self, target_model: nn.Module) -> None:
target_language_model = target_model
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \
and self.model.model.embed_tokens.weight.shape \
== target_language_model.model.embed_tokens.weight.shape:
and self.model.model.embed_tokens.weight.shape \
== target_language_model.model.embed_tokens.weight.shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding" \
" with the target model."
)
"Assuming the EAGLE head shares the same vocab embedding"
" with the target model.")
del self.model.model.embed_tokens
self.model.model.embed_tokens = (
target_language_model.model.embed_tokens)
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately" \
" from the target model."
)
"The EAGLE head's vocab embedding will be loaded separately"
" from the target model.")

# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
Expand Down