Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 116 additions & 98 deletions vllm/transformers_utils/processors/nano_nemotron_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# --------------------------------------------------------

import math
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
Expand All @@ -26,7 +25,7 @@
from vllm.model_executor.models.parakeet import ParakeetExtractor
from vllm.multimodal.evs import compute_retained_tokens_count
from vllm.multimodal.inputs import AudioItem
from vllm.multimodal.processing.processor import PromptUpdateDetails, _seq2tokens
from vllm.multimodal.processing.processor import PromptUpdateDetails
from vllm.tokenizers.hf import HfTokenizer

from .internvl import calculate_internvl_targets, get_internvl_target_ratios
Expand Down Expand Up @@ -63,42 +62,50 @@ def calculate_timestamps(
return timestamps


def input_conditioner(x: torch.Tensor, norm_mean: torch.Tensor, norm_std: torch.Tensor):
return (x - norm_mean) / norm_std
@torch.compile(dynamic=True)
Comment thread
milesial marked this conversation as resolved.
Comment thread
milesial marked this conversation as resolved.
def _bicubic_resize_and_normalize(
tensor: torch.Tensor,
size: tuple[int, int] | None = None,
norm_mean: torch.Tensor | None = None,
norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Permute NHWC→NCHW, optional bicubic resize, rescale + normalize.

Input must be a raw 4-D **NHWC** tensor.

def _bicubic_from_ndarray(
array: npt.NDArray[Any], *, size: tuple[int, int]
) -> torch.Tensor:
"""
Convert a 4D NHWC ndarray to NCHW and interpolate with bicubic.
Suppresses PyTorch's non-writable NumPy warning because interpolate copies,
and torch.from_numpy(array) is discarded at the end of function scope.
*size*: target ``(H, W)``; skips interpolation when ``None``.
*norm_mean* / *norm_std*: when both provided, fused
``(x/255 - mean) / std`` + dtype cast; otherwise ``x/255`` + cast.
"""

with warnings.catch_warnings():
msg = "The given NumPy array is not writ.*"
# Apparently, different versions of PyTorch use writable or writeable.
warnings.filterwarnings("ignore", message=msg, category=UserWarning)
tensor = torch.from_numpy(array)
assert tensor.ndim == 4, f"{tensor.ndim=}"
tensor = tensor.permute(0, 3, 1, 2)
return (
torch.nn.functional.interpolate(
tensor = tensor.permute(0, 3, 1, 2).to(dtype=torch.float32)
if size is not None:
tensor = torch.nn.functional.interpolate(
tensor, size=size, mode="bicubic", align_corners=False, antialias=True
)
/ 255.0
if norm_mean is not None and norm_std is not None:
return ((tensor / 255.0 - norm_mean) / norm_std).to(dtype=dtype).contiguous()
return (tensor / 255.0).to(dtype=dtype).contiguous()


def _pil_to_nhwc_tensor(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a 4-D NHWC tensor suitable for compiled ops."""
array = np.asarray(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
)
return torch.from_numpy(np.expand_dims(array, axis=0))


def dynamic_preprocess(
image,
image: Image.Image,
*,
image_size=512,
max_num_tiles=12,
use_thumbnail=True,
idx=0,
):
image_size: int = 512,
max_num_tiles: int = 12,
use_thumbnail: bool = True,
norm_mean: torch.Tensor | None = None,
norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
orig_width, orig_height = image.size

target_ratios = get_internvl_target_ratios(1, max_num_tiles)
Expand All @@ -111,13 +118,15 @@ def dynamic_preprocess(
use_thumbnail=False,
)

image = np.asarray(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
)
tensor = _pil_to_nhwc_tensor(image)

image = np.expand_dims(image, axis=0)

resized_img = _bicubic_from_ndarray(image, size=(target_height, target_width))
resized_img = _bicubic_resize_and_normalize(
tensor,
size=(target_height, target_width),
norm_mean=norm_mean,
norm_std=norm_std,
dtype=dtype,
)
B, C, H, W = resized_img.shape
hp, wp = H // image_size, W // image_size
patches = (
Expand All @@ -127,30 +136,16 @@ def dynamic_preprocess(
)

if use_thumbnail and patches.shape[0] > 1:
thumb = _bicubic_from_ndarray(image, size=(image_size, image_size))
thumb = _bicubic_resize_and_normalize(
tensor,
size=(image_size, image_size),
norm_mean=norm_mean,
norm_std=norm_std,
dtype=dtype,
)
patches = torch.cat([patches, thumb], dim=0)

return list(patches)


def image_to_pixel_values(
image: Image.Image,
*,
input_size: int,
max_num: int,
use_thumbnail: bool,
idx: int,
) -> torch.Tensor:
images = dynamic_preprocess(
image,
image_size=input_size,
max_num_tiles=max_num,
use_thumbnail=use_thumbnail,
idx=idx,
)

pixel_values = torch.stack(images)
return pixel_values
return patches


def _compute_aspect_preserving_size(
Expand Down Expand Up @@ -233,29 +228,30 @@ def video_to_pixel_values(
video_maintain_aspect_ratio: bool = False,
patch_size: int = 16,
downsample_ratio: float = 0.5,
norm_mean: torch.Tensor | None = None,
norm_std: torch.Tensor | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
# (num_frames, H, W, C) -> (num_frames, C, H, W)
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
"""Convert video ndarray (T, H, W, C) to normalized pixel tensor (T, C, H, W)."""
orig_h, orig_w = video.shape[1], video.shape[2]
size: tuple[int, int] | None = None

if video_target_num_patches is not None:
# Resize to target patch count (aspect-preserving or square).
orig_h, orig_w = video_tensor.shape[2], video_tensor.shape[3]
target_w, target_h, _ = get_video_target_size_and_feature_size(
tw, th, _ = get_video_target_size_and_feature_size(
orig_w=orig_w,
orig_h=orig_h,
target_patches=video_target_num_patches,
maintain_aspect_ratio=video_maintain_aspect_ratio,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
)
if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w:
return _bicubic_from_ndarray(video, size=(target_h, target_w))
elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
return _bicubic_from_ndarray(video, size=(input_size, input_size))
if orig_h != th or orig_w != tw:
size = (th, tw)
elif orig_h != input_size or orig_w != input_size:
size = (input_size, input_size)

video_tensor = video_tensor / 255.0

return video_tensor
tensor = torch.from_numpy(video)
return _bicubic_resize_and_normalize(tensor, size, norm_mean, norm_std, dtype)


class DynamicResolutionImageTiler:
Expand Down Expand Up @@ -343,14 +339,15 @@ def _images_to_pixel_values_lst(
self,
text_prompt_length: int,
images: list[Image.Image],
dtype: torch.dtype = torch.float32,
) -> tuple[list[torch.Tensor], list[int]]:
num_tokens_available = self.max_num_tokens_available(text_prompt_length)
params_per_image = self.compute_params(images, num_tokens_available)

feature_sizes = []
images = []
for param in params_per_image:
for t in self.apply_params(param):
for t in self.apply_params(param, dtype=dtype):
assert t.ndim == 3, f"{t.ndim=}: expected 3 dim tensor"
images.append(t)
feature_sizes.append(param.num_embeddings)
Expand All @@ -363,17 +360,23 @@ class DynamicResolutionParams:
num_embeddings: int
patch_size: tuple[int, int]

def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
def apply_params(
self,
params: DynamicResolutionParams,
dtype: torch.dtype = torch.float32,
) -> list[torch.Tensor]:
target_size = (
params.patch_size[1] * self._patch_size,
params.patch_size[0] * self._patch_size,
)
image = np.asarray(
params.media.convert("RGB") if params.media.mode != "RGB" else params.media,
dtype=np.uint8,
tensor = _pil_to_nhwc_tensor(params.media)
resized_img = _bicubic_resize_and_normalize(
tensor,
size=target_size,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=dtype,
)
image = np.expand_dims(image, axis=0)
resized_img = _bicubic_from_ndarray(image, size=target_size)
return list(resized_img)

def process_media(
Expand Down Expand Up @@ -619,6 +622,7 @@ def __init__(
norm_mean=config.norm_mean,
norm_std=config.norm_std,
)
self.dtype: torch.dtype = getattr(config, "dtype", torch.float32)

@staticmethod
def use_dynamic_resolution(config: PretrainedConfig) -> bool:
Expand Down Expand Up @@ -662,14 +666,16 @@ def _images_to_pixel_values_lst(
max_num_tiles: int,
) -> list[torch.Tensor]:
return [
image_to_pixel_values(
dynamic_preprocess(
image,
input_size=self.image_size,
max_num=max_num_tiles,
image_size=self.image_size,
max_num_tiles=max_num_tiles,
use_thumbnail=self.use_thumbnail,
idx=idx,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=self.dtype,
)
for idx, image in enumerate(images)
for image in images
]

def _preprocess_image(
Expand All @@ -690,23 +696,22 @@ def _preprocess_image(
pixel_values_lst, num_tokens_per_image = tiler._images_to_pixel_values_lst(
text_prompt_length=text_prompt_length,
images=images,
dtype=self.dtype,
)
imgs_sizes = [(pv.shape[-2], pv.shape[-1]) for pv in pixel_values_lst]
normalized = [
input_conditioner(img, tiler.norm_mean, tiler.norm_std)
for img in pixel_values_lst
]
image_num_patches = torch.tensor([1] * len(num_tokens_per_image))
image_inputs = {
"pixel_values_flat": normalized,
"pixel_values_flat": pixel_values_lst,
"imgs_sizes": imgs_sizes,
"num_tokens_per_image": num_tokens_per_image,
}
else:
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
image_num_patches = torch.tensor([len(item) for item in pixel_values_lst])
pixel_values_flat = input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
pixel_values_flat = (
torch.cat(pixel_values_lst)
if len(pixel_values_lst) > 1
else pixel_values_lst[0]
)
image_inputs = {
"pixel_values_flat": pixel_values_flat,
Expand Down Expand Up @@ -861,6 +866,8 @@ def image_token_id(self) -> int:
def _videos_to_pixel_values_lst(
self,
videos: list[npt.NDArray],
*,
dtype: torch.dtype = torch.float32,
) -> list[torch.Tensor]:
return [
video_to_pixel_values(
Expand All @@ -870,6 +877,9 @@ def _videos_to_pixel_values_lst(
video_maintain_aspect_ratio=self.video_maintain_aspect_ratio,
patch_size=self.config.patch_size,
downsample_ratio=self.config.downsample_ratio,
norm_mean=self.norm_mean,
norm_std=self.norm_std,
dtype=dtype,
)
for video in videos
]
Expand All @@ -884,8 +894,10 @@ def _preprocess_video(

videos_lst = [v[0] for v in videos]
video_metadata_lst = [v[1] for v in videos]

pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos_lst,
dtype=self.dtype,
)

# We use frame duration in milliseconds (as integer) to ensure
Expand All @@ -901,10 +913,15 @@ def _preprocess_video(
metadata["frames_indices"] for metadata in video_metadata_lst
]
video_num_patches = torch.tensor([len(item) for item in pixel_values_lst_video])

# Normalization already fused into resize above.
# Skip the torch.cat copy when there is exactly one video
if len(pixel_values_lst_video) == 1:
pixel_values_flat = pixel_values_lst_video[0]
else:
pixel_values_flat = torch.cat(pixel_values_lst_video)
video_inputs = {
"pixel_values_flat_video": input_conditioner(
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
),
"pixel_values_flat_video": pixel_values_flat,
"video_num_patches": video_num_patches,
"frames_indices": frames_indices_lst,
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
Expand Down Expand Up @@ -1176,20 +1193,21 @@ def get_video_repl(
for i, _ in enumerate(tokens_per_frame)
]

# Tokenize frame separator independently
frame_separators_tokenized = [
_seq2tokens(tokenizer, sep) for sep in frame_separators
]
# Batch-tokenize all frame separators at once — the HuggingFace
# tokenizers Rust backend parallelizes batch encoding across threads.
batch_encoded = tokenizer(
frame_separators,
add_special_tokens=False,
return_attention_mask=False,
)
frame_separators_tokenized: list[list[int]] = batch_encoded["input_ids"]

# Tokenize each component independently to avoid tokenizer merging tokens
# across boundaries. This ensures consistent tokenization regardless of
# num_tokens_per_frame values.
all_token_ids = []
for i, num_tokens in enumerate(tokens_per_frame):
frame_sep_token_ids = frame_separators_tokenized[i]
all_token_ids.extend(frame_sep_token_ids)

# Add pre-tokenized special tokens
all_token_ids.extend(frame_separators_tokenized[i])
all_token_ids.extend(img_start_token_ids)
all_token_ids.extend(img_context_token_ids * num_tokens)
all_token_ids.extend(img_end_token_ids)
Expand Down
Loading