Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
SupportsPP,
)
from .module_mapping import MultiModelKeys
from .utils import init_vllm_registered_model, maybe_prefix
from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix
from .vision import (
VisionEncoderInfo,
VisionFeatureSelectStrategy,
Expand All @@ -93,6 +93,10 @@
PATCH_MERGE = "patch_merge"


def _is_layer_none_or_staged(layer: nn.Module) -> bool:
return layer is None or isinstance(layer, StageMissingLayer)


class PixtralImagePixelInputs(TensorSchema):
"""
Dimensions:
Expand Down Expand Up @@ -542,7 +546,7 @@ def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_encoder_weights((name, w)):
if self.vision_encoder is None:
if _is_layer_none_or_staged(self.vision_encoder):
continue
# Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:])
Expand All @@ -551,23 +555,23 @@ def llm_weights_generator():
with torch.no_grad():
default_weight_loader(param, w)
elif is_patch_merger((name, w)):
if self.patch_merger is None:
if _is_layer_none_or_staged(self.patch_merger):
continue
# Load vision patch merger weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = patch_merger_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_pre_mm_projector_norm((name, w)):
if self.pre_mm_projector_norm is None:
if _is_layer_none_or_staged(self.pre_mm_projector_norm):
continue
# Load vision pre_mm_projector_norm weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = pre_mm_projector_norm_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
if self.vision_language_adapter is None:
if _is_layer_none_or_staged(self.vision_language_adapter):
continue
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
Expand Down