Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,39 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")


def run_minimax_vl_01(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"

model_name = "MiniMaxAI/MiniMax-VL-01"

engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
tensor_parallel_size=8,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [
[
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": question}],
}
]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# Mistral-3 HF-format
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
Expand Down Expand Up @@ -1463,6 +1496,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"mantis": run_mantis,
"minicpmo": run_minicpmo,
"minicpmv": run_minicpmv,
"minimax_vl_01": run_minimax_vl_01,
"mistral3": run_mistral3,
"mllama": run_mllama,
"molmo": run_molmo,
Expand Down
1 change: 0 additions & 1 deletion tests/models/multimodal/test_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"MiniMaxVL01ForConditionalGeneration": "broken model",
}
ARCH_NEEDS_EXTRAS = [
"InternVLChatModel",
Expand Down
120 changes: 89 additions & 31 deletions vllm/model_executor/models/minimax_vl_01.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union, cast
from typing import Annotated, Literal, Optional, Union, cast

import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)

from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
Expand All @@ -17,6 +19,7 @@
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
Expand All @@ -29,24 +32,36 @@
maybe_prefix, merge_multimodal_embeddings)


class MiniMaxVL01ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
class MiniMaxVL01ImagePixelInputs(TensorSchema):
"""
Shape: `(batch_size * num_images, num_channels, height, width)`

Note that `height` or `width` may be different per batch and image,
Dimensions:
- bn: Batch size * number of images
- np: Number of patches + 1
- c: Number of channels (3)
- h: Height
- w: Width

Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})]

image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
# This should be in `(height, width)` format.

class MiniMaxVL01ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`

`hidden_size` must match the hidden size of language model backbone.
class MiniMaxVL01ImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]


MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs,
Expand Down Expand Up @@ -141,6 +156,7 @@ def _get_mm_fields_config(
) -> Mapping[str, MultiModalFieldConfig]:
return {
"pixel_values": MultiModalFieldConfig.batched("image"),
"image_sizes": MultiModalFieldConfig.batched("image"),
"image_embeds": MultiModalFieldConfig.batched("image"),
}

Expand Down Expand Up @@ -239,7 +255,7 @@ def _image_pixels_to_features(
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features = tuple(vision_tower(p) for p in pixel_values)

def select_features(leaf: torch.Tensor):
return self._select_image_features(
Expand All @@ -252,14 +268,63 @@ def select_features(leaf: torch.Tensor):
json_map_leaves(select_features, image_features),
)

# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
def pack_image_features(self, image_features: list[torch.Tensor],
image_sizes: torch.Tensor):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = (self.config.vision_config.image_size //
self.config.vision_config.patch_size)
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with "
"the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)

image_feature = image_feature.view(num_patch_height,
num_patch_width, height,
width, -1)
image_feature = image_feature.permute(4, 0, 2, 1,
3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature,
image_sizes[image_idx])

image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1).to(
image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature),
dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature,
self.image_newline[None].to(image_feature)),
dim=0)
new_image_features.append(image_feature)
return new_image_features

def _process_image_pixels(
self,
inputs: MiniMaxVL01ImagePixelInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None

pixel_values = inputs["pixel_values"]

return self._image_pixels_to_features(self.vision_tower, pixel_values)

def _process_image_input(
Expand All @@ -281,38 +346,31 @@ def _process_image_input(

image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds

def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])

if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")

return data
image_sizes = image_input.get("image_sizes")
return self.pack_image_features(image_embeds, image_sizes)

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None and image_embeds is None:
return None

if pixel_values is not None:
if pixel_values is not None and image_sizes is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")

return MiniMaxVL01ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
)

if image_embeds is not None:
Expand Down