From 409b7822f524ad1dc1d13c8febb070ae10415c93 Mon Sep 17 00:00:00 2001 From: "chenkui.shen" Date: Mon, 15 Dec 2025 16:10:00 +0800 Subject: [PATCH 1/2] [Feature][Model] Add video input support for transformers modeling backend Key changes: - Extended multimodal classes (`MultiModalProcessingInfo`, `MultiModalProcessor`, `MultiModalMixin`) to handle video-specific logic, including token calculation, dummy data generation, and embedding. - Corrected the frame size extraction for video frames in `vllm/multimodal/parse.py`. - Updated documentation to reflect video support. - Fixed a potential OOM issue in the dummy batch generator for multimodal models. Signed-off-by: chenkui.shen --- docs/models/supported_models.md | 2 - .../models/transformers/multimodal.py | 181 +++++++++++++++--- vllm/multimodal/parse.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 13 +- 4 files changed, 167 insertions(+), 31 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9d8cdfe8b130..8cb2da5a4bb7 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -23,8 +23,6 @@ Currently, the Transformers modeling backend works for the following: - Architectures: encoder-only, decoder-only, mixture-of-experts - Attention types: full attention and/or sliding attention -_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._ - If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers modeling backend, it will be compatible with the following features of vLLM: - All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature) diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 9d77dee2810c..a7b6cefba848 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -33,7 +33,11 @@ MultiModalUUIDDict, PlaceholderRange, ) -from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems +from vllm.multimodal.parse import ( + ImageProcessorItems, + MultiModalDataItems, + VideoProcessorItems, +) from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -55,10 +59,13 @@ class MultiModalProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self): - return {"image": None} + return {"image": None, "video": None} def get_mm_max_tokens_per_item(self, seq_len, mm_counts): - return {"image": self.get_max_image_tokens()} + return { + "image": self.get_max_image_tokens(), + "video": self.get_max_video_tokens(seq_len), + } def get_max_image_tokens(self) -> int: width, height = self.get_max_image_size() @@ -71,20 +78,53 @@ def get_max_image_tokens(self) -> int: image_tokens = mm_tokens["num_image_tokens"][0] return image_tokens + def _get_video_tokens(self, num_frames, width, height) -> int: + processor = self.get_hf_processor() + multimodal_config = self.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} + mm_tokens = processor._get_num_multimodal_tokens( + video_sizes=([num_frames, height, width],), **mm_processor_kwargs + ) + video_tokens = mm_tokens["num_video_tokens"][0] + return video_tokens + + def get_max_video_tokens(self, seq_len: int) -> int: + width, height = self.get_max_image_size() + num_frames = self.get_max_video_frames(seq_len) + return self._get_video_tokens(num_frames, width, height) + def get_max_image_size(self): return 10_000, 10_000 # hardcode for arbitrary very large size + def get_max_video_frames(self, seq_len: int) -> int: + width, height = self.get_max_image_size() + + max_num_frames = 1 + + while True: + next_num_frames = max_num_frames + 1 + video_tokens = self._get_video_tokens(next_num_frames, width, height) + if video_tokens > seq_len: + break + + max_num_frames = next_num_frames + + return max_num_frames + class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) processor = self.info.get_hf_processor() if "gemma3" in processor.__class__.__name__.lower(): image_token = processor.boi_token + video_token = "" else: image_token = getattr(processor, "image_token", "") - return image_token * num_images + video_token = getattr(processor, "video_token", "") + return image_token * num_images + video_token * num_videos def get_dummy_mm_data( self, @@ -93,8 +133,11 @@ def get_dummy_mm_data( mm_options: Mapping[str, "BaseDummyOptions"] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) target_width, target_height = self.info.get_max_image_size() + max_total_frames = self.info.get_max_video_frames(seq_len) + target_num_frames = max_total_frames // max(num_videos, 1) image_overrides = mm_options.get("image") if mm_options else None @@ -105,6 +148,12 @@ def get_dummy_mm_data( num_images=num_images, overrides=image_overrides, ), + "video": self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ), } @@ -148,8 +197,18 @@ def _get_mm_fields_config( # Keep these as batched, as they always have batch size as first dim mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") - mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + mm_fields["pixel_values_videos"] = MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes + ) + mm_fields["video_embeds"] = MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes + ) + mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("video") + mm_fields["num_video_patches"] = MultiModalFieldConfig.batched("video") return mm_fields def _get_hf_mm_data( @@ -211,24 +270,36 @@ def apply( # We can infer vLLM style placeholder from token type ids, if we split # it for each input `mm_data`. - mm_positions = torch.where(mm_token_type_ids == 1)[1] - images = mm_items.get_items("image", ImageProcessorItems) + image_sizes = [] + if "image" in mm_items: + images = mm_items.get_items("image", ImageProcessorItems) + for item_idx in range(len(images)): + image_size = images.get_image_size(item_idx) + image_sizes.append((image_size.height, image_size.width)) + + video_sizes = [] + if "video" in mm_items: + videos = mm_items.get_items("video", VideoProcessorItems) + for item_idx in range(len(videos)): + video_size = videos.get_frame_size(item_idx) + num_frames = videos.get_num_frames(item_idx) + video_sizes.append((num_frames, video_size.height, video_size.width)) + multimodal_config = self.info.ctx.model_config.multimodal_config mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} - image_sizes = [] - for item_idx in range(len(images)): - image_size = images.get_image_size(item_idx) - image_sizes.append((image_size.height, image_size.width)) mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( - image_sizes=image_sizes, **mm_processor_kwargs + image_sizes=image_sizes, video_sizes=video_sizes, **mm_processor_kwargs ) mm_placeholders = {} + + # image_token_ids + mm_positions = torch.where(mm_token_type_ids == 1)[1] split_sizes = mm_tokens_per_modality["num_image_tokens"] if split_sizes: chunked_mm_positions = torch.split(mm_positions, split_sizes) - mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] + mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0] == 1] chunked_mm_tokens = torch.split(mm_tokens, split_sizes) ranges = [ PlaceholderRange( @@ -238,11 +309,34 @@ def apply( ) for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) ] - mm_placeholders = {"image": ranges} + mm_placeholders["image"] = ranges processed_data["num_image_patches"] = torch.tensor( mm_tokens_per_modality["num_image_patches"] ) + + # video_token_ids + mm_positions = torch.where(mm_token_type_ids == 2)[1] + + split_sizes = mm_tokens_per_modality["num_video_tokens"] + if split_sizes: + chunked_mm_positions = torch.split(mm_positions, split_sizes) + mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0] == 2] + chunked_mm_tokens = torch.split(mm_tokens, split_sizes) + ranges = [ + PlaceholderRange( + offset=positions[0].item(), + length=positions.shape[0], + is_embed=(mm_tokens == hf_processor.video_token_id).bool(), + ) + for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) + ] + mm_placeholders["video"] = ranges + + processed_data["num_video_patches"] = torch.tensor( + mm_tokens_per_modality["num_video_patches"] + ) + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( processed_data, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), @@ -333,21 +427,29 @@ def __init__(self, multimodal_model): def embed_multimodal(self, **kwargs): pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None) image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None) + pixel_values_videos: torch.Tensor | None = kwargs.pop( + "pixel_values_videos", None + ) + video_embeds: torch.Tensor | None = kwargs.pop("video_embeds", None) + # Model might use `image_patches` instead of `pixel_values` if pixel_values is None: pixel_values = kwargs.pop("image_patches", None) - if image_embeds is not None: - return image_embeds + multimodal_embeddings: tuple[torch.Tensor, ...] = () - if pixel_values is None: - return None + if image_embeds is not None: + multimodal_embeddings += tuple(image_embeds) - num_image_patches = kwargs.pop("num_image_patches") kwargs.pop("token_type_ids", None) # used only in `forward` + if pixel_values is not None: + num_image_patches = kwargs.pop("num_image_patches") vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) + if isinstance(vision_embeddings, tuple): + # For qwen3 vl, The deepstack visual features are also returned + vision_embeddings = vision_embeddings[0] if isinstance(vision_embeddings, torch.Tensor): if vision_embeddings.ndim == 2: vision_embeddings = vision_embeddings.unsqueeze(0) @@ -362,8 +464,37 @@ def embed_multimodal(self, **kwargs): embed.flatten(start_dim=0, end_dim=-2) for embed in vision_embeddings ] + multimodal_embeddings += tuple(vision_embeddings) + + if video_embeds is not None: + multimodal_embeddings += tuple(video_embeds) + + if pixel_values_videos is not None: + num_video_patches = kwargs.pop("num_video_patches") + vision_embeddings = self.model.get_video_features( + pixel_values_videos, **kwargs + ) + + if isinstance(vision_embeddings, tuple): + # For qwen3 vl, The deepstack visual features are also returned + vision_embeddings = vision_embeddings[0] + if isinstance(vision_embeddings, torch.Tensor): + if vision_embeddings.ndim == 2: + vision_embeddings = vision_embeddings.unsqueeze(0) + + # Embeddings have to be 2D tensors of length `num_images` + # but transformers returns concat tensors if each patch + # is of different size. We split it back to make vLLM happy + vision_embeddings = torch.split( + vision_embeddings, num_video_patches.flatten().tolist() + ) + vision_embeddings = [ + embed.flatten(start_dim=0, end_dim=-2) + for embed in vision_embeddings + ] + multimodal_embeddings += tuple(vision_embeddings) - return vision_embeddings + return multimodal_embeddings def get_mrope_input_positions( self, @@ -386,18 +517,14 @@ def get_mrope_input_positions( if k not in {"image_grid_thw", "video_grid_thw"} ): raise NotImplementedError( - "Transformers modeling backend only supports images." + "Transformers modeling backend only supports images and videos." ) image_grid_thw = kwargs.get("image_grid_thw", []) video_grid_thw = kwargs.get("video_grid_thw", []) - image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( - image_grid_thw - ) - video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( - video_grid_thw - ) + image_grid_thw = torch.stack(image_grid_thw) if image_grid_thw else None + video_grid_thw = torch.stack(video_grid_thw) if video_grid_thw else None mrope_positions, mrope_position_delta = self.model.get_rope_index( input_ids=torch.tensor(input_tokens).unsqueeze(0), diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index a69afc3176ca..3ef256486b22 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -280,7 +280,7 @@ def get_frame_size(self, item_idx: int) -> ImageSize: if isinstance(image, PILImage.Image): return ImageSize(*image.size) if isinstance(image, (np.ndarray, torch.Tensor)): - _, h, w = image.shape + w, h, _ = image.shape return ImageSize(w, h) assert_never(image) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 978224faae65..178b49ffc218 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1118,7 +1118,18 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: assert mm_budget is not None dummy_modality = mm_budget.get_modality_with_max_tokens() - return self._get_mm_dummy_batch(dummy_modality, num_seqs) + + # TBD: + # The mm_dummy_batch below is only retrieved when + # supports_multimodal_raw_input_only is True. + # Currently, only the transform modeling backend and terratorch have + # supports_multimodal_raw_input_only as True. + # When testing the transform modeling backend, it was found that + # if num_seqs (usually the default 256) is passed in here, + # an OOM error occurs. + # It needs to be confirmed what value should be passed in here, + # for now it is fixed to 1. + return self._get_mm_dummy_batch(dummy_modality, 1) def _get_cumsum_and_arange( self, From c8b65f7e4008527d55cc48f71a8e0f7d6fc311c5 Mon Sep 17 00:00:00 2001 From: "chenkui.shen" Date: Tue, 16 Dec 2025 17:06:18 +0800 Subject: [PATCH 2/2] resolve some comments Signed-off-by: chenkui.shen --- .../models/transformers/multimodal.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index a7b6cefba848..b53de79c2521 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -22,7 +22,11 @@ import torch from vllm.config.utils import getattr_iter -from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsMRoPE, + SupportsMultiModal, +) from vllm.model_executor.models.utils import WeightsMapper from vllm.multimodal import MultiModalKwargsItems from vllm.multimodal.inputs import ( @@ -120,10 +124,9 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: processor = self.info.get_hf_processor() if "gemma3" in processor.__class__.__name__.lower(): image_token = processor.boi_token - video_token = "" else: image_token = getattr(processor, "image_token", "") - video_token = getattr(processor, "video_token", "") + video_token = getattr(processor, "video_token", "") return image_token * num_images + video_token * num_videos def get_dummy_mm_data( @@ -140,6 +143,7 @@ def get_dummy_mm_data( target_num_frames = max_total_frames // max(num_videos, 1) image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None return { "image": self._get_dummy_images( @@ -153,6 +157,7 @@ def get_dummy_mm_data( height=target_height, num_frames=target_num_frames, num_videos=num_videos, + overrides=video_overrides, ), } @@ -424,7 +429,7 @@ def __init__(self, multimodal_model): return LanguageModel(self) - def embed_multimodal(self, **kwargs): + def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None) image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None) pixel_values_videos: torch.Tensor | None = kwargs.pop( @@ -436,15 +441,16 @@ def embed_multimodal(self, **kwargs): if pixel_values is None: pixel_values = kwargs.pop("image_patches", None) - multimodal_embeddings: tuple[torch.Tensor, ...] = () + multimodal_embeddings: list[torch.Tensor] = [] if image_embeds is not None: - multimodal_embeddings += tuple(image_embeds) + multimodal_embeddings += image_embeds kwargs.pop("token_type_ids", None) # used only in `forward` + num_image_patches = kwargs.pop("num_image_patches", None) + num_video_patches = kwargs.pop("num_video_patches", None) if pixel_values is not None: - num_image_patches = kwargs.pop("num_image_patches") vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) if isinstance(vision_embeddings, tuple): @@ -464,13 +470,12 @@ def embed_multimodal(self, **kwargs): embed.flatten(start_dim=0, end_dim=-2) for embed in vision_embeddings ] - multimodal_embeddings += tuple(vision_embeddings) + multimodal_embeddings += vision_embeddings if video_embeds is not None: - multimodal_embeddings += tuple(video_embeds) + multimodal_embeddings += video_embeds if pixel_values_videos is not None: - num_video_patches = kwargs.pop("num_video_patches") vision_embeddings = self.model.get_video_features( pixel_values_videos, **kwargs ) @@ -492,7 +497,7 @@ def embed_multimodal(self, **kwargs): embed.flatten(start_dim=0, end_dim=-2) for embed in vision_embeddings ] - multimodal_embeddings += tuple(vision_embeddings) + multimodal_embeddings += vision_embeddings return multimodal_embeddings @@ -520,8 +525,8 @@ def get_mrope_input_positions( "Transformers modeling backend only supports images and videos." ) - image_grid_thw = kwargs.get("image_grid_thw", []) - video_grid_thw = kwargs.get("video_grid_thw", []) + image_grid_thw = kwargs.get("image_grid_thw", None) + video_grid_thw = kwargs.get("video_grid_thw", None) image_grid_thw = torch.stack(image_grid_thw) if image_grid_thw else None video_grid_thw = torch.stack(video_grid_thw) if video_grid_thw else None