Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions vllm/model_executor/models/ovis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
""" PyTorch Ovis model."""
import math
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -49,6 +49,7 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
Expand Down Expand Up @@ -201,25 +202,22 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return tokens


class OvisImagePatchInputs(TypedDict):
type: Literal["image_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""

indicator_tokens: torch.Tensor
class OvisImagePatchInputs(TensorSchema):
"""
Shape:
`(batch_size * (num_patches + 1))`
"""

patches_per_image: list[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
Dimensions:
- batch_patches: Batch size * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image
in the batch.
"""
type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor,
TensorShape("batch_patches", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int],
TensorShape("num_patches_per_image")]
# This is used to restore the first two dimensions of `flat_data`.


class VisualEmbedding(torch.nn.Embedding):
Expand Down Expand Up @@ -458,9 +456,12 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(pixel_values)}")

flat_data = flatten_bn(pixel_values, concat=True)
if flat_data.ndim >= 3:
flat_data = flat_data.flatten(start_dim=1)
return OvisImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
flat_data=flat_data,
patches_per_image=[
x.shape[0] for x in flatten_bn(pixel_values)
],
Expand Down