Skip to content

[Bugfix] Fix the acceptance rates dorp issue when applying eagle3 to QuaRot model#6914

Merged
MengqingCao merged 1 commit intovllm-project:mainfrom
zhaomingyu13:main
Mar 4, 2026
Merged

[Bugfix] Fix the acceptance rates dorp issue when applying eagle3 to QuaRot model#6914
MengqingCao merged 1 commit intovllm-project:mainfrom
zhaomingyu13:main

Conversation

@zhaomingyu13
Copy link
Copy Markdown
Contributor

@zhaomingyu13 zhaomingyu13 commented Mar 2, 2026

What this PR does / why we need it?

When using the target model after rotational quantization, the acceptance rate decreases because the fc weight of the draft model has not undergone rotational quantization(issue: #6445). We fixed this issue by performing rotation quantization on the fc weight of the draft model in the same way as the main model when loading draft model.

Does this PR introduce any user-facing change?

The bug was previously resolved by using the tool in the Issue(#5974). If your version has already adapted to this PR, please use the original eagle3 weights matching target model quantified by new version modelslim and do not use that tool.

How was this patch tested?

import gc
import torch
import os
from vllm.v1.metrics.reader import Counter, Vector
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import (destroy_distributed_environment,
                                             destroy_model_parallel)

os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"

TP = 2
K = 4
QUANTIZATION = True
# QUANTIZATION = False
MODEL_PATH = "Qwen/Qwen3-32B-w8a8-quarot-eagle3-new" if QUANTIZATION else "Qwen/Qwen3-32B"

def clean_up():
    destroy_model_parallel()
    destroy_distributed_environment()
    gc.collect()
    torch.npu.empty_cache()

if __name__ == '__main__':
    prompts = [
        "Who are you?",
    ]
    sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=40, ignore_eos=True, max_tokens=200)

    llm_kwargs = dict(
        disable_log_stats=False,
        model=MODEL_PATH,
        tensor_parallel_size=TP,
        enforce_eager=True,
        gpu_memory_utilization=0.9,
        speculative_config={
            "model": "RedHatAI/Qwen3-32B-speculator.eagle3",
            "method": "eagle3",
            "num_speculative_tokens": K,
        },
        max_model_len=4096, 
        enable_prefix_caching=False,
        compilation_config={
            "cudagraph_mode": "FULL_DECODE_ONLY",
            "cudagraph_capture_sizes": [4],
        },)
    
    if QUANTIZATION:
        llm_kwargs["quantization"] = "ascend"

    llm = LLM(**llm_kwargs)

    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        # print(output.outputs[0].token_ids)
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    total_num_output_tokens = sum(
        len(output.outputs[0].token_ids) for output in outputs
    )

    metrics = llm.get_metrics()
    print('metrics------------------')
    num_drafts = 0
    num_draft_tokens = 0
    num_accepted_tokens = 0
    acceptance_counts = [0] * K
    for metric in metrics:
        if metric.name == "vllm:spec_decode_num_drafts":
            assert isinstance(metric, Counter)
            num_drafts += metric.value
        elif metric.name == "vllm:spec_decode_num_draft_tokens":
            assert isinstance(metric, Counter)
            num_draft_tokens += metric.value
        elif metric.name == "vllm:spec_decode_num_accepted_tokens":
            assert isinstance(metric, Counter)
            num_accepted_tokens += metric.value
        elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
            assert isinstance(metric, Vector)
            for pos in range(len(metric.values)):
                acceptance_counts[pos] += metric.values[pos]

    print("-" * 50)
    print(f"total_num_output_tokens: {total_num_output_tokens}")
    print(f"num_drafts: {num_drafts}")
    print(f"num_draft_tokens: {num_draft_tokens}")
    print(f"num_accepted_tokens: {num_accepted_tokens}")
    acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
    print(f"mean acceptance length: {acceptance_length:.2f}")
    print("-" * 50)

    # print acceptance at each token position
    for i in range(len(acceptance_counts)):
        acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
        print(f"acceptance at token {i}: {acceptance_rate:.2f}")

    del llm
    clean_up()

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a performance degradation observed in speculative decoding setups involving Eagle3 and QuaRot quantized models. By ensuring that the draft model's fully connected layer weights are correctly rotationally quantized, the change aims to restore expected acceptance rates and maintain the efficiency of the quantized model inference pipeline.

Highlights

  • Bugfix for Acceptance Rates Drop: Addressed an issue where the acceptance rates decreased when applying Eagle3 speculative decoding to QuaRot quantized models. This was due to the draft model's 'fc' weights not undergoing rotational quantization.
  • Rotational Quantization for Draft Model: Implemented rotational quantization for the 'fc' weights of the draft model, ensuring consistency with the main model during loading. This is applied specifically when a quantization configuration is present.
  • New Patch Module: Introduced a new patch module, patch_qwen3_quarot.py, which modifies the Eagle3LlamaForCausalLM.load_weights method to incorporate the rotational quantization logic.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • vllm_ascend/patch/worker/init.py
    • Imported the new patch_qwen3_quarot module to enable its functionality.
  • vllm_ascend/patch/worker/patch_qwen3_quarot.py
    • Added a new file containing patch_load_weights and make_load_weights functions.
    • Implemented logic to load the global_rotation matrix from the model path.
    • Modified the load_weights method for Eagle3LlamaForCausalLM to apply rotational quantization to 'fc' weights if present.
  • vllm_ascend/worker/model_runner_v1.py
    • Imported patch_load_weights from the new patch_qwen3_quarot module.
    • Conditionally applied patch_load_weights to the drafter model during its loading process, specifically when a quantization configuration is detected.
Activity
  • The author provided a detailed Python script demonstrating how the patch was tested, including setup for a quantized Qwen3-32B model with Eagle3 speculative decoding and metrics for acceptance rates.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 2, 2026

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug causing a drop in acceptance rates for speculative decoding when using a rotationally quantized (QuaRot) model with an Eagle3 draft model. The fix involves applying the same rotational quantization to the draft model's fully-connected layer weights during loading. The approach is correct, but I've identified two critical areas for improvement to enhance robustness: adding error handling for loading the rotation matrix file and making the layer selection for weight rotation more specific to prevent accidentally modifying incorrect layers.

As per the repository's style guide, here is a suggested title for the pull request:

Suggested PR Title:

[Quantization][BugFix] Fix acceptance rate drop issue when applying eagle3 to QuaRot model

Comment thread vllm_ascend/patch/worker/patch_qwen3_quarot.py Outdated
Comment thread vllm_ascend/patch/worker/patch_qwen3_quarot.py Outdated
@zhenwenqi2024 zhenwenqi2024 added ready read for review ready-for-test start test by label for PR labels Mar 2, 2026
@zhaomingyu13 zhaomingyu13 force-pushed the main branch 6 times, most recently from 46fbb0b to 963e8b2 Compare March 2, 2026 09:20
@zhaomingyu13 zhaomingyu13 force-pushed the main branch 4 times, most recently from bf303e2 to 7cf15d6 Compare March 3, 2026 06:52
…QuaRot model

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
# Dynamically replace the `load_weights` function at runtime,
# and fix `target_config` into the new implementation with a closure.
# Future Plan:
# Remove this patch when vLLM merges the PR.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you just put the pr link here?

@MengqingCao MengqingCao merged commit 52d9086 into vllm-project:main Mar 4, 2026
27 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Mar 5, 2026
…to qwen3next_graph

* 'main' of https://github.com/vllm-project/vllm-ascend: (40 commits)
  [Feature] Add docs of batch invariance and make some extra operators patch (vllm-project#6910)
  [bugfix]Qwen2.5VL accurate question (vllm-project#6975)
  [CI] Add DeepSeek-V3.2 large EP nightly ci (vllm-project#6378)
  [Ops][BugFix] Fix RoPE shape mismatch for mtp models with flashcomm v1 enabled (vllm-project#6939)
  [bugfix]fix file not found error in nightly of single-node (vllm-project#6976)
  [Bugfix] Fix the acceptance rates dorp issue when applying eagle3 to QuaRot model (vllm-project#6914)
  [CI] Enable auto upgrade e2e estimated time for auto-partition suites (vllm-project#6840)
  [Doc][Misc] Fix msprobe_guide.md documentation issues (vllm-project#6965)
  [Nightly][Refactor]Migrate nightly single-node model tests from `.py` to `.yaml` (vllm-project#6503)
  [BugFix] Improve GDN layer detection for multimodal models (vllm-project#6941)
  [feat]ds3.2 pcp support mtp and chunkprefill (vllm-project#6917)
  [CPU binding] Implement global CPU slicing and improve IRQ binding for Ascend NPUs (vllm-project#6945)
  [Triton] Centralize Ascend extension op dispatch in triton_utils (vllm-project#6937)
  [csrc][bugfix] Add compile-time Ascend950/910_95 compatibility for custom ops between CANN8.5 and 9.0 (vllm-project#6936)
  [300I][Bugfix] fix unquant model weight nd2nz error (vllm-project#6851)
  [doc] fix supported_models (vllm-project#6930)
  [CI] nightly test timeout (vllm-project#6912)
  [CI] Upgrade CANN to 8.5.1 (vllm-project#6897)
  [Model]Add Qwen3-Omni quantization Ascend NPU adaptation and optimization (vllm-project#6828)
  [P/D][v0.16.0]Adapt to RecomputeScheduler in vLLM 0.16.0 (vllm-project#6898)
  ...
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…QuaRot model (vllm-project#6914)

### What this PR does / why we need it?
When using the target model after rotational quantization, the
acceptance rate decreases because the fc weight of the draft model has
not undergone rotational quantization(issue: vllm-project#6445). We fixed this issue
by performing rotation quantization on the fc weight of the draft model
in the same way as the main model when loading draft model.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
MengqingCao pushed a commit that referenced this pull request Mar 16, 2026
### What this PR does / why we need it?
Add an e2e test for QuaRot model with eagle3 that runs both the QuaRot
model and the float model, and then compares their acceptance rates. The
QuaRot model adapting eagle3 PR(#6914, #7038)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Nagisa125 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 17, 2026
…ct#7128)

### What this PR does / why we need it?
Add an e2e test for QuaRot model with eagle3 that runs both the QuaRot
model and the float model, and then compares their acceptance rates. The
QuaRot model adapting eagle3 PR(vllm-project#6914, vllm-project#7038)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm-ascend that referenced this pull request Mar 19, 2026
…QuaRot model (vllm-project#6914)

When using the target model after rotational quantization, the
acceptance rate decreases because the fc weight of the draft model has
not undergone rotational quantization(issue: vllm-project#6445). We fixed this issue
by performing rotation quantization on the fc weight of the draft model
in the same way as the main model when loading draft model.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm-ascend that referenced this pull request Mar 19, 2026
…el (vllm-project#6914)

Cherry-pick from upstream main 52d9086.
Perform rotation quantization on the fc weight of the draft model
in the same way as the main model when loading draft model.
liuchenbing2026 pushed a commit to liuchen20/vllm-ascend that referenced this pull request Mar 24, 2026
…QuaRot model (vllm-project#6914)

When using the target model after rotational quantization, the
acceptance rate decreases because the fc weight of the draft model has
not undergone rotational quantization(issue: vllm-project#6445). We fixed this issue
by performing rotation quantization on the fc weight of the draft model
in the same way as the main model when loading draft model.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
liuchenbing2026 pushed a commit to liuchen20/vllm-ascend that referenced this pull request Mar 24, 2026
…QuaRot model (vllm-project#6914)

When using the target model after rotational quantization, the
acceptance rate decreases because the fc weight of the draft model has
not undergone rotational quantization(issue: vllm-project#6445). We fixed this issue
by performing rotation quantization on the fc weight of the draft model
in the same way as the main model when loading draft model.

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Apr 1, 2026
…ct#7128)

### What this PR does / why we need it?
Add an e2e test for QuaRot model with eagle3 that runs both the QuaRot
model and the float model, and then compares their acceptance rates. The
QuaRot model adapting eagle3 PR(vllm-project#6914, vllm-project#7038)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants