Skip to content

[main][feature] Support quarot for eagle3 without embedding#7038

Merged
MengqingCao merged 1 commit intovllm-project:mainfrom
drslark:quarot
Mar 9, 2026
Merged

[main][feature] Support quarot for eagle3 without embedding#7038
MengqingCao merged 1 commit intovllm-project:mainfrom
drslark:quarot

Conversation

@drslark
Copy link
Copy Markdown
Contributor

@drslark drslark commented Mar 6, 2026

What this PR does / why we need it?

If some eagle3 model without embed_tokens works with quarot target model, the acceptence rate will drop.

We solve it in this PR.

The relative vllm pr is vllm-project/vllm#36225.

Does this PR introduce any user-facing change?

N/A

How was this patch tested?

Without this change,

Qwen3-32B-w8a8-quarot + Qwen3-32B-speculator.eagle3.

--------------------------------------------------
total_num_output_tokens: 200
num_drafts: 186
num_draft_tokens: 744
num_accepted_tokens: 13
mean acceptance length: 1.07
--------------------------------------------------
acceptance at token 0: 0.07
acceptance at token 1: 0.00
acceptance at token 2: 0.00
acceptance at token 3: 0.00

With this change:

--------------------------------------------------
total_num_output_tokens: 200
num_drafts: 65
num_draft_tokens: 260
num_accepted_tokens: 136
mean acceptance length: 3.09
--------------------------------------------------
acceptance at token 0: 0.74
acceptance at token 1: 0.51
acceptance at token 2: 0.43
acceptance at token 3: 0.42

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 6, 2026

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an issue where eagle3 models, particularly those without embedded tokens, experienced a decrease in acceptance rate when integrated with quarot target models. The changes involve refactoring the quarot patching logic to ensure accurate weight loading and the correct application of anti-rotation, thereby maintaining model performance and compatibility across different configurations.

Highlights

  • Refactored Quarot Patching: The quarot patching mechanism has been refactored by renaming the patch file from patch_qwen3_quarot.py to patch_draft_quarot.py, indicating a broader applicability beyond just Qwen3 models.
  • Dynamic Weight Loading for Eagle3: A new patch_draft_quarot.py module was introduced to dynamically modify the Eagle3LlamaForCausalLM.load_weights method. This enhancement specifically addresses eagle3 models that might lack embed_tokens when used with quarot target models.
  • Improved Quarot Model Compatibility: The weight loading process now correctly applies anti-rotation to fc layers and dynamically loads embedding tensors if embed_tokens are not explicitly present in the model weights, preventing acceptance rate drops.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • vllm_ascend/patch/init.py
    • Updated the internal reference to the quarot patch file name.
    • Added a related PR link for the quarot patch documentation.
  • vllm_ascend/patch/worker/init.py
    • Updated the import statement to reference the newly named patch_draft_quarot module.
  • vllm_ascend/patch/worker/patch_draft_quarot.py
    • Added a new module containing functions to retrieve embedding tensors, determine rotation paths, load rotation matrices, compute block-diagonal rotation matrices, and patch the Eagle3LlamaForCausalLM.load_weights method for quarot support.
  • vllm_ascend/patch/worker/patch_qwen3_quarot.py
    • Removed the deprecated qwen3_quarot patch file.
  • vllm_ascend/worker/model_runner_v1.py
    • Updated the import statement to use the new patch_draft_quarot module instead of the old patch_qwen3_quarot.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the patch for quarot quantized models and adds support for eagle3 draft models that do not have their own embed_tokens. The logic correctly loads the embedding from the target model and applies the necessary anti-rotation.

I've found a couple of critical issues in the new patch_draft_quarot.py file that could lead to runtime errors, related to uninitialized variables and lack of None checks. I've also pointed out some typos that affect readability. Please see my detailed comments.

Comment on lines +92 to +141
def make_load_weights(target_model_path, rotation_path):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
Q = get_rotataion_matrix(rotation_path)
Q3 = compute_rotataion_matrix3(Q)
dtype = None

model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
elif "fc." in name:
# anti-rotate fc
dtype = loaded_weight.dtype
loaded_weight = loaded_weight @ Q3.to(dtype)
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)

# process embedding if drafter does not have embedding
if not includes_embed_tokens:
name = "model.embed_tokens.weight"
loaded_weight = get_embedding_tensor(target_model_path).to(dtype) @ Q.T.to(dtype)
model_weights[name] = loaded_weight

includes_embed_tokens = True
process_eagle_weight(self, name)

skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
if not self.model.use_aux_hidden_state:
skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())

return load_weights
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function has two critical issues that could lead to a crash:

  1. The dtype variable is initialized to None and only set if a weight name contains "fc.". If no such weight exists, dtype remains None, causing a TypeError on line 121 when .to(dtype) is called.
  2. The function get_embedding_tensor can return None, but the return value is used without a check on line 121, which would cause an AttributeError when .to(dtype) is called on None.

I suggest initializing dtype from self.dtype at the beginning of the function and adding a None check for the result of get_embedding_tensor.

def make_load_weights(target_model_path, rotation_path):
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        Q = get_rotataion_matrix(rotation_path)
        Q3 = compute_rotataion_matrix3(Q)
        dtype = self.dtype

        model_weights = {}
        includes_draft_id_mapping = False
        includes_embed_tokens = False
        for name, loaded_weight in weights:
            if "t2d" in name:
                continue
            if "d2t" in name:
                name = name.replace("d2t", "draft_id_to_target_id")
                includes_draft_id_mapping = True
            elif "lm_head" not in name:
                name = "model." + name
            elif "fc." in name:
                # anti-rotate fc
                loaded_weight = loaded_weight @ Q3.to(dtype)
            if "embed_tokens" in name:
                includes_embed_tokens = True
            model_weights[name] = loaded_weight
            process_eagle_weight(self, name)

        # process embedding if drafter does not have embedding
        if not includes_embed_tokens:
            embedding_tensor = get_embedding_tensor(target_model_path)
            if embedding_tensor is not None:
                name = "model.embed_tokens.weight"
                loaded_weight = embedding_tensor.to(dtype) @ Q.T.to(dtype)
                model_weights[name] = loaded_weight

                includes_embed_tokens = True
                process_eagle_weight(self, name)
            else:
                logger.warning(
                    f"Could not find embedding tensor in {target_model_path} "
                    "to patch draft model."
                )

        skip_substrs = []
        if not includes_draft_id_mapping:
            skip_substrs.append("draft_id_to_target_id")
        if not includes_embed_tokens:
            skip_substrs.append("embed_tokens")
        if not self.model.use_aux_hidden_state:
            skip_substrs.append("fc.")
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=None,
            skip_substrs=skip_substrs,
        )
        loader.load_weights(model_weights.items())

    return load_weights

Comment on lines +18 to +21
"""
Scans the directory and returns the first tensor found that contains 'embed' in its key.
Returns a tuple of (key_name, tensor) or (None, None) if not found.
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring states that this function returns a tuple (key_name, tensor) or (None, None). However, the implementation returns only the tensor or None. The docstring should be updated to match the implementation.

Suggested change
"""
Scans the directory and returns the first tensor found that contains 'embed' in its key.
Returns a tuple of (key_name, tensor) or (None, None) if not found.
"""
"""
Scans the directory and returns the first tensor found that contains 'embed' in its key.
Returns the tensor if found, otherwise None.
"""

Comment on lines +57 to +78
def get_rotataion_matrix(rotation_path):
"""
Anti-rotate maxtrix.
"""
try:
safetensor_data = load_file(rotation_path)
Q = safetensor_data["global_rotation"]

return Q
except Exception as e:
logger.error(
f"Failed to load rotation weight from '{rotation_path}'. "
"If you want to use quarot model with eagle3, take a check."
)
raise e


def compute_rotataion_matrix3(Q):
"""
Anti-rotate matrix for 3 layers of hidden_states.
"""
return torch.block_diag(Q, Q, Q)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a few typos in this block that affect readability and maintainability:

  • On line 57, get_rotataion_matrix should be get_rotation_matrix.
  • On line 59, maxtrix should be matrix.
  • On line 74, compute_rotataion_matrix3 should be compute_rotation_matrix3.

Please correct these. Remember to also update the function calls on lines 94 and 95.

def get_rotation_matrix(rotation_path):
    """
    Anti-rotate matrix.
    """
    try:
        safetensor_data = load_file(rotation_path)
        Q = safetensor_data["global_rotation"]

        return Q
    except Exception as e:
        logger.error(
            f"Failed to load rotation weight from '{rotation_path}'. "
            "If you want to use quarot model with eagle3, take a check."
        )
        raise e


def compute_rotation_matrix3(Q):
    """
    Anti-rotate matrix for 3 layers of hidden_states.
    """
    return torch.block_diag(Q, Q, Q)

@MengqingCao MengqingCao added ready read for review ready-for-test start test by label for PR labels Mar 6, 2026
Signed-off-by: drslark <slarksblood@qq.com>
Copy link
Copy Markdown
Collaborator

@MengqingCao MengqingCao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your effort!

@MengqingCao MengqingCao merged commit 6a7115f into vllm-project:main Mar 9, 2026
38 checks passed
MengqingCao pushed a commit that referenced this pull request Mar 16, 2026
### What this PR does / why we need it?
Add an e2e test for QuaRot model with eagle3 that runs both the QuaRot
model and the float model, and then compares their acceptance rates. The
QuaRot model adapting eagle3 PR(#6914, #7038)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Nagisa125 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 17, 2026
…ct#7128)

### What this PR does / why we need it?
Add an e2e test for QuaRot model with eagle3 that runs both the QuaRot
model and the float model, and then compares their acceptance rates. The
QuaRot model adapting eagle3 PR(vllm-project#6914, vllm-project#7038)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm-ascend that referenced this pull request Mar 19, 2026
…ject#7038)

If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is vllm-project/vllm#36225.

- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm-ascend that referenced this pull request Mar 19, 2026
)

Cherry-pick from upstream main 6a7115f.
Rename patch_qwen3_quarot to patch_draft_quarot (already present)
and clean up redundant documentation.
liuchenbing2026 pushed a commit to liuchen20/vllm-ascend that referenced this pull request Mar 24, 2026
…ject#7038)

If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is vllm-project/vllm#36225.

- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
liuchenbing2026 pushed a commit to liuchen20/vllm-ascend that referenced this pull request Mar 24, 2026
…ject#7038)

If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is vllm-project/vllm#36225.

- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Apr 1, 2026
…ct#7128)

### What this PR does / why we need it?
Add an e2e test for QuaRot model with eagle3 that runs both the QuaRot
model and the float model, and then compares their acceptance rates. The
QuaRot model adapting eagle3 PR(vllm-project#6914, vllm-project#7038)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants