diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6e9e46368f26..b8da164ee8e3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -26,11 +26,12 @@ # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from functools import lru_cache, partial from typing import Annotated, Any, Literal, TypeAlias import einops +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -1044,121 +1045,82 @@ class Qwen2_5_VLForConditionalGeneration( supports_encoder_tp_data = True + def iter_mm_grid_thw( + self, mm_features: list[MultiModalFeatureSpec] + ) -> Iterator[tuple[int, int, int, int, float]]: + """ + Iterate over multimodal features and yield grid information. + + Args: + mm_features: List of multimodal feature specifications + + Yields: + Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) + for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): + offset = mm_feature.mm_position.offset + if mm_feature.modality == "image": + t, h, w = mm_feature.data["image_grid_thw"].data.tolist() + assert t == 1, f"Image must have 1 frame, got {t}" + yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0 + elif mm_feature.modality == "video": + t, h, w = mm_feature.data["video_grid_thw"].data.tolist() + second_per_grid_ts = 1.0 + if mm_feature.data.get("second_per_grid_ts", None): + second_per_grid_ts = mm_feature.data[ + "second_per_grid_ts" + ].data.item() + t_factor = second_per_grid_ts * tokens_per_second + yield ( + offset, + t, + h // spatial_merge_size, + w // spatial_merge_size, + t_factor, + ) + else: + raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - kwargs = MultiModalFeatureSpec.gather_kwargs( - mm_features, - {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, - ) - image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] - video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] - second_per_grid_ts = kwargs.get("second_per_grid_ts", []) - - hf_config = self.config - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id - ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - video_second_per_grid_t = 0.0 - if remain_images > 0: - try: - ed_image = input_tokens.index(image_token_id, st) - except ValueError: - ed_image = len(input_tokens) + 1 - else: - ed_image = len(input_tokens) + 1 - if remain_videos > 0: - try: - ed_video = input_tokens.index(video_token_id, st) - except ValueError: - ed_video = len(input_tokens) + 1 - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = image_grid_thw[image_index] - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = video_grid_thw[video_index] - video_second_per_grid_t = 1.0 - if second_per_grid_ts: - video_second_per_grid_t = second_per_grid_ts[video_index] - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st + for ( + offset, + llm_grid_t, + llm_grid_h, + llm_grid_w, + t_factor, + ) in self.iter_mm_grid_thw(mm_features): + text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - t_index = ( - ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - * video_second_per_grid_t - * tokens_per_second - ) - .long() - .flatten() - ) - - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w + grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)) + if t_factor != 1.0: + grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64) + llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx) + st = offset + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - return llm_positions, mrope_position_delta + return torch.from_numpy(llm_positions), mrope_position_delta @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 3b0dce7fcd17..226d94c944a4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -26,7 +26,7 @@ """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" import math -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal, TypeAlias @@ -1137,121 +1137,82 @@ class Qwen2VLForConditionalGeneration( supports_encoder_tp_data = True + def iter_mm_grid_thw( + self, mm_features: list[MultiModalFeatureSpec] + ) -> Iterator[tuple[int, int, int, int, float]]: + """ + Iterate over multimodal features and yield grid information. + + Args: + mm_features: List of multimodal feature specifications + + Yields: + Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) + for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): + offset = mm_feature.mm_position.offset + if mm_feature.modality == "image": + t, h, w = mm_feature.data["image_grid_thw"].data.tolist() + assert t == 1, f"Image must have 1 frame, got {t}" + yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0 + elif mm_feature.modality == "video": + t, h, w = mm_feature.data["video_grid_thw"].data.tolist() + second_per_grid_ts = 1.0 + if mm_feature.data.get("second_per_grid_ts", None): + second_per_grid_ts = mm_feature.data[ + "second_per_grid_ts" + ].data.item() + t_factor = second_per_grid_ts * tokens_per_second + yield ( + offset, + t, + h // spatial_merge_size, + w // spatial_merge_size, + t_factor, + ) + else: + raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - kwargs = MultiModalFeatureSpec.gather_kwargs( - mm_features, - {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, - ) - image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] - video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] - second_per_grid_ts = kwargs.get("second_per_grid_ts", []) - - hf_config = self.config - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id - ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - video_second_per_grid_t = 0.0 - if remain_images > 0: - try: - ed_image = input_tokens.index(image_token_id, st) - except ValueError: - ed_image = len(input_tokens) + 1 - else: - ed_image = len(input_tokens) + 1 - if remain_videos > 0: - try: - ed_video = input_tokens.index(video_token_id, st) - except ValueError: - ed_video = len(input_tokens) + 1 - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = image_grid_thw[image_index] - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = video_grid_thw[video_index] - video_second_per_grid_t = 1.0 - if second_per_grid_ts: - video_second_per_grid_t = second_per_grid_ts[video_index] - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st + for ( + offset, + llm_grid_t, + llm_grid_h, + llm_grid_w, + t_factor, + ) in self.iter_mm_grid_thw(mm_features): + text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - t_index = ( - ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - * video_second_per_grid_t - * tokens_per_second - ) - .long() - .flatten() + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w + grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)) + if t_factor != 1.0: + grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64) + llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx) + st = offset + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - return llm_positions, mrope_position_delta + return torch.from_numpy(llm_positions), mrope_position_delta @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: