diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index f7b5d8899502..3360ce59a763 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -218,7 +218,7 @@ def apply( if "mm_token_type_ids" in processed_data else "token_type_ids" ) - mm_token_type_ids = processed_data.pop(token_type_key) + mm_token_type_ids = processed_data.get(token_type_key) # We can infer vLLM style placeholder from token type ids, if we split # it for each input `mm_data`. @@ -353,6 +353,7 @@ def embed_multimodal(self, **kwargs): num_image_patches = kwargs.pop("num_image_patches") kwargs.pop("token_type_ids", None) # used only in `forward` + kwargs.pop("mm_token_type_ids", None) # used only in `model.get_rope_index` if pixel_values is not None: # ROCm: Force math SDP backend for vision encoder to avoid accuracy issues @@ -443,6 +444,7 @@ def get_mrope_input_positions( { "image_grid_thw", "video_grid_thw", + "mm_token_type_ids", "second_per_grid_ts", "audio_feature_lengths", "use_audio_in_video", @@ -451,7 +453,7 @@ def get_mrope_input_positions( if any( v for k, v in kwargs.items() - if k not in {"image_grid_thw", "video_grid_thw"} + if k not in {"image_grid_thw", "mm_token_type_ids"} ): raise NotImplementedError( "Transformers modeling backend only supports images." @@ -459,6 +461,7 @@ def get_mrope_input_positions( image_grid_thw = kwargs.get("image_grid_thw", []) video_grid_thw = kwargs.get("video_grid_thw", []) + mm_token_type_ids = kwargs.get("mm_token_type_ids") image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( image_grid_thw @@ -467,10 +470,17 @@ def get_mrope_input_positions( video_grid_thw ) + # In v4 `get_rope_index` doesn't have wildcard `kwargs`, and + # can't accept arbitrary args, even if its value is `None` + kwargs = {} + if mm_token_type_ids: + kwargs["mm_token_type_ids"] = torch.cat(mm_token_type_ids) + mrope_positions, mrope_position_delta = self.model.get_rope_index( input_ids=torch.tensor(input_tokens).unsqueeze(0), image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, + **kwargs, ) mrope_positions = mrope_positions[:, 0]