diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 73c9c586029..48cabca293f 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -305,3 +305,14 @@ # https://github.com/vllm-project/vllm/pull/34336 # Future Plan: # Remove this patch when vLLM merges the PR. +# ** 16. File: worker/patch_qwen3_quarot.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.llama_eagle3.Eagle3LlamaForCausalLM.load_weights` +# Why: +# vllm-ascend reused the loading logic of drafter model from vllm, +# but vllm doesn't need to apply to Ascend quantization. +# How: +# Dynamically replace the `load_weights` function at runtime, +# and fix `target_config` into the new implementation with a closure. +# Future Plan: +# Remove this patch when vLLM merges the PR. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index ad2429d120e..60320de093c 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -35,3 +35,4 @@ import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa import vllm_ascend.patch.worker.patch_kimi_k25 # noqa +import vllm_ascend.patch.worker.patch_qwen3_quarot # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen3_quarot.py b/vllm_ascend/patch/worker/patch_qwen3_quarot.py new file mode 100644 index 00000000000..3780c6640f9 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3_quarot.py @@ -0,0 +1,79 @@ +import logging +from collections.abc import Iterable +from pathlib import Path + +import torch +from safetensors.torch import load_file +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + process_eagle_weight, +) + + +def patch_load_weights(target_config): + Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_config) + + +def make_load_weights(target_config): + logger = logging.getLogger(__name__) + quant_cfg = target_config.quant_config + rotation_matrix3 = None + + model_path = target_config.model_config.model + try: + rotation_rel_path = quant_cfg.quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"] + except KeyError as e: + logger.error( + "Invalid quant_config: missing key " + "quant_description['optional']['quarot']['rotation_map']['global_rotation']. " + "If you don't use quarot model, please ignore it. " + f"Error: {e}" + ) + else: + rotation_path = Path(model_path) / rotation_rel_path + try: + safetensor_data = load_file(rotation_path) + Q = safetensor_data["global_rotation"] + rotation_matrix3 = torch.block_diag(Q, Q, Q) + except Exception as e: + logger.error( + f"Failed to load rotation weight from '{rotation_path}'. " + "If you don't use quarot model, please ignore it. " + f"Error: {e}" + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + model_weights = {} + includes_draft_id_mapping = False + includes_embed_tokens = False + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + includes_draft_id_mapping = True + elif "lm_head" not in name: + name = "model." + name + if "fc." in name and rotation_matrix3 is not None: + loaded_weight = loaded_weight @ rotation_matrix3.to(loaded_weight.dtype) + if "embed_tokens" in name: + includes_embed_tokens = True + model_weights[name] = loaded_weight + process_eagle_weight(self, name) + + skip_substrs = [] + if not includes_draft_id_mapping: + skip_substrs.append("draft_id_to_target_id") + if not includes_embed_tokens: + skip_substrs.append("embed_tokens") + if not self.model.use_aux_hidden_state: + skip_substrs.append("fc.") + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=skip_substrs, + ) + loader.load_weights(model_weights.items()) + + return load_weights diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4ed3f58795c..dc8011d5426 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -106,6 +106,7 @@ from vllm_ascend.eplb.utils import model_register from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort +from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer @@ -2419,6 +2420,8 @@ def load_model(self) -> None: model_register(self.model) if self.drafter: logger.info("Loading drafter model...") + if self.vllm_config.quant_config is not None: + patch_load_weights(self.vllm_config) with get_tp_context(self.drafter): self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: