diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index a7699f0d5983..b4da15ec2eb7 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -38,14 +38,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model from vllm.sequence import IntermediateTensors -from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix logger = init_logger(__name__) @@ -89,62 +85,6 @@ def forward( hidden_states = hidden_states + residual return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "mtp_layers" in name: - continue - if "rotary_emb.inv_freq" in name: - continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -186,3 +126,12 @@ def compute_logits( hidden_states = self.model.norm(hidden_states) logits = self.logits_processor(self.lm_head, hidden_states) return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = ["lm_head."] if self.config.tie_word_embeddings else None + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + skip_substrs=["mtp_layers"], + ) + return loader.load_weights(weights)