From a3bd4e2c89c261b50b44919066eccb92d43c6ee7 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenia Date: Wed, 8 Oct 2025 11:32:09 +0300 Subject: [PATCH 01/14] Allow passing "mm_processor_kwargs": dict(max_num_tiles=2), Signed-off-by: Eugene Khvedchenia --- .../model_executor/models/nano_nemotron_vl.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 039ffbddf8db..7c6b82a4b313 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -93,6 +93,7 @@ # Profiling MAX_FRAMES = 16 +DEFAULT_NUM_TILES = 12 class NanoNemotronVLImagePixelInputs(TypedDict): @@ -255,13 +256,19 @@ class BaseNanoNemotronVLProcessor(ABC): """ def __init__( - self, config: PretrainedConfig, tokenizer: AnyTokenizer, *args, **kwargs + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *args, + max_num_tiles: Optional[int] = None, + **kwargs, ) -> None: super().__init__() self.config = config self.tokenizer = tokenizer + self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES image_size: int = config.force_image_size patch_size: int = config.patch_size @@ -361,7 +368,7 @@ def __call__( ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: - max_num_tiles = 12 + max_num_tiles = self.max_num_tiles text, images = [self._make_batch_input(x) for x in (text, images)] @@ -390,6 +397,7 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *, + max_num_tiles: Optional[int] = None, min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, @@ -399,6 +407,7 @@ def __init__( super().__init__( config=config, tokenizer=tokenizer, + max_num_tiles=max_num_tiles, min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, @@ -506,7 +515,7 @@ def __call__( ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: - max_num_tiles = 12 + max_num_tiles = self.max_num_tiles text, images, videos = [ self._make_batch_input(x) for x in (text, images, videos) @@ -635,7 +644,7 @@ def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize: def get_max_image_tokens(self) -> int: processor = self.get_hf_processor() # Use default max_num_tiles for max tokens calculation - max_num_tiles = 12 + max_num_tiles = processor.max_num_tiles target_width, target_height = self.get_image_size_with_most_features( max_num_tiles ) @@ -768,7 +777,9 @@ def get_replacement_custom(item_idx: int): else: image_size = images.get_image_size(item_idx) # Extract max_num_tiles from kwargs, default to 12 - max_num_tiles = hf_processor_mm_kwargs.get("max_num_tiles", 12) + max_num_tiles = hf_processor_mm_kwargs.get( + "max_num_tiles", hf_processor.max_num_tiles + ) feature_size = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, From 16568945e0635f06580ae69554ac095b9f008fea Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenia Date: Wed, 8 Oct 2025 11:53:39 +0300 Subject: [PATCH 02/14] Ensure video modality always uses 1 tile (performance optimization) Signed-off-by: Eugene Khvedchenia --- vllm/model_executor/models/nano_nemotron_vl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 7c6b82a4b313..91dfa6735534 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -228,6 +228,8 @@ def video_to_pixel_values( max_num_tiles: int = 1, use_thumbnail: bool, ) -> torch.Tensor: + assert max_num_tiles == 1, "Video modality always uses one tile" + # Convert each frame to a single resized tile tensor consistent # with image path frames_tensors: list[torch.Tensor] = [] @@ -530,7 +532,7 @@ def __call__( text, video_inputs = self._preprocess_video( text=text, videos=videos, - max_num_tiles=max_num_tiles, + max_num_tiles=1, dynamic_image_size=dynamic_image_size, ) From e8562b8d2b94ec18e8004d7a29890e1193ed7d46 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenia Date: Sun, 12 Oct 2025 20:04:23 +0300 Subject: [PATCH 03/14] Cherry pick video loading support Signed-off-by: Eugene Khvedchenia --- vllm/multimodal/video.py | 72 ++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 400d6a6be9be..7e5ba917be71 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -6,7 +6,7 @@ from functools import partial from io import BytesIO from pathlib import Path -from typing import Any, Union +from typing import Any import numpy as np import numpy.typing as npt @@ -175,6 +175,16 @@ def load_bytes( max_duration: int = 300, **kwargs, ) -> tuple[npt.NDArray, dict[str, Any]]: + """ + Args: + num_frames (int): Maximum number of frames to load. + A total sampled number of frames will never be larger + than this value. Set it -1 to remove the upper limit. + + fps (int): Desired video sampling rate. A real samping + rate may be lower if we encounter long video and + num_frames upper limit is set to positive value. + """ import cv2 backend = cls().get_cv2_video_api() @@ -183,36 +193,42 @@ def load_bytes( raise ValueError("Could not open video stream") total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - original_fps = cap.get(cv2.CAP_PROP_FPS) - duration = total_frames_num / original_fps if original_fps > 0 else 0 + if total_frames_num == 0: + raise ValueError("CAP_PROP_FRAME_COUNT returned 0") - # resample video to target num_frames - max_frame_idx = total_frames_num - 1 - duration = duration or round(max_frame_idx / original_fps) + 1 - - # Refer to: - # https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140 - frame_indices: Union[range, list[int]] - if duration <= max_duration: - n = int(math.floor(duration * fps)) - frame_indices = sorted( - { - min(max_frame_idx, int(math.ceil(i * original_fps / fps))) - for i in range(n) - } + original_fps = cap.get(cv2.CAP_PROP_FPS) + if not (original_fps > 0): + print( + f"WARNING: CAP_PROP_FPS returned {original_fps}. " + f"We will use 30 FPS as default fallback." ) + original_fps = 30 + + duration = total_frames_num / original_fps + + # Determine target number of samples + if num_frames > 0: + # Hard upper bound + max_samples = int(num_frames) else: - num_samples = int(max_duration * fps) - if num_samples >= total_frames_num: - frame_indices = range(total_frames_num) - else: - target_seconds = np.linspace(0, duration, num_samples, endpoint=True) - frame_indices = sorted( - { - min(max_frame_idx, int(math.ceil(t * original_fps))) - for t in target_seconds - } - ) + # No cap -> sample at desired fps + max_samples = int(max(1, math.floor(duration * fps))) + + # Clamp to available frames if count is known + max_samples = max(1, min(max_samples, total_frames_num)) + + # Uniform coverage of the entire timeline within the cap + # Use linspace over [0, total_frames-1] + raw = np.linspace(0, total_frames_num - 1, max_samples, endpoint=True) + frame_indices = np.unique(raw.round().astype(int)).tolist() + + effective_fps = len(frame_indices) / duration + print( + f"Video [{total_frames_num} fames]({duration:.2f}sec " + f"at {original_fps:.2f}fps) sampled " + f"into frame [{len(frame_indices)}] indexes {frame_indices} " + f"at {effective_fps:.2f}fps." + ) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) From 5809b49ed2751ccbcd0950aca3b9fc2632735b0f Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenia Date: Sat, 11 Oct 2025 23:32:52 +0300 Subject: [PATCH 04/14] Cherry-pick image normalization Signed-off-by: Eugene Khvedchenia --- .../model_executor/models/nano_nemotron_vl.py | 12 +++++- vllm/model_executor/models/radio.py | 38 +++---------------- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 91dfa6735534..dfe77ee22fda 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -248,6 +248,10 @@ def video_to_pixel_values( return torch.stack(frames_tensors) +def input_conditioner(x, norm_mean, norm_std): + return (x - norm_mean) / norm_std + + class BaseNanoNemotronVLProcessor(ABC): """ This model doesn't define its own HF processor, @@ -341,7 +345,9 @@ def _preprocess_image( else: pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) image_inputs = { - "pixel_values_flat": torch.cat(pixel_values_lst), + "pixel_values_flat": input_conditioner( + torch.cat(pixel_values_lst), self.norm_mean, self.norm_std + ), "image_num_patches": torch.tensor( [len(item) for item in pixel_values_lst] ), @@ -465,7 +471,9 @@ def _preprocess_video( ) video_inputs = { - "pixel_values_flat_video": torch.cat(pixel_values_lst_video), + "pixel_values_flat_video": input_conditioner( + torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std + ), "video_num_patches": torch.tensor( [len(item) for item in pixel_values_lst_video] ), diff --git a/vllm/model_executor/models/radio.py b/vllm/model_executor/models/radio.py index 2313b98348b7..03d56eead6e0 100644 --- a/vllm/model_executor/models/radio.py +++ b/vllm/model_executor/models/radio.py @@ -43,32 +43,6 @@ def parse(x): to_ntuple = _ntuple -class InputConditioner(nn.Module): - def __init__( - self, - input_scale: float, - norm_mean: norm_t, - norm_std: norm_t, - dtype: torch.dtype = None, - ): - super().__init__() - - self.dtype = dtype - - self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) - self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) - - def forward(self, x: torch.Tensor): - y = (x - self.norm_mean) / self.norm_std - if self.dtype is not None: - y = y.to(self.dtype) - return y - - -def _to_tensor(v: norm_t): - return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) - - class ClsToken(nn.Module): def __init__( self, @@ -507,11 +481,6 @@ def __init__( super().__init__() self.config = config - self.input_conditioner = InputConditioner( - input_scale=1.0, - norm_mean=config.norm_mean, - norm_std=config.norm_std, - ) self.model = RadioInternVisionModel( config=config, quant_config=quant_config, @@ -525,8 +494,7 @@ def forward( pixel_values: Optional[torch.Tensor] = None, pixel_embeds: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: - x = self.input_conditioner(pixel_values) - y = self.model(x) + y = self.model(pixel_values) return self._extract_final(y) def load_weights(self, weights) -> set[str]: @@ -548,6 +516,10 @@ def load_weights(self, weights) -> set[str]: # Skip buffers not used in vLLM if sub in {"summary_idxs"}: continue + if sub.startswith("input_conditioner."): + # we normalize in the input processor, + # based on norm and std values from the config + continue vllm_key = None if sub.startswith("model.patch_generator."): From 44f364be8f869ccabbf3349baa04343c097f0b57 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenia Date: Sun, 12 Oct 2025 20:04:23 +0300 Subject: [PATCH 05/14] Add video timestamps support Fix image normalization precision bug Fix possible inconsistent tokenization of frame prefixes Change logic of video sampling in OpenCVDynamicBackend Signed-off-by: Eugene Khvedchenia --- .../model_executor/models/nano_nemotron_vl.py | 283 +++++++++++++++--- vllm/model_executor/models/radio.py | 38 +-- vllm/model_executor/models/utils.py | 1 - vllm/multimodal/video.py | 72 +++-- 4 files changed, 293 insertions(+), 101 deletions(-) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 91dfa6735534..e9d757a78e04 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -14,6 +14,7 @@ from typing import Annotated, Any, Literal, Optional, TypedDict, TypeVar, Union import numpy.typing as npt +import regex as re import torch import torch.nn as nn import torchvision.transforms as T @@ -21,7 +22,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -54,12 +55,14 @@ MultiModalFieldConfig, MultiModalKwargs, MultiModalKwargsItems, + VideoItem, ) from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, + MultiModalDataParser, ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, @@ -128,7 +131,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema): """ Dimensions: - bvf: Batch size * number of videos * num_frames - - bn: Batch size * number of images + - bn: Batch size * number of videos + - f: Number of frames - c: Number of channels (3) - h: Height of each video frame - w: Width of each video frame @@ -137,6 +141,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] + frames_indices: Annotated[torch.Tensor, TensorShape("bvf")] + frame_duration_ms: Annotated[torch.Tensor, TensorShape("bn")] class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): @@ -248,6 +254,21 @@ def video_to_pixel_values( return torch.stack(frames_tensors) +def input_conditioner(x, norm_mean, norm_std): + return (x - norm_mean) / norm_std + + +def calculate_timestamps( + indices: list[int] | torch.Tensor, + frame_duration_ms: int, +): + if not isinstance(indices, list): + indices = indices.tolist() + + timestamps = [int(i) * frame_duration_ms / 1000.0 for i in indices] + return timestamps + + class BaseNanoNemotronVLProcessor(ABC): """ This model doesn't define its own HF processor, @@ -341,17 +362,30 @@ def _preprocess_image( else: pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) image_inputs = { - "pixel_values_flat": torch.cat(pixel_values_lst), + "pixel_values_flat": input_conditioner( + torch.cat(pixel_values_lst), self.norm_mean, self.norm_std + ), "image_num_patches": torch.tensor( [len(item) for item in pixel_values_lst] ), } - for pixel_values in pixel_values_lst: + assert len(text) == 1, ( + "hf_processor is called on the output of get_dummy_text, " + "which should be a single string" + ) + parts = [x for x in re.split(r"()", text[0]) if x] + assert parts.count("") == len(pixel_values_lst), ( + "the number of tokens in the text should be the " + "same as the number of images" + ) + + for i, pixel_values in enumerate(pixel_values_lst): num_patches = pixel_values.shape[0] feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace("", image_repl.full, 1) for t in text] + parts[i] = parts[i].replace("", image_repl.full) + text = ["".join(parts)] return text, image_inputs def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): @@ -418,6 +452,18 @@ def __init__( self.video_token = video_token self.video_pruning_rate = video_pruning_rate + # Pre-tokenize special tokens for video processing + # to avoid repeated tokenization + self._img_start_token_ids = encode_tokens( + tokenizer, IMG_START, add_special_tokens=False + ) + self._img_end_token_ids = encode_tokens( + tokenizer, IMG_END, add_special_tokens=False + ) + self._img_context_token_ids = encode_tokens( + tokenizer, IMG_CONTEXT, add_special_tokens=False + ) + @property def supports_video(self) -> bool: return self.video_token_id is not None @@ -451,24 +497,43 @@ def _videos_to_pixel_values_lst( def _preprocess_video( self, text: list[str], - videos: list[npt.NDArray], + videos: list[tuple[npt.NDArray, dict[str, Any]]], max_num_tiles: int, dynamic_image_size: Optional[bool] = None, ): if len(videos) == 0 or not self.supports_video: video_inputs = {} else: + videos_lst = [v[0] for v in videos] + video_metadata_lst = [v[1] for v in videos] pixel_values_lst_video = self._videos_to_pixel_values_lst( - videos, + videos_lst, max_num_tiles=max_num_tiles, dynamic_image_size=dynamic_image_size, ) + # We use frame duration in milliseconds (as integer) to ensure + # we have consistent timestamps calculation. At preprocessing + # fps parameter is given in fp32, while at inference it is bf16 + # which leads to inaccurate timestamp calculation and causes + # timestamp values to differ.In rare cases this causes + # mismatching number of output tokens for tokenized frame prefixes + frame_duration_ms_lst = [ + int(1000.0 / metadata["fps"]) for metadata in video_metadata_lst + ] + frames_indices_lst = [ + metadata["frames_indices"] for metadata in video_metadata_lst + ] + video_inputs = { - "pixel_values_flat_video": torch.cat(pixel_values_lst_video), + "pixel_values_flat_video": input_conditioner( + torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std + ), "video_num_patches": torch.tensor( [len(item) for item in pixel_values_lst_video] ), + "frames_indices": frames_indices_lst, + "frame_duration_ms": torch.tensor(frame_duration_ms_lst), } image_size: int = self.config.force_image_size @@ -478,7 +543,12 @@ def _preprocess_video( (image_size * image_size // patch_size**2) * (downsample_ratio**2) ) - for pixel_values in pixel_values_lst_video: + for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip( + pixel_values_lst_video, + video_metadata_lst, + frames_indices_lst, + frame_duration_ms_lst, + ): num_frames = pixel_values.shape[0] if ( @@ -501,16 +571,29 @@ def _preprocess_video( else: tokens_per_frame = [tokens_in_single_frame] * num_frames - video_repl = self.get_video_repl(tokens_per_frame, self.video_token) + video_repl = self.get_video_repl( + tokens_per_frame=tokens_per_frame, + frames_indices=frames_indices, + frame_duration_ms=frame_duration_ms, + tokenizer=self.tokenizer, + img_start_token_ids=self._img_start_token_ids, + img_end_token_ids=self._img_end_token_ids, + img_context_token_ids=self._img_context_token_ids, + ) - text = [t.replace("