diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index b08810892006..305d13996b5a 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -1151,6 +1151,28 @@ def get_mm_mapping(self) -> MultiModelKeys: """ return MultiModelKeys.from_string_field( language_model="language_model", - connector="multi_modal_projector.", + connector=[ + "multi_modal_projector.", + "vision_model.vision_adapter.", + ], tower_model="vision_model.", ) + + def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: + vision_config = self.config.vision_config + patches_per_chunk = Mllama4ProcessingInfo.get_patch_per_chunk(vision_config) + if num_image_tokens <= 0 or patches_per_chunk <= 0: + return 0 + raw_patches = (vision_config.image_size // vision_config.patch_size) ** 2 + num_chunks = num_image_tokens // patches_per_chunk + # Encoder processes raw_patches + 1 (CLS) per chunk + return num_chunks * (raw_patches + 1) + + def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int: + vision_config = self.config.vision_config + raw_patches = (vision_config.image_size // vision_config.patch_size) ** 2 + if num_vision_tokens <= 0: + return 0 + num_chunks = num_vision_tokens // (raw_patches + 1) + patches_per_chunk = Mllama4ProcessingInfo.get_patch_per_chunk(vision_config) + return num_chunks * patches_per_chunk