Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions vllm/model_executor/models/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union)

import torch
Expand Down Expand Up @@ -32,6 +32,7 @@
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
Expand All @@ -42,16 +43,24 @@
from .vision import get_vision_encoder_info


class Mistral3ImagePixelInputs(TypedDict):
type: Literal["pixel_values_pixtral"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
class Mistral3ImagePixelInputs(TensorSchema):
"""
Shape: `(batch_size * num_images, num_channels, height, width)`

Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""

type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"

# Note that `height` or `width` may be different per batch and image,
# in which case the data is passed as a list instead of a batched tensor.
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
]


class Mistral3PatchMerger(nn.Module):
"""
Expand Down Expand Up @@ -456,19 +465,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])

if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")

return data

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
Expand Down