diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 45e5010e2a3a..684c32a3a9e5 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -70,7 +70,7 @@ SupportsPP, ) from .module_mapping import MultiModelKeys -from .utils import init_vllm_registered_model, maybe_prefix +from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix from .vision import ( VisionEncoderInfo, VisionFeatureSelectStrategy, @@ -93,6 +93,10 @@ PATCH_MERGE = "patch_merge" +def _is_layer_none_or_staged(layer: nn.Module) -> bool: + return layer is None or isinstance(layer, StageMissingLayer) + + class PixtralImagePixelInputs(TensorSchema): """ Dimensions: @@ -542,7 +546,7 @@ def llm_weights_generator(): # Single pass over weights for name, w in weights: if is_vision_encoder_weights((name, w)): - if self.vision_encoder is None: + if _is_layer_none_or_staged(self.vision_encoder): continue # Load vision encoder weights directly trimmed_name = ".".join(name.split(".")[1:]) @@ -551,7 +555,7 @@ def llm_weights_generator(): with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): - if self.patch_merger is None: + if _is_layer_none_or_staged(self.patch_merger): continue # Load vision patch merger weights directly trimmed_name = ".".join(name.split(".")[1:]) @@ -559,7 +563,7 @@ def llm_weights_generator(): with torch.no_grad(): default_weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): - if self.pre_mm_projector_norm is None: + if _is_layer_none_or_staged(self.pre_mm_projector_norm): continue # Load vision pre_mm_projector_norm weights directly trimmed_name = ".".join(name.split(".")[1:]) @@ -567,7 +571,7 @@ def llm_weights_generator(): with torch.no_grad(): default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): - if self.vision_language_adapter is None: + if _is_layer_none_or_staged(self.vision_language_adapter): continue # Load vision-language adapter weights directly trimmed_name = ".".join(name.split(".")[1:])