Skip to content

[main][feature] Support quarot for eagle3#36225

Draft
drslark wants to merge 1 commit intovllm-project:mainfrom
drslark:main
Draft

[main][feature] Support quarot for eagle3#36225
drslark wants to merge 1 commit intovllm-project:mainfrom
drslark:main

Conversation

@drslark
Copy link
Copy Markdown
Contributor

@drslark drslark commented Mar 6, 2026

Purpose

As described in #36223.

We have a quarot target model and a eagle3 draft model trained with original target model.

The acceptence rate dropped because of roatation.

We discussed with modelslim and get a new model protocal.

We will adapt to modelslim's protocal.

Test Plan

TODO, we made this pr as draft now.

When completed, we will open it.

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 6, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify Bot added llama Related to Llama models speculative-decoding labels Mar 6, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for quarot for the eagle3 model, involving utility functions for rotation matrices and updated weight loading in llama_eagle3.py. However, it contains a high-severity Path Traversal vulnerability in the get_rotation_path function. Furthermore, the implementation has critical issues such as missing imports, undefined function calls (e.g., target_config, compute_rotataion_matrix3, get_embedding_tensor, Path, load_file), typos (get_rotataion_matrix), and the use of global state, which will lead to runtime errors and maintenance difficulties. These issues must be addressed before merging.

Comment on lines +365 to +372
target_model_path = Path(target_config.model_config.model)
rotation_path = get_rotation_path(target_config)

use_quarot = rotation_path is not None

if use_quarot:
Q = get_rotataion_matrix(rotation_path)
Q3 = compute_rotataion_matrix3(Q)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block of code relies on several functions and classes that are not imported or defined in this file: Path (from pathlib), get_rotation_path (from .utils), get_rotataion_matrix (from .utils), and compute_rotataion_matrix3. This will cause NameError exceptions at runtime. Please ensure all necessary components are imported or defined. Additionally, get_rotataion_matrix has a typo and should be get_rotation_matrix.

# process embedding if drafter does not have embedding
if use_quarot and not includes_embed_tokens:
name = "model.embed_tokens.weight"
loaded_weight = get_embedding_tensor(target_model_path).to(dtype) @ Q.T.to(dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_embedding_tensor is called here but it is not defined or imported in this file, nor is it defined in this pull request. This will cause a NameError at runtime.

Comment thread vllm/model_executor/models/utils.py Outdated
Comment on lines +912 to +922
_target_config = None


def set_target_config(target_config):
global _target_config
_target_config = target_config


def get_target_config():
global _target_config
return _target_config
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using a global variable _target_config with set_target_config and get_target_config introduces global state, which is a poor design choice. It creates hidden dependencies, makes the code harder to reason about and test, and is not thread-safe. This can lead to subtle bugs that are difficult to track down. It is strongly recommended to refactor this to pass target_config explicitly as an argument to the functions that need it.

Comment thread vllm/model_executor/models/utils.py Outdated
Comment on lines +880 to +892
def get_rotation_path(target_config):
"""
Gets the path of the rotation matrix, returns None if the target model is not a quarot model.
"""
target_model_path = target_config.model_config.model
try:
rotation_relative_path = target_config.quant_config.quant_description["optional"]["quarot"]["rotation_map"][
"global_rotation"
]
except KeyError:
return None

return Path(target_model_path) / rotation_relative_path
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

This function is vulnerable to a high-severity Path Traversal. The get_rotation_path function constructs a file path using target_model_path and rotation_relative_path without proper validation. An attacker could exploit this by providing a malicious rotation_relative_path (e.g., containing ../ or an absolute path) to access files outside the intended directory. The suggested code addresses this by validating rotation_relative_path to prevent traversal and also includes the necessary from pathlib import Path import, fixing a NameError.

def get_rotation_path(target_config):
    """
    Gets the path of the rotation matrix, returns None if the target model is not a quarot model.
    """
    import os
    from pathlib import Path
    target_model_path = target_config.model_config.model
    try:
        rotation_relative_path = target_config.quant_config.quant_description["optional"]["quarot"]["rotation_map"][
            "global_rotation"
        ]
    except KeyError:
        return None

    if os.path.isabs(rotation_relative_path) or ".." in Path(rotation_relative_path).parts:
        raise ValueError(f"Invalid rotation_relative_path: {rotation_relative_path}")

    return Path(target_model_path) / rotation_relative_path

if use_quarot:
Q = get_rotataion_matrix(rotation_path)
Q3 = compute_rotataion_matrix3(Q)
dtype = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

dtype is initialized to None. It is only updated if an fc layer's weights are processed. If not, it remains None when used on line 411, which could lead to unexpected behavior or dtype mismatches. It's safer to initialize it to a sensible default, like the model's lm_head dtype.

Suggested change
dtype = None
dtype = self.lm_head.weight.dtype

Comment on lines +895 to +898
def get_rotataion_matrix(rotation_path):
"""
Anti-rotate maxtrix.
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a typo in the function name (get_rotataion_matrix should be get_rotation_matrix). Also, the docstring "Anti-rotate maxtrix." is not very descriptive and contains a typo (maxtrix). Please correct the function name and provide a more informative docstring.

Suggested change
def get_rotataion_matrix(rotation_path):
"""
Anti-rotate maxtrix.
"""
def get_rotation_matrix(rotation_path):
"""
Loads the anti-rotation matrix Q from the given path.
"""

Signed-off-by: drslark <slarksblood@qq.com>
MengqingCao pushed a commit to vllm-project/vllm-ascend that referenced this pull request Mar 9, 2026
### What this PR does / why we need it?
If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is vllm-project/vllm#36225.

- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm-ascend that referenced this pull request Mar 19, 2026
…ject#7038)

If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is vllm-project/vllm#36225.

- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
liuchenbing2026 pushed a commit to liuchen20/vllm-ascend that referenced this pull request Mar 24, 2026
…ject#7038)

If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is vllm-project/vllm#36225.

- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
liuchenbing2026 pushed a commit to liuchen20/vllm-ascend that referenced this pull request Mar 24, 2026
…ject#7038)

If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is vllm-project/vllm#36225.

- vLLM main:
vllm-project/vllm@4034c3d

Signed-off-by: drslark <slarksblood@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant