From d3b4c367047dfb1c918b145740155eac0bf9fa49 Mon Sep 17 00:00:00 2001 From: Vensenmu Date: Wed, 25 Jun 2025 17:05:12 +0800 Subject: [PATCH] Fix(gemma3_mm): Add robust weight remapping for VLM Signed-off-by: Vensenmu --- vllm/model_executor/models/gemma3_mm.py | 43 ++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 3a1c14978b45..f810e895d60d 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -13,6 +13,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -35,9 +36,8 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import (WeightsMapper, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -707,8 +707,41 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + # Custom weight loader for Gemma3 VLM to handle naming inconsistencies. + # This loader first applies the class-level hf_to_vllm_mapper and then + # applies a targeted fix for the "double prefix" issue within the + # vision model component. + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for original_name, loaded_weight in weights: + name = original_name + + # Apply the standard class-level mapper first. + if self.hf_to_vllm_mapper is not None: + name = self.hf_to_vllm_mapper.map(name) + + # Apply the targeted hotfix only if a vision weight is not found. + # This prevents regressions on other models. + if name not in params_dict and name.startswith("vision_model."): + potential_name = f"vision_model.{name}" + if potential_name in params_dict: + name = potential_name + + # Load the weight using the potentially corrected name. + if name not in params_dict: + # Silently skip any weights that are still not found. + loaded_params.add(original_name) + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(original_name) + + return loaded_params def get_mm_mapping(self) -> MultiModelKeys: """