Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,14 @@
# https://github.com/vllm-project/vllm/pull/34336
# Future Plan:
# Remove this patch when vLLM merges the PR.
# ** 16. File: worker/patch_qwen3_quarot.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.llama_eagle3.Eagle3LlamaForCausalLM.load_weights`
# Why:
# vllm-ascend reused the loading logic of drafter model from vllm,
# but vllm doesn't need to apply to Ascend quantization.
# How:
# Dynamically replace the `load_weights` function at runtime,
# and fix `target_config` into the new implementation with a closure.
# Future Plan:
# Remove this patch when vLLM merges the PR.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you just put the pr link here?

1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
import vllm_ascend.patch.worker.patch_kimi_k25 # noqa
import vllm_ascend.patch.worker.patch_qwen3_quarot # noqa
79 changes: 79 additions & 0 deletions vllm_ascend/patch/worker/patch_qwen3_quarot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging
from collections.abc import Iterable
from pathlib import Path

import torch
from safetensors.torch import load_file
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
process_eagle_weight,
)


def patch_load_weights(target_config):
Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_config)


def make_load_weights(target_config):
logger = logging.getLogger(__name__)
quant_cfg = target_config.quant_config
rotation_matrix3 = None

model_path = target_config.model_config.model
try:
rotation_rel_path = quant_cfg.quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"]
except KeyError as e:
logger.error(
"Invalid quant_config: missing key "
"quant_description['optional']['quarot']['rotation_map']['global_rotation']. "
"If you don't use quarot model, please ignore it. "
f"Error: {e}"
)
else:
rotation_path = Path(model_path) / rotation_rel_path
try:
safetensor_data = load_file(rotation_path)
Q = safetensor_data["global_rotation"]
rotation_matrix3 = torch.block_diag(Q, Q, Q)
except Exception as e:
logger.error(
f"Failed to load rotation weight from '{rotation_path}'. "
"If you don't use quarot model, please ignore it. "
f"Error: {e}"
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
if "fc." in name and rotation_matrix3 is not None:
loaded_weight = loaded_weight @ rotation_matrix3.to(loaded_weight.dtype)
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)

skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
if not self.model.use_aux_hidden_state:
skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())

return load_weights
3 changes: 3 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights
from vllm_ascend.sample.sampler import AscendSampler
from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
Expand Down Expand Up @@ -2419,6 +2420,8 @@ def load_model(self) -> None:
model_register(self.model)
if self.drafter:
logger.info("Loading drafter model...")
if self.vllm_config.quant_config is not None:
patch_load_weights(self.vllm_config)
with get_tp_context(self.drafter):
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
Expand Down
Loading