diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c8eef850c497..4f210b8e5f0c 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -511,7 +511,9 @@ def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): ) def is_patch_merger(weight: tuple[str, torch.Tensor]): - return weight[0].startswith("patch_merger") + return weight[0].startswith( + ("patch_merger", "multi_modal_projector.patch_merger") + ) def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): return weight[0].startswith("pre_mm_projector_norm") @@ -554,18 +556,23 @@ def llm_weights_generator(): if self.patch_merger is None: continue # Load vision patch merger weights directly - trimmed_name = ".".join(name.split(".")[1:]) - param = patch_merger_dict[trimmed_name] - with torch.no_grad(): - default_weight_loader(param, w) + if name.startswith("multi_modal_projector.patch_merger"): + trimmed_name = ".".join(name.split(".")[2:]) + else: + trimmed_name = ".".join(name.split(".")[1:]) + param = patch_merger_dict.get(trimmed_name) + if param is not None: + with torch.no_grad(): + default_weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): if self.pre_mm_projector_norm is None: continue # Load vision pre_mm_projector_norm weights directly trimmed_name = ".".join(name.split(".")[1:]) - param = pre_mm_projector_norm_dict[trimmed_name] - with torch.no_grad(): - default_weight_loader(param, w) + param = pre_mm_projector_norm_dict.get(trimmed_name) + if param is not None: + with torch.no_grad(): + default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): if self.vision_language_adapter is None: continue