[main][feature] Support quarot for eagle3#36225
[main][feature] Support quarot for eagle3#36225drslark wants to merge 1 commit intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces support for quarot for the eagle3 model, involving utility functions for rotation matrices and updated weight loading in llama_eagle3.py. However, it contains a high-severity Path Traversal vulnerability in the get_rotation_path function. Furthermore, the implementation has critical issues such as missing imports, undefined function calls (e.g., target_config, compute_rotataion_matrix3, get_embedding_tensor, Path, load_file), typos (get_rotataion_matrix), and the use of global state, which will lead to runtime errors and maintenance difficulties. These issues must be addressed before merging.
| target_model_path = Path(target_config.model_config.model) | ||
| rotation_path = get_rotation_path(target_config) | ||
|
|
||
| use_quarot = rotation_path is not None | ||
|
|
||
| if use_quarot: | ||
| Q = get_rotataion_matrix(rotation_path) | ||
| Q3 = compute_rotataion_matrix3(Q) |
There was a problem hiding this comment.
This block of code relies on several functions and classes that are not imported or defined in this file: Path (from pathlib), get_rotation_path (from .utils), get_rotataion_matrix (from .utils), and compute_rotataion_matrix3. This will cause NameError exceptions at runtime. Please ensure all necessary components are imported or defined. Additionally, get_rotataion_matrix has a typo and should be get_rotation_matrix.
| # process embedding if drafter does not have embedding | ||
| if use_quarot and not includes_embed_tokens: | ||
| name = "model.embed_tokens.weight" | ||
| loaded_weight = get_embedding_tensor(target_model_path).to(dtype) @ Q.T.to(dtype) |
| _target_config = None | ||
|
|
||
|
|
||
| def set_target_config(target_config): | ||
| global _target_config | ||
| _target_config = target_config | ||
|
|
||
|
|
||
| def get_target_config(): | ||
| global _target_config | ||
| return _target_config |
There was a problem hiding this comment.
Using a global variable _target_config with set_target_config and get_target_config introduces global state, which is a poor design choice. It creates hidden dependencies, makes the code harder to reason about and test, and is not thread-safe. This can lead to subtle bugs that are difficult to track down. It is strongly recommended to refactor this to pass target_config explicitly as an argument to the functions that need it.
| def get_rotation_path(target_config): | ||
| """ | ||
| Gets the path of the rotation matrix, returns None if the target model is not a quarot model. | ||
| """ | ||
| target_model_path = target_config.model_config.model | ||
| try: | ||
| rotation_relative_path = target_config.quant_config.quant_description["optional"]["quarot"]["rotation_map"][ | ||
| "global_rotation" | ||
| ] | ||
| except KeyError: | ||
| return None | ||
|
|
||
| return Path(target_model_path) / rotation_relative_path |
There was a problem hiding this comment.
This function is vulnerable to a high-severity Path Traversal. The get_rotation_path function constructs a file path using target_model_path and rotation_relative_path without proper validation. An attacker could exploit this by providing a malicious rotation_relative_path (e.g., containing ../ or an absolute path) to access files outside the intended directory. The suggested code addresses this by validating rotation_relative_path to prevent traversal and also includes the necessary from pathlib import Path import, fixing a NameError.
def get_rotation_path(target_config):
"""
Gets the path of the rotation matrix, returns None if the target model is not a quarot model.
"""
import os
from pathlib import Path
target_model_path = target_config.model_config.model
try:
rotation_relative_path = target_config.quant_config.quant_description["optional"]["quarot"]["rotation_map"][
"global_rotation"
]
except KeyError:
return None
if os.path.isabs(rotation_relative_path) or ".." in Path(rotation_relative_path).parts:
raise ValueError(f"Invalid rotation_relative_path: {rotation_relative_path}")
return Path(target_model_path) / rotation_relative_path| if use_quarot: | ||
| Q = get_rotataion_matrix(rotation_path) | ||
| Q3 = compute_rotataion_matrix3(Q) | ||
| dtype = None |
There was a problem hiding this comment.
dtype is initialized to None. It is only updated if an fc layer's weights are processed. If not, it remains None when used on line 411, which could lead to unexpected behavior or dtype mismatches. It's safer to initialize it to a sensible default, like the model's lm_head dtype.
| dtype = None | |
| dtype = self.lm_head.weight.dtype |
| def get_rotataion_matrix(rotation_path): | ||
| """ | ||
| Anti-rotate maxtrix. | ||
| """ |
There was a problem hiding this comment.
There's a typo in the function name (get_rotataion_matrix should be get_rotation_matrix). Also, the docstring "Anti-rotate maxtrix." is not very descriptive and contains a typo (maxtrix). Please correct the function name and provide a more informative docstring.
| def get_rotataion_matrix(rotation_path): | |
| """ | |
| Anti-rotate maxtrix. | |
| """ | |
| def get_rotation_matrix(rotation_path): | |
| """ | |
| Loads the anti-rotation matrix Q from the given path. | |
| """ |
Signed-off-by: drslark <slarksblood@qq.com>
### What this PR does / why we need it? If some `eagle3` model without embed_tokens works with `quarot` target model, the acceptence rate will drop. We solve it in this PR. The relative vllm pr is vllm-project/vllm#36225. - vLLM main: vllm-project/vllm@4034c3d Signed-off-by: drslark <slarksblood@qq.com>
…ject#7038) If some `eagle3` model without embed_tokens works with `quarot` target model, the acceptence rate will drop. We solve it in this PR. The relative vllm pr is vllm-project/vllm#36225. - vLLM main: vllm-project/vllm@4034c3d Signed-off-by: drslark <slarksblood@qq.com>
…ject#7038) If some `eagle3` model without embed_tokens works with `quarot` target model, the acceptence rate will drop. We solve it in this PR. The relative vllm pr is vllm-project/vllm#36225. - vLLM main: vllm-project/vllm@4034c3d Signed-off-by: drslark <slarksblood@qq.com>
…ject#7038) If some `eagle3` model without embed_tokens works with `quarot` target model, the acceptence rate will drop. We solve it in this PR. The relative vllm pr is vllm-project/vllm#36225. - vLLM main: vllm-project/vllm@4034c3d Signed-off-by: drslark <slarksblood@qq.com>
Purpose
As described in #36223.
We have a
quarottarget model and aeagle3draft model trained with original target model.The acceptence rate dropped because of roatation.
We discussed with modelslim and get a new model protocal.
We will adapt to modelslim's protocal.
Test Plan
TODO, we made this pr as draft now.
When completed, we will open it.
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.