Skip to content

Prevent overwriting drafters lm-head and embed_tokens#27737

Closed
eldarkurtic wants to merge 6 commits intovllm-project:mainfrom
eldarkurtic:fix-eagle3-drafter-init
Closed

Prevent overwriting drafters lm-head and embed_tokens#27737
eldarkurtic wants to merge 6 commits intovllm-project:mainfrom
eldarkurtic:fix-eagle3-drafter-init

Conversation

@eldarkurtic
Copy link
Contributor

@eldarkurtic eldarkurtic commented Oct 29, 2025

Some EAGLE3 drafters might have their own lm_head and/or embed_tokens layers. Existing codebase ignores this and always overrides them with target model's layers.

Test-case 1: for a model which needs copying of lm_head and embed_tokens from verifier model, the behavior should not change

Eval command to verify:

CUDA_VISIBLE_DEVICES=0,1 python examples/offline_inference/spec_decode.py \
  --method "eagle3" \
  --tp 2 \
  --model-dir "openai/gpt-oss-120b" \
  --eagle-dir "nvidia/gpt-oss-120b-Eagle3" \
  --dataset_name "hf" \
  --dataset_path "philschmid/mt-bench" \
  --num-spec-tokens 3

Before this PR:

--------------------------------------------------
total_num_output_tokens: 241888
num_drafts: 105426
num_draft_tokens: 316278
num_accepted_tokens: 136330
mean acceptance length: 2.29
--------------------------------------------------
acceptance at token 0: 0.65
acceptance at token 1: 0.40
acceptance at token 2: 0.24

After this PR:

--------------------------------------------------
total_num_output_tokens: 241168
num_drafts: 105678
num_draft_tokens: 317034
num_accepted_tokens: 135362
mean acceptance length: 2.28
--------------------------------------------------
acceptance at token 0: 0.65
acceptance at token 1: 0.39
acceptance at token 2: 0.24

Test-case 2: for a model with has its own lm_head and embed_tokens, and therefore does not require copying from target's layers, acceptance rates look significantly better

Before this PR:

--------------------------------------------------
total_num_output_tokens: 247973
num_drafts: 187566
num_draft_tokens: 562698
num_accepted_tokens: 59726
mean acceptance length: 1.32
--------------------------------------------------
acceptance at token 0: 0.23
acceptance at token 1: 0.07
acceptance at token 2: 0.02

After this PR:

--------------------------------------------------
total_num_output_tokens: 247974
num_drafts: 99354
num_draft_tokens: 298062
num_accepted_tokens: 148677
mean acceptance length: 2.50
--------------------------------------------------
acceptance at token 0: 0.70
acceptance at token 1: 0.48
acceptance at token 2: 0.31

Note: idea for lm_head check inspired by #27688

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a more robust mechanism for handling the lm_head and embed_tokens layers in EAGLE3 drafter models. By replacing the fragile shape-matching heuristic with explicit flags set during weight loading, the change correctly prevents overwriting these layers when the drafter model provides its own. This is a significant improvement for correctness and maintainability. My review includes a suggestion to further strengthen the flag-setting logic to prevent potential edge cases.

Copy link
Contributor

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 29, 2025
@eldarkurtic
Copy link
Contributor Author

eldarkurtic commented Oct 30, 2025

moved all attributes into SupportsEagle3 interface. Could you please re-review? @NickLucche @rahul-tuli @dsikka

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) October 30, 2025 13:40
Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @eldarkurtic ! Let's get CI green real quick before merging this one

# To prevent overriding with target model's layers
if "lm_head" in name:
self.has_own_lm_head = True
if "embed_tokens" in name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a change from the default? I thought that EAGLE3 heads usually share the embedding with the base model?

Copy link
Contributor Author

@eldarkurtic eldarkurtic Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eagle3 by default doesn’t train these layers. But there is no reason not to train them. This doesn’t affect the standard eagle3 flow, just extends it to support this new use case

@benchislett
Copy link
Collaborator

There are other EAGLE1-based models which do not have the EAGLE3 mixin, causing the CI failures. Please update the logic to cover EAGLE1 as well

logger.info(
"Assuming the EAGLE head shares the same vocab embedding"
" with the target model."
"Draft model embed_tokens are uninitialized. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was originally done as a memory optimization since the original EAGLE3 models are released with an embedding layer included, but having the same weights as the base model.

Ideally we would have another check here to delete them if they are present but identical to those of the base model, but that's tricky to implement cleanly. I'm fine with doing it this way for now, but it should be noted that the behaviour is changing

@eldarkurtic
Copy link
Contributor Author

eldarkurtic commented Oct 30, 2025

Do they use some other mixin similar to SupportsEagle3?

@@ -922,6 +922,16 @@ class SupportsEagle3(Protocol):
MRO of your model class.
"""

has_own_lm_head: ClassVar[bool] = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not class variables, they are instance variables that are set per-model based on the weight loading.

https://typing.python.org/en/latest/spec/class-compat.html

@benchislett
Copy link
Collaborator

The proper fix here is to add a base SupportsEagle mixin and have SupportsEagle3 inherit from that. Then the other EAGLE classes can inherit from the new base mixin that will give them a reasonable default.

@@ -328,6 +328,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
includes_embed_tokens = True
model_weights[name] = loaded_weight

# To prevent overriding with target model's layers
if "lm_head" in name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be added to all the EAGLE classes, not just llama_eagle3.

You might want to refactor this into the mixin to reuse the code between them.

@hjjq
Copy link
Contributor

hjjq commented Nov 6, 2025

Hi @eldarkurtic @robertgshaw2-redhat , just wondering what is the status of this PR? Are we close to getting it merged? Thanks!

auto-merge was automatically disabled November 12, 2025 10:04

Head branch was pushed to by a user without write access

@mergify mergify bot added the deepseek Related to DeepSeek models label Nov 12, 2025
@mergify
Copy link

mergify bot commented Nov 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @eldarkurtic.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Nov 12, 2025

Documentation preview: https://vllm--27737.org.readthedocs.build/en/27737/

@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models nvidia labels Nov 12, 2025
@mergify mergify bot added rocm Related to AMD ROCm tpu Related to Google TPUs kv-connector labels Nov 12, 2025
eldarkurtic and others added 4 commits November 12, 2025 10:12
…ialized

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
…ration (vllm-project#27670)

Signed-off-by: KevinCheung2259 <2651309292@qq.com>
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
@eldarkurtic eldarkurtic force-pushed the fix-eagle3-drafter-init branch from e69c88f to f766651 Compare November 12, 2025 10:12
@mergify mergify bot removed the tpu Related to Google TPUs label Nov 12, 2025
@eldarkurtic
Copy link
Contributor Author

Closed in favor of a slightly cleaner approach in #28549

@github-project-automation github-project-automation bot moved this from To Triage to Done in gpt-oss Issues & Enhancements Nov 12, 2025
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Nov 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) needs-rebase new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants