Skip to content
7 changes: 7 additions & 0 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def main():
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)

# print the generated text
for output in outputs:
print("-" * 50)
print(f"prompt: {output.prompt}")
print(f"generated text: {output.outputs[0].text}")
print("-" * 50)

if not hasattr(outputs, "metrics") or outputs.metrics is None:
return

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def forward(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
skip_prefixes=(["lm_head.", "model.embed_tokens."]
if self.config.tie_word_embeddings else None),
)

Expand Down
17 changes: 9 additions & 8 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,18 @@ def load_model(self, target_model: nn.Module) -> None:
model_config=draft_model_config,
start_layer_id=target_layer_num).to(target_device)

loaded_weights = self.model.load_weights(
_ = self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
if self.vllm_config.speculative_config.method == "eagle3":
if "model.embed_tokens.weight" not in loaded_weights:
logger.info(
"Loading EAGLE embedding weights from the target model.")
self.model.model.embed_tokens = target_model.model.embed_tokens
else:
logger.info("Loading EAGLE LM head weights from the target model.")

# EAGLE-1/3 reuses the same embedding layer for draft and target models
self.model.model.embed_tokens = target_model.model.embed_tokens

# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \
hasattr(self.model, "lm_head"):
self.model.lm_head = target_model.lm_head


Expand Down