Skip to content
7 changes: 7 additions & 0 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def main():
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)

# print the generated text
for output in outputs:
print("-" * 50)
print(f"prompt: {output.prompt}")
print(f"generated text: {output.outputs[0].text}")
print("-" * 50)

if not hasattr(outputs, "metrics") or outputs.metrics is None:
return

Expand Down
27 changes: 18 additions & 9 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -52,11 +53,15 @@ def __init__(
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
Expand Down Expand Up @@ -109,6 +114,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:

# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and \
"embed_tokens." in name:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down Expand Up @@ -142,14 +153,12 @@ def forward(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
Comment on lines -145 to -146
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eagle model def doesnt have lm_head nor the weights to removed it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan Do you mean EAGLE1 doesn't have the LM head? I'm wondering because some EAGLE3 weights do include the LM head.

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EAGLE1 reuses the lm_head of target model for each spec step whereas EAGLE3 does not. For e.g.,

yuhuili/EAGLE-LLaMA3-Instruct-8B has these weights

Number of weights: 10
Key: layers.0.self_attn.q_proj.weight, Shape: torch.Size([4096, 4096]), Dtype: torch.float16
Key: layers.0.self_attn.k_proj.weight, Shape: torch.Size([1024, 4096]), Dtype: torch.float16
Key: layers.0.self_attn.v_proj.weight, Shape: torch.Size([1024, 4096]), Dtype: torch.float16
Key: layers.0.self_attn.o_proj.weight, Shape: torch.Size([4096, 4096]), Dtype: torch.float16
Key: layers.0.mlp.gate_proj.weight, Shape: torch.Size([14336, 4096]), Dtype: torch.float16
Key: layers.0.mlp.up_proj.weight, Shape: torch.Size([14336, 4096]), Dtype: torch.float16
Key: layers.0.mlp.down_proj.weight, Shape: torch.Size([4096, 14336]), Dtype: torch.float16
Key: layers.0.post_attention_layernorm.weight, Shape: torch.Size([4096]), Dtype: torch.float16
Key: embed_tokens.weight, Shape: torch.Size([128256, 4096]), Dtype: torch.float16
Key: fc.weight, Shape: torch.Size([4096, 8192]), Dtype: torch.float16

EAGLE1 sets the target lm_head as draft's lm_head here

EAGLE 3's lm_head is not the same as the target model. It's noted in this PR as well #16937 (comment)

skip_prefixes=None,
)

model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight

loader.load_weights(model_weights.items())
return loader.load_weights(model_weights.items())
15 changes: 10 additions & 5 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
Expand Down Expand Up @@ -91,11 +92,15 @@ def __init__(
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
Expand Down
24 changes: 19 additions & 5 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
Expand Down Expand Up @@ -302,12 +303,25 @@ def load_model(self, target_model: nn.Module) -> None:
self.attn_layer_name = next(iter(draft_attn_layer_names))
loaded_weights = self.model.load_weights(
loader.get_all_weights(draft_model_config, self.model))
if self.vllm_config.speculative_config.method == "eagle3":
if "model.embed_tokens.weight" not in loaded_weights:
logger.info(
"Loading EAGLE embedding weights from the target model.")
self.model.model.embed_tokens = target_model.model.embed_tokens

# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
assert "model.embed_tokens.weight" not in loaded_weights, \
"For PP = 1, Eagle draft should share embed with target model"
logger.info(
"Loading EAGLE embedding weights from the target model.")
self.model.model.embed_tokens = target_model.model.embed_tokens
else:
assert "model.embed_tokens.weight" in loaded_weights, \
"For PP > 0, Eagle draft checkpoint should its own copy of "
" the model.embed_tokens.weight"
logger.info("EAGLE embedding weights are already loaded.")

# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \
hasattr(target_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_model.lm_head

Expand Down