diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index f287cff12086..25754cb64d27 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,7 +34,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from vllm.attention.backends.registry import _Backend from vllm.attention.layer import ( @@ -58,6 +58,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, ) @@ -1432,15 +1433,16 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for Ernie VL.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.im_patch_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id @@ -1448,10 +1450,7 @@ def get_mrope_input_positions( temporal_conv_size = hf_config.temporal_conv_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -1483,11 +1482,7 @@ def get_mrope_input_positions( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_conv_size, @@ -1518,11 +1513,7 @@ def get_mrope_input_positions( mm_data_idx += 1 elif modality_type == "video": - t, h, w = ( - video_grid_thw[mm_data_idx][0], - video_grid_thw[mm_data_idx][1], - video_grid_thw[mm_data_idx][2], - ) + t, h, w = video_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t // temporal_conv_size, h // spatial_conv_size, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index b9cd3545ec45..6eda049d148d 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -37,7 +37,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( Glm4vImageProcessor, @@ -70,6 +70,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1618,25 +1619,23 @@ def get_multimodal_embeddings( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: "PretrainedConfig", - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.image_token_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -1668,11 +1667,7 @@ def get_mrope_input_positions( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -1705,8 +1700,7 @@ def get_mrope_input_positions( elif modality_type == "video": t, h, w = ( video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], + *image_grid_thw[mm_data_idx][1:], ) llm_grid_t, llm_grid_h, llm_grid_w = ( t, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index ebf6934dddea..899797a51053 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -15,7 +15,7 @@ from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType +from transformers import BatchFeature, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput @@ -36,6 +36,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, ) @@ -622,25 +623,23 @@ def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tenso def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.image_token_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -672,11 +671,7 @@ def get_mrope_input_positions( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -709,8 +704,7 @@ def get_mrope_input_positions( elif modality_type == "video": t, h, w = ( video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], + *image_grid_thw[mm_data_idx][1:], ) llm_grid_t, llm_grid_h, llm_grid_w = ( t, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d6a8f86d998b..88b45bf07c0d 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn from torch import Tensor -from transformers import PretrainedConfig from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs @@ -32,10 +31,12 @@ if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.models.utils import WeightsMapper + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.sequence import IntermediateTensors else: VllmConfig = object WeightsMapper = object + MultiModalFeatureSpec = object IntermediateTensors = object logger = init_logger(__name__) @@ -991,12 +992,7 @@ class SupportsMRoPE(Protocol): def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list["MultiModalFeatureSpec"], ) -> tuple[torch.Tensor, int]: """ Get M-RoPE input positions and delta value for this specific model. @@ -1006,17 +1002,11 @@ def get_mrope_input_positions( Args: input_tokens: List of input token IDs - hf_config: HuggingFace model configuration - image_grid_thw: Image grid dimensions (t, h, w) - video_grid_thw: Video grid dimensions (t, h, w) - second_per_grid_ts: Seconds per grid timestep for videos - audio_feature_lengths: Audio feature lengths for multimodal models - use_audio_in_video: Whether to use audio in video for interleaving + mm_features: Information about each multi-modal data item Returns: - Tuple of (llm_positions, mrope_position_delta) - - llm_positions: Tensor of shape [3, num_tokens] - with T/H/W positions + Tuple of `(llm_positions, mrope_position_delta)` + - llm_positions: Tensor of shape `[3, num_tokens]` with T/H/W positions - mrope_position_delta: Delta for position calculations """ ... diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 42f16ad9f3b3..821710ef1bec 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -40,6 +40,7 @@ ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1627,16 +1628,17 @@ def _process_video_input( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: video_grid_thw = video_grid_thw[0] - """Get mrope input positions and delta value (Keye series).""" def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: """ @@ -1662,6 +1664,7 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: video_grid_thw = split_thw(video_grid_thw) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size @@ -1691,20 +1694,12 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_frames -= 1 ed = ed_video diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 6f95a59d36d2..124e9c2afa21 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -21,6 +21,7 @@ from vllm.multimodal.inputs import ( ImageItem, ModalityData, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -597,16 +598,17 @@ def _process_video_input( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: video_grid_thw = video_grid_thw[0] - """Get mrope input positions and delta value (Keye series).""" def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: """ @@ -632,6 +634,7 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: video_grid_thw = split_thw(video_grid_thw) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size @@ -661,20 +664,12 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_frames -= 1 ed = ed_video diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 631475c964c0..12544cc391a0 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -61,6 +61,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargs, ) @@ -1175,15 +1176,17 @@ def compute_logits( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float], - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1220,20 +1223,12 @@ def get_mrope_input_positions( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index fac281d2caf4..8f74cab0534d 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -68,6 +68,7 @@ ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, @@ -923,21 +924,9 @@ def get_language_model(self) -> torch.nn.Module: def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value (Qwen2.5-Omni version). - - Differences from MRotaryEmbedding: - 1. Add audio support (and related `audio_feature_lengths`). - 2. Add `use_audio_in_video` option to read audio from video inputs. - In this case, audio and vision position ids will be split into - chunks and interleaved. - + """ Example: (V_i are vision position ids, A_i are audio position ids) @@ -945,11 +934,33 @@ def get_mrope_input_positions( |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... """ + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + audio_feature_lengths = kwargs.get("audio_feature_lengths", []) + use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) # TODO(fyabc): refactor and share more code with # _vl_get_input_positions_tensor. - thinker_config = hf_config.thinker_config + thinker_config = self.config audio_token_id = thinker_config.audio_token_index image_token_id = thinker_config.image_token_index video_token_id = thinker_config.video_token_index @@ -963,11 +974,6 @@ def get_mrope_input_positions( thinker_config.vision_config, "tokens_per_second", 25 ) - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - src_item = input_tokens audio_seqlens = audio_feature_lengths if not second_per_grid_ts: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 48834ba699e4..e7f876d5fb5c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -35,7 +35,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, @@ -75,7 +75,11 @@ compute_retention_mask, recompute_mrope_positions, ) -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalFieldConfig, + MultiModalKwargs, +) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors @@ -1114,15 +1118,17 @@ class Qwen2_5_VLForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float], - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1159,20 +1165,12 @@ def get_mrope_input_positions( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b3999e6c934e..975c1dce40dc 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -34,7 +34,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl.configuration_qwen2_vl import ( Qwen2VLConfig, @@ -70,6 +70,7 @@ ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1236,21 +1237,17 @@ class Qwen2VLForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get M-RoPE input positions for Qwen2-VL model.""" - if image_grid_thw is None: - image_grid_thw = [] - if video_grid_thw is None: - video_grid_thw = [] - if second_per_grid_ts is None: - second_per_grid_ts = [] + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1287,20 +1284,12 @@ def get_mrope_input_positions( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index da489a812f55..b5aeb09b5dbd 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -65,7 +65,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems from vllm.multimodal.processing import ( BaseMultiModalProcessor, @@ -1413,39 +1413,48 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - config = hf_config.thinker_config - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + audio_feature_lengths = kwargs.get("audio_feature_lengths", []) + use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) + input_ids = torch.tensor(input_tokens) if input_ids is None or input_ids.ndim != 1: raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") seq_len = input_ids.shape[0] - if audio_feature_lengths is not None and not isinstance( - audio_feature_lengths, torch.Tensor - ): - audio_feature_lengths = torch.as_tensor( + + if isinstance(audio_feature_lengths, list): + audio_feature_lengths = torch.tensor( audio_feature_lengths, dtype=torch.long ) - if second_per_grid_ts is None: - if video_grid_thw is not None and video_grid_thw.numel() > 0: - second_per_grids = torch.ones( - video_grid_thw.shape[0], dtype=torch.float32 - ) - else: - second_per_grids = torch.tensor([], dtype=torch.float32) + + if not len(second_per_grid_ts) and len(video_grid_thw): + second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32) else: second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) + config = self.config spatial_merge_size = config.vision_config.spatial_merge_size image_token_id = config.image_token_id video_token_id = config.video_token_id diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fe0124ef3258..e14533f4e929 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -34,7 +34,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( smart_resize as image_smart_resize, @@ -70,6 +70,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItem, MultiModalKwargsItems, @@ -1416,17 +1417,18 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1455,20 +1457,12 @@ def get_mrope_input_positions( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_videos -= 1 ed = ed_video diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 476074542e6a..2efcef68d1c7 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -27,6 +27,7 @@ from vllm.multimodal import MultiModalKwargsItems from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalInputs, MultiModalUUIDDict, @@ -38,7 +39,7 @@ from vllm.sequence import IntermediateTensors if TYPE_CHECKING: - from transformers import BatchFeature, PretrainedConfig + from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -367,20 +368,34 @@ def get_multimodal_embeddings(self, **kwargs): def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: "PretrainedConfig", - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)): + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + if any( + v + for k, v in kwargs.items() + if k not in {"image_grid_thw", "video_grid_thw"} + ): raise NotImplementedError("Transformers backend only supports images.") - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) mrope_positions, mrope_position_delta = self.model.get_rope_index( input_ids=torch.tensor(input_tokens).unsqueeze(0), diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index a05f54191f04..7518a023c5f5 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -249,6 +249,19 @@ class MultiModalFeatureSpec: mm_position: PlaceholderRange """e.g., PlaceholderRange(offset=2, length=336)""" + @staticmethod + def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]): + kwargs = defaultdict[str, list[NestedTensors]](list) + + for f in features: + item = f.data + if item is not None: + for k in keys: + if k in item: + kwargs[k].append(item[k].data) + + return dict(kwargs) + @dataclass class MultiModalFieldElem: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 26007d29d61b..40f64c3e3cb3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -889,38 +889,13 @@ def _update_states_after_model_execute( self.input_batch.num_accepted_tokens_cpu[i] = num_tokens def _init_mrope_positions(self, req_state: CachedRequestState): - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_feature in req_state.mm_features: - mm_item = mm_feature.data - if mm_item is None: - continue - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + model = self.get_model() + assert supports_mrope(model), "M-RoPE support is not implemented." req_state.mrope_positions, req_state.mrope_position_delta = ( - self.model.get_mrope_input_positions( + model.get_mrope_input_positions( req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, + req_state.mm_features, ) )