diff --git a/vllm/transformers_utils/processors/nano_nemotron_vl.py b/vllm/transformers_utils/processors/nano_nemotron_vl.py index 42659c8c1430..6ad495235842 100644 --- a/vllm/transformers_utils/processors/nano_nemotron_vl.py +++ b/vllm/transformers_utils/processors/nano_nemotron_vl.py @@ -8,7 +8,6 @@ # -------------------------------------------------------- import math -import warnings from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -26,7 +25,7 @@ from vllm.model_executor.models.parakeet import ParakeetExtractor from vllm.multimodal.evs import compute_retained_tokens_count from vllm.multimodal.inputs import AudioItem -from vllm.multimodal.processing.processor import PromptUpdateDetails, _seq2tokens +from vllm.multimodal.processing.processor import PromptUpdateDetails from vllm.tokenizers.hf import HfTokenizer from .internvl import calculate_internvl_targets, get_internvl_target_ratios @@ -63,42 +62,50 @@ def calculate_timestamps( return timestamps -def input_conditioner(x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor): - return (x - norm_mean) / norm_std +@torch.compile(dynamic=True) +def _bicubic_resize_and_normalize( + tensor: torch.Tensor, + size: tuple[int, int] | None = None, + norm_mean: torch.Tensor | None = None, + norm_std: torch.Tensor | None = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Permute NHWC→NCHW, optional bicubic resize, rescale + normalize. + Input must be a raw 4-D **NHWC** tensor. -def _bicubic_from_ndarray( - array: npt.NDArray[Any], *, size: tuple[int, int] -) -> torch.Tensor: - """ - Convert a 4D NHWC ndarray to NCHW and interpolate with bicubic. - Suppresses PyTorch's non-writable NumPy warning because interpolate copies, - and torch.from_numpy(array) is discarded at the end of function scope. + *size*: target ``(H, W)``; skips interpolation when ``None``. + *norm_mean* / *norm_std*: when both provided, fused + ``(x/255 - mean) / std`` + dtype cast; otherwise ``x/255`` + cast. """ - - with warnings.catch_warnings(): - msg = "The given NumPy array is not writ.*" - # Apparently, different versions of PyTorch use writable or writeable. - warnings.filterwarnings("ignore", message=msg, category=UserWarning) - tensor = torch.from_numpy(array) - assert tensor.ndim == 4, f"{tensor.ndim=}" - tensor = tensor.permute(0, 3, 1, 2) - return ( - torch.nn.functional.interpolate( + tensor = tensor.permute(0, 3, 1, 2).to(dtype=torch.float32) + if size is not None: + tensor = torch.nn.functional.interpolate( tensor, size=size, mode="bicubic", align_corners=False, antialias=True ) - / 255.0 + if norm_mean is not None and norm_std is not None: + return ((tensor / 255.0 - norm_mean) / norm_std).to(dtype=dtype).contiguous() + return (tensor / 255.0).to(dtype=dtype).contiguous() + + +def _pil_to_nhwc_tensor(image: Image.Image) -> torch.Tensor: + """Convert a PIL image to a 4-D NHWC tensor suitable for compiled ops.""" + array = np.asarray( + image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8 ) + return torch.from_numpy(np.expand_dims(array, axis=0)) def dynamic_preprocess( - image, + image: Image.Image, *, - image_size=512, - max_num_tiles=12, - use_thumbnail=True, - idx=0, -): + image_size: int = 512, + max_num_tiles: int = 12, + use_thumbnail: bool = True, + norm_mean: torch.Tensor | None = None, + norm_std: torch.Tensor | None = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: orig_width, orig_height = image.size target_ratios = get_internvl_target_ratios(1, max_num_tiles) @@ -111,13 +118,15 @@ def dynamic_preprocess( use_thumbnail=False, ) - image = np.asarray( - image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8 - ) + tensor = _pil_to_nhwc_tensor(image) - image = np.expand_dims(image, axis=0) - - resized_img = _bicubic_from_ndarray(image, size=(target_height, target_width)) + resized_img = _bicubic_resize_and_normalize( + tensor, + size=(target_height, target_width), + norm_mean=norm_mean, + norm_std=norm_std, + dtype=dtype, + ) B, C, H, W = resized_img.shape hp, wp = H // image_size, W // image_size patches = ( @@ -127,30 +136,16 @@ def dynamic_preprocess( ) if use_thumbnail and patches.shape[0] > 1: - thumb = _bicubic_from_ndarray(image, size=(image_size, image_size)) + thumb = _bicubic_resize_and_normalize( + tensor, + size=(image_size, image_size), + norm_mean=norm_mean, + norm_std=norm_std, + dtype=dtype, + ) patches = torch.cat([patches, thumb], dim=0) - return list(patches) - - -def image_to_pixel_values( - image: Image.Image, - *, - input_size: int, - max_num: int, - use_thumbnail: bool, - idx: int, -) -> torch.Tensor: - images = dynamic_preprocess( - image, - image_size=input_size, - max_num_tiles=max_num, - use_thumbnail=use_thumbnail, - idx=idx, - ) - - pixel_values = torch.stack(images) - return pixel_values + return patches def _compute_aspect_preserving_size( @@ -233,14 +228,16 @@ def video_to_pixel_values( video_maintain_aspect_ratio: bool = False, patch_size: int = 16, downsample_ratio: float = 0.5, + norm_mean: torch.Tensor | None = None, + norm_std: torch.Tensor | None = None, + dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - # (num_frames, H, W, C) -> (num_frames, C, H, W) - video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2) + """Convert video ndarray (T, H, W, C) to normalized pixel tensor (T, C, H, W).""" + orig_h, orig_w = video.shape[1], video.shape[2] + size: tuple[int, int] | None = None if video_target_num_patches is not None: - # Resize to target patch count (aspect-preserving or square). - orig_h, orig_w = video_tensor.shape[2], video_tensor.shape[3] - target_w, target_h, _ = get_video_target_size_and_feature_size( + tw, th, _ = get_video_target_size_and_feature_size( orig_w=orig_w, orig_h=orig_h, target_patches=video_target_num_patches, @@ -248,14 +245,13 @@ def video_to_pixel_values( patch_size=patch_size, downsample_ratio=downsample_ratio, ) - if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w: - return _bicubic_from_ndarray(video, size=(target_h, target_w)) - elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size: - return _bicubic_from_ndarray(video, size=(input_size, input_size)) + if orig_h != th or orig_w != tw: + size = (th, tw) + elif orig_h != input_size or orig_w != input_size: + size = (input_size, input_size) - video_tensor = video_tensor / 255.0 - - return video_tensor + tensor = torch.from_numpy(video) + return _bicubic_resize_and_normalize(tensor, size, norm_mean, norm_std, dtype) class DynamicResolutionImageTiler: @@ -343,6 +339,7 @@ def _images_to_pixel_values_lst( self, text_prompt_length: int, images: list[Image.Image], + dtype: torch.dtype = torch.float32, ) -> tuple[list[torch.Tensor], list[int]]: num_tokens_available = self.max_num_tokens_available(text_prompt_length) params_per_image = self.compute_params(images, num_tokens_available) @@ -350,7 +347,7 @@ def _images_to_pixel_values_lst( feature_sizes = [] images = [] for param in params_per_image: - for t in self.apply_params(param): + for t in self.apply_params(param, dtype=dtype): assert t.ndim == 3, f"{t.ndim=}: expected 3 dim tensor" images.append(t) feature_sizes.append(param.num_embeddings) @@ -363,17 +360,23 @@ class DynamicResolutionParams: num_embeddings: int patch_size: tuple[int, int] - def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]: + def apply_params( + self, + params: DynamicResolutionParams, + dtype: torch.dtype = torch.float32, + ) -> list[torch.Tensor]: target_size = ( params.patch_size[1] * self._patch_size, params.patch_size[0] * self._patch_size, ) - image = np.asarray( - params.media.convert("RGB") if params.media.mode != "RGB" else params.media, - dtype=np.uint8, + tensor = _pil_to_nhwc_tensor(params.media) + resized_img = _bicubic_resize_and_normalize( + tensor, + size=target_size, + norm_mean=self.norm_mean, + norm_std=self.norm_std, + dtype=dtype, ) - image = np.expand_dims(image, axis=0) - resized_img = _bicubic_from_ndarray(image, size=target_size) return list(resized_img) def process_media( @@ -619,6 +622,7 @@ def __init__( norm_mean=config.norm_mean, norm_std=config.norm_std, ) + self.dtype: torch.dtype = getattr(config, "dtype", torch.float32) @staticmethod def use_dynamic_resolution(config: PretrainedConfig) -> bool: @@ -662,14 +666,16 @@ def _images_to_pixel_values_lst( max_num_tiles: int, ) -> list[torch.Tensor]: return [ - image_to_pixel_values( + dynamic_preprocess( image, - input_size=self.image_size, - max_num=max_num_tiles, + image_size=self.image_size, + max_num_tiles=max_num_tiles, use_thumbnail=self.use_thumbnail, - idx=idx, + norm_mean=self.norm_mean, + norm_std=self.norm_std, + dtype=self.dtype, ) - for idx, image in enumerate(images) + for image in images ] def _preprocess_image( @@ -690,23 +696,22 @@ def _preprocess_image( pixel_values_lst, num_tokens_per_image = tiler._images_to_pixel_values_lst( text_prompt_length=text_prompt_length, images=images, + dtype=self.dtype, ) imgs_sizes = [(pv.shape[-2], pv.shape[-1]) for pv in pixel_values_lst] - normalized = [ - input_conditioner(img, tiler.norm_mean, tiler.norm_std) - for img in pixel_values_lst - ] image_num_patches = torch.tensor([1] * len(num_tokens_per_image)) image_inputs = { - "pixel_values_flat": normalized, + "pixel_values_flat": pixel_values_lst, "imgs_sizes": imgs_sizes, "num_tokens_per_image": num_tokens_per_image, } else: pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) image_num_patches = torch.tensor([len(item) for item in pixel_values_lst]) - pixel_values_flat = input_conditioner( - torch.cat(pixel_values_lst), self.norm_mean, self.norm_std + pixel_values_flat = ( + torch.cat(pixel_values_lst) + if len(pixel_values_lst) > 1 + else pixel_values_lst[0] ) image_inputs = { "pixel_values_flat": pixel_values_flat, @@ -861,6 +866,8 @@ def image_token_id(self) -> int: def _videos_to_pixel_values_lst( self, videos: list[npt.NDArray], + *, + dtype: torch.dtype = torch.float32, ) -> list[torch.Tensor]: return [ video_to_pixel_values( @@ -870,6 +877,9 @@ def _videos_to_pixel_values_lst( video_maintain_aspect_ratio=self.video_maintain_aspect_ratio, patch_size=self.config.patch_size, downsample_ratio=self.config.downsample_ratio, + norm_mean=self.norm_mean, + norm_std=self.norm_std, + dtype=dtype, ) for video in videos ] @@ -884,8 +894,10 @@ def _preprocess_video( videos_lst = [v[0] for v in videos] video_metadata_lst = [v[1] for v in videos] + pixel_values_lst_video = self._videos_to_pixel_values_lst( videos_lst, + dtype=self.dtype, ) # We use frame duration in milliseconds (as integer) to ensure @@ -901,10 +913,15 @@ def _preprocess_video( metadata["frames_indices"] for metadata in video_metadata_lst ] video_num_patches = torch.tensor([len(item) for item in pixel_values_lst_video]) + + # Normalization already fused into resize above. + # Skip the torch.cat copy when there is exactly one video + if len(pixel_values_lst_video) == 1: + pixel_values_flat = pixel_values_lst_video[0] + else: + pixel_values_flat = torch.cat(pixel_values_lst_video) video_inputs = { - "pixel_values_flat_video": input_conditioner( - torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std - ), + "pixel_values_flat_video": pixel_values_flat, "video_num_patches": video_num_patches, "frames_indices": frames_indices_lst, "frame_duration_ms": torch.tensor(frame_duration_ms_lst), @@ -1176,20 +1193,21 @@ def get_video_repl( for i, _ in enumerate(tokens_per_frame) ] - # Tokenize frame separator independently - frame_separators_tokenized = [ - _seq2tokens(tokenizer, sep) for sep in frame_separators - ] + # Batch-tokenize all frame separators at once — the HuggingFace + # tokenizers Rust backend parallelizes batch encoding across threads. + batch_encoded = tokenizer( + frame_separators, + add_special_tokens=False, + return_attention_mask=False, + ) + frame_separators_tokenized: list[list[int]] = batch_encoded["input_ids"] # Tokenize each component independently to avoid tokenizer merging tokens # across boundaries. This ensures consistent tokenization regardless of # num_tokens_per_frame values. all_token_ids = [] for i, num_tokens in enumerate(tokens_per_frame): - frame_sep_token_ids = frame_separators_tokenized[i] - all_token_ids.extend(frame_sep_token_ids) - - # Add pre-tokenized special tokens + all_token_ids.extend(frame_separators_tokenized[i]) all_token_ids.extend(img_start_token_ids) all_token_ids.extend(img_context_token_ids * num_tokens) all_token_ids.extend(img_end_token_ids)