diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 41fd272397e6..f1bb18716b40 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -19,7 +19,7 @@ """ PyTorch Ovis model.""" import math from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -49,6 +49,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.ovis import OvisProcessor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import merge_multimodal_embeddings @@ -201,25 +202,22 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return tokens -class OvisImagePatchInputs(TypedDict): - type: Literal["image_patches"] - flat_data: torch.Tensor - """ - Shape: - `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` - """ - - indicator_tokens: torch.Tensor +class OvisImagePatchInputs(TensorSchema): """ - Shape: - `(batch_size * (num_patches + 1))` - """ - - patches_per_image: list[int] - """ - List of number of total patches for each image in the batch. - This is used to restore the first two dimensions of `flat_data`. + Dimensions: + - batch_patches: Batch size * number of patches + - patch_size: patch_size_x * patch_size_y * num_channels + - patch_indicators: Batch size * (number of patches + 1) + - patches_per_image: List of number of total patches for each image + in the batch. """ + type: Literal["image_patches"] + flat_data: Annotated[torch.Tensor, + TensorShape("batch_patches", "patch_size")] + indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")] + patches_per_image: Annotated[list[int], + TensorShape("num_patches_per_image")] + # This is used to restore the first two dimensions of `flat_data`. class VisualEmbedding(torch.nn.Embedding): @@ -458,9 +456,12 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of indicator_tokens. " f"Got type: {type(pixel_values)}") + flat_data = flatten_bn(pixel_values, concat=True) + if flat_data.ndim >= 3: + flat_data = flat_data.flatten(start_dim=1) return OvisImagePatchInputs( type="image_patches", - flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + flat_data=flat_data, patches_per_image=[ x.shape[0] for x in flatten_bn(pixel_values) ],