From 7a07a9c93d328fddece47df7d938cf505e914fe7 Mon Sep 17 00:00:00 2001 From: drslark Date: Fri, 6 Mar 2026 14:57:18 +0800 Subject: [PATCH] [main][feature] Support quarot for eagle3 Signed-off-by: drslark --- vllm/model_executor/models/llama_eagle3.py | 33 +++++++++++++++++++ vllm/model_executor/models/utils.py | 37 ++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 5f66716d5454..ff784ff5828e 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -30,6 +30,9 @@ get_draft_quant_config, maybe_prefix, process_eagle_weight, + get_rotation_path, + get_rotataion_matrix, + compute_rotataion_matrix3, ) logger = init_logger(__name__) @@ -266,6 +269,10 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config.speculative_config.draft_model_config.hf_config + + self.target_model_config = self.vllm_config.speculative_config.target_model_config + self.target_quant_config = self.vllm_config.quant_config + # Ensure draft_vocab_size is set # default to the base vocab size when absent if getattr(self.config, "draft_vocab_size", None) is None: @@ -360,6 +367,19 @@ def combine_hidden_states( return self.model.fc(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # TODO maybe extract a function + rotation_path = get_rotation_path(self.target_model_config.model, self.target_quant_config) + + use_quarot = rotation_path is not None + + if use_quarot: + Q = get_rotataion_matrix(rotation_path) + Q3 = compute_rotataion_matrix3(Q) + if isinstance(self.config.dtype, str): + embed_dtype = getattr(torch, self.config.dtype) + else: + embed_dtype = self.config.dtype + model_weights = {} includes_draft_id_mapping = False includes_embed_tokens = False @@ -384,11 +404,24 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): continue elif "lm_head" not in name: name = "model." + name + if "fc." in name and use_quarot: + # anti-rotate fc + dtype = loaded_weight.dtype + loaded_weight = loaded_weight @ Q3.to(dtype) if "embed_tokens" in name: includes_embed_tokens = True model_weights[name] = loaded_weight process_eagle_weight(self, name) + # process embedding if drafter does not have embedding + if use_quarot and not includes_embed_tokens: + name = "model.embed_tokens.weight" + loaded_weight = get_embedding_tensor(target_model_path).to(dtype) @ Q.T.to(dtype) + model_weights[name] = loaded_weight + + includes_embed_tokens = True + process_eagle_weight(self, name) + if not includes_mask_hidden and self.use_parallel_drafting: raise ValueError( "mask_hidden not found in weights but " diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index abc953b7f980..1ffef8f93b2b 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -875,3 +875,40 @@ def get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: if feature_layer_index < 0: return num_hidden_layers + feature_layer_index + 1 return feature_layer_index + + +def get_rotation_path(target_model_path, quant_config): + """ + Gets the path of the rotation matrix, returns None if the target model is not a quarot model. + """ + try: + rotation_relative_path = quant_config.quant_description["optional"]["quarot"]["rotation_map"][ + "global_rotation" + ] + except KeyError: + return None + + return Path(target_model_path) / rotation_relative_path + + +def get_rotataion_matrix(rotation_path): + """ + Anti-rotate maxtrix. + """ + try: + safetensor_data = load_file(rotation_path) + Q = safetensor_data["global_rotation"] + + return Q + except Exception as e: + logger.error( + f"Failed to load rotation weight from '{rotation_path}'. " + "If you want to use quarot model with eagle3, take a check." + ) + raise e + +def compute_rotataion_matrix3(Q): + """ + Anti-rotate matrix for 3 layers of hidden_states. + """ + return torch.block_diag(Q, Q, Q)