Skip to content

[Eagle] [Quantization] Add complete quantization support to the draft model in Eagle#27434

Closed
shreyas269 wants to merge 3 commits intovllm-project:mainfrom
capitalone-contributions:shreyas269-vllm-quantize-eagle
Closed

[Eagle] [Quantization] Add complete quantization support to the draft model in Eagle#27434
shreyas269 wants to merge 3 commits intovllm-project:mainfrom
capitalone-contributions:shreyas269-vllm-quantize-eagle

Conversation

@shreyas269
Copy link
Copy Markdown
Contributor

@shreyas269 shreyas269 commented Oct 23, 2025

Purpose

This PR adds comprehensive quantization support for Eagle and Eagle3 draft models in speculative decoding, including full KV cache quantization support. Previously, Eagle draft models could not use quantized weights in fully connected layers, or quantized KV caches.

Recently, #26590 was merged to properly obtain the draft model's quantization config but it doesn't address the case where the entire draft model is quantized and we want to read input and weight scales of fc layer along with KV cache quantization scales.

This PR addresses the following:

  • Define get_draft_quant_config in utils to avoid duplication of code in llama_eagle.py and llama_eagle3.py.
  • Add ReplicatedLinear class to make fc layer in drafters quantizable and handle input and weight quantization scales (in llama with eagle/eagle3).
  • Handle and load KV cache quantization scales. Additionally, attempts to remap it to the expected name format in the model.

Test Plan

Tested with a base llama3 instruct model with a quantized Eagle draft model (one decoder layer + one FC layer) with static fp8 quantization. The quantization of the base/verifier and Eagle draft model was performed using ModelOpt.

The non-quantized models work exactly the same as before (no changes to behavior).

Test Result

Before:
KeyError: 'fc.input_scale'

After:

(Worker_TP1 pid=1060248) WARNING 10-21 15:19:08 [modelopt.py:103] Detected ModelOpt fp8 checkpoint. Please note that the format is experimental and could change.
(Worker_TP1 pid=1060248) INFO 10-21 15:19:08 [default_loader.py:309] Loading weights took 0.03 seconds
(Worker_TP1 pid=1060248) INFO 10-21 15:19:08 [eagle.py:973] Assuming the EAGLE head shares the same vocab embedding with the target model.
(Worker_TP1 pid=1060248) INFO 10-21 15:19:08 [eagle.py:995] Loading EAGLE LM head weights from the target model.
(Worker_TP0 pid=1060247) WARNING 10-21 15:19:08 [modelopt.py:103] Detected ModelOpt fp8 checkpoint. Please note that the format is experimental and could change.
(Worker_TP0 pid=1060247) WARNING 10-21 15:19:08 [modelopt.py:103] Detected ModelOpt fp8 checkpoint. Please note that the format is experimental and could change.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 36.17it/s]

Acceptance rate of drafter:

Per-position acceptance rate: 0.830, 0.675, 0.452, 0.309, Avg Draft acceptance rate: 56.6%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added llama Related to Llama models speculative-decoding labels Oct 23, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds comprehensive quantization support for Eagle and Eagle3 draft models, which is a valuable feature. The implementation is solid, including the use of ReplicatedLinear for quantizable layers and the refactoring of get_draft_quant_config to reduce code duplication. My review identifies one area for improvement: there's a new block of duplicated code for handling KV cache scales in the load_weights methods of both llama_eagle.py and llama_eagle3.py. Extracting this into a shared utility would enhance the long-term maintainability of the code.

Comment on lines +126 to +143
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for handling KV cache quantization scales and remapping FP8 scale names is duplicated in vllm/model_executor/models/llama_eagle3.py. To improve maintainability and avoid potential bugs from inconsistent updates, this logic should be extracted into a shared utility function, perhaps in vllm.model_executor.model_loader.weight_utils. This would follow the same good practice you've already applied by refactoring get_draft_quant_config.

Comment on lines +220 to +237
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for handling KV cache quantization scales is duplicated from vllm/model_executor/models/llama_eagle.py. To improve maintainability, please consider refactoring this shared logic into a common utility function.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@shreyas269 shreyas269 changed the title Add complete quantization support to the draft model in Eagle [Eagle] [Quantization] Add complete quantization support to the draft model in Eagle Oct 23, 2025
@shreyas269 shreyas269 marked this pull request as draft October 23, 2025 21:05
Copy link
Copy Markdown
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only targeting kv cache quantization or are other quant schemes meant to be supported as well?

@dsikka
Copy link
Copy Markdown
Contributor

dsikka commented Oct 23, 2025

@rahul-tuli

@shreyas269
Copy link
Copy Markdown
Contributor Author

Is this only targeting kv cache quantization or are other quant schemes meant to be supported as well?

@dsikka, Along with kv cache quantization, I'd say it is also targeting FC layer quantization in llama based Eagle drafters. Currently, if we pass quant configs to the drafter, only the Decoder layer is quantized.

@shreyas269 shreyas269 marked this pull request as ready for review October 24, 2025 01:28
Copy link
Copy Markdown
Contributor

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting to this! The PR looks good to me, could you fix pre-commit, sign-off commits and add smoke tests for quantized eagle/eagle3 heads?

Signed-off-by: Shreyas Kulkarni <shreyas.gp269@gmail.com>
Signed-off-by: Shreyas Kulkarni <shreyas.gp269@gmail.com>
@shreyas269
Copy link
Copy Markdown
Contributor Author

Closing this in favor of it's duplicate #28435 which has a DCO sign-off and smoke tests.

@shreyas269 shreyas269 closed this Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants