Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,9 @@ def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
)

def is_patch_merger(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("patch_merger")
return weight[0].startswith(
("patch_merger", "multi_modal_projector.patch_merger")
)

def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("pre_mm_projector_norm")
Expand Down Expand Up @@ -554,18 +556,23 @@ def llm_weights_generator():
if self.patch_merger is None:
continue
# Load vision patch merger weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = patch_merger_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
if name.startswith("multi_modal_projector.patch_merger"):
trimmed_name = ".".join(name.split(".")[2:])
else:
trimmed_name = ".".join(name.split(".")[1:])
param = patch_merger_dict.get(trimmed_name)
if param is not None:
with torch.no_grad():
default_weight_loader(param, w)
elif is_pre_mm_projector_norm((name, w)):
if self.pre_mm_projector_norm is None:
continue
# Load vision pre_mm_projector_norm weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = pre_mm_projector_norm_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
param = pre_mm_projector_norm_dict.get(trimmed_name)
if param is not None:
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
if self.vision_language_adapter is None:
continue
Expand Down