From 3ede00f74bf079d0a04a92ad7f4e9c6be18025b4 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 2 Aug 2025 12:52:24 +0800 Subject: [PATCH 1/7] fix minimax shape check Signed-off-by: Isotr0py <2037008807@qq.com> --- examples/offline_inference/vision_language.py | 35 +++++++++++++++++++ vllm/model_executor/models/minimax_vl_01.py | 5 +-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index a75b8e2b047d..da6a45bc2bed 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -766,6 +766,40 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData: return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6") +def run_minimax(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "MiniMaxAI/MiniMax-VL-01" + + # The configuration below has been confirmed to launch on a single L40 GPU. + engine_args = EngineArgs( + model=model_name, + max_model_len=14336, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [ + [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": question}], + } + ] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Mistral-3 HF-format def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1412,6 +1446,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "mantis": run_mantis, "minicpmo": run_minicpmo, "minicpmv": run_minicpmv, + "minimax": run_minimax, "mistral3": run_mistral3, "mllama": run_mllama, "molmo": run_molmo, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 62a7d37ec9d3..afd8f65b81a6 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -33,7 +33,8 @@ class MiniMaxVL01ImagePixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values: torch.Tensor """ - Shape: `(batch_size * num_images, num_channels, height, width)` + Shape: + `(batch_size * num_images * num_patches, num_channels, height, width)` Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. @@ -312,7 +313,7 @@ def _parse_and_validate_image_input( return MiniMaxVL01ImagePixelInputs( type="pixel_values", pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + flatten_bn(flatten_bn(pixel_values), concat=True)), ) if image_embeds is not None: From d5d4bfbe7318ffe311a7fd313879c72e9c7749c2 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 19 Aug 2025 12:46:59 +0800 Subject: [PATCH 2/7] fix Signed-off-by: Isotr0py --- vllm/model_executor/models/minimax_vl_01.py | 88 ++++++++++++++++++--- 1 file changed, 77 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 5369e40417f0..9efe7fc17751 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn from transformers import BatchFeature, PretrainedConfig +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn @@ -142,6 +144,7 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return { "pixel_values": MultiModalFieldConfig.batched("image"), + "image_sizes": MultiModalFieldConfig.batched("image"), "image_embeds": MultiModalFieldConfig.batched("image"), } @@ -240,7 +243,7 @@ def _image_pixels_to_features( ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) + image_features = tuple(vision_tower(p) for p in pixel_values) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -253,6 +256,58 @@ def select_features(leaf: torch.Tensor): json_map_leaves(select_features, image_features), ) + # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 + def pack_image_features(self, image_features: list[torch.Tensor], + image_sizes: torch.Tensor): + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = (self.config.vision_config.image_size // + self.config.vision_config.patch_size) + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with " + "the image size.") + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view(num_patch_height, + num_patch_width, height, + width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, + 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, + image_sizes[image_idx]) + + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1).to( + image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), + dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, + self.image_newline[None].to(image_feature)), + dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + return new_image_features + def _process_image_pixels( self, inputs: MiniMaxVL01ImagePixelInputs, @@ -260,7 +315,6 @@ def _process_image_pixels( assert self.vision_tower is not None pixel_values = inputs["pixel_values"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) def _process_image_input( @@ -282,38 +336,50 @@ def _process_image_input( image_embeds = self.multi_modal_projector(torch.cat(image_features)) image_embeds = torch.split(image_embeds, feature_sizes) - return image_embeds + image_sizes = image_input.get("image_sizes") + return self.pack_image_features(image_embeds, image_sizes) def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") + for x in data: + actual_dims = x.shape[1:] + if len(actual_dims) != len(expected_dims) or actual_dims[ + 0] != expected_dims[0] or any( + actual_dims[i] > expected_dims[i] + for i in range(1, 3)): + expected_expr = ("batch_size", *map(str, expected_dims)) + actual_expr = ("batch_size", *map(str, tuple(actual_dims))) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {actual_expr}.") return data def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None - if pixel_values is not None: + if pixel_values is not None and image_sizes is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + return MiniMaxVL01ImagePixelInputs( type="pixel_values", pixel_values=self._validate_pixel_values( - flatten_bn(flatten_bn(pixel_values), concat=True)), + flatten_bn(pixel_values)), + image_sizes=flatten_bn(image_sizes, concat=True), ) if image_embeds is not None: From 2f91b39d5ca8105d2fad6d78716641ed3484d336 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 19 Aug 2025 13:15:51 +0800 Subject: [PATCH 3/7] enable test Signed-off-by: Isotr0py --- examples/offline_inference/vision_language.py | 5 ++--- tests/models/multimodal/test_tensor_schema.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index b429a237f29c..5fa0f942f9ec 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -815,18 +815,17 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData: return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6") -def run_minimax(questions: list[str], modality: str) -> ModelRequestData: +def run_minimax_vl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "MiniMaxAI/MiniMax-VL-01" - # The configuration below has been confirmed to launch on a single L40 GPU. engine_args = EngineArgs( model=model_name, - max_model_len=14336, max_num_seqs=2, limit_mm_per_prompt={modality: 1}, trust_remote_code=True, + tensor_parallel_size=8, ) tokenizer = AutoTokenizer.from_pretrained(model_name) diff --git a/tests/models/multimodal/test_tensor_schema.py b/tests/models/multimodal/test_tensor_schema.py index 51e5b84b6c08..143b4c8fc8c4 100644 --- a/tests/models/multimodal/test_tensor_schema.py +++ b/tests/models/multimodal/test_tensor_schema.py @@ -30,7 +30,6 @@ ARCH_TO_SKIP = { "MolmoForCausalLM": "incompatible requirements", - "MiniMaxVL01ForConditionalGeneration": "broken model", } ARCH_NEEDS_EXTRAS = [ "InternVLChatModel", From ebe05d1089c0db9fd8eae36339754c63d04f486c Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 19 Aug 2025 13:25:48 +0800 Subject: [PATCH 4/7] use tensor schema Signed-off-by: Isotr0py --- vllm/model_executor/models/minimax_vl_01.py | 41 +++++++++++++-------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 9efe7fc17751..2018869cbd05 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import torch import torch.nn as nn @@ -19,6 +19,7 @@ from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -31,25 +32,36 @@ maybe_prefix, merge_multimodal_embeddings) -class MiniMaxVL01ImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class MiniMaxVL01ImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * num_patches, num_channels, height, width)` - - Note that `height` or `width` may be different per batch and image, + Dimensions: + - bn: Batch size * number of images + - np: Number of patches + 1 + - c: Number of channels (3) + - h: Height + - w: Width + + Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})] + image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + # This should be in `(height, width)` format. -class MiniMaxVL01ImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class MiniMaxVL01ImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs, @@ -377,8 +389,7 @@ def _parse_and_validate_image_input( return MiniMaxVL01ImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values)), + pixel_values=flatten_bn(pixel_values), image_sizes=flatten_bn(image_sizes, concat=True), ) From c5f3058d5b580f778b60add38e51d00c709e7e6b Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 19 Aug 2025 13:27:46 +0800 Subject: [PATCH 5/7] clean Signed-off-by: Isotr0py --- vllm/model_executor/models/minimax_vl_01.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 2018869cbd05..96b9751d3c9b 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -272,7 +272,6 @@ def select_features(leaf: torch.Tensor): def pack_image_features(self, image_features: list[torch.Tensor], image_sizes: torch.Tensor): new_image_features = [] - feature_lens = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] @@ -317,7 +316,6 @@ def pack_image_features(self, image_features: list[torch.Tensor], self.image_newline[None].to(image_feature)), dim=0) new_image_features.append(image_feature) - feature_lens.append(image_feature.size(0)) return new_image_features def _process_image_pixels( From 238fced5e27cbc4ae4d67c3a1ef86b8c8709b4c7 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 19 Aug 2025 14:01:02 +0800 Subject: [PATCH 6/7] update Signed-off-by: Isotr0py --- examples/offline_inference/vision_language.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 5fa0f942f9ec..75dd527b05bb 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -815,7 +815,7 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData: return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6") -def run_minimax_vl(questions: list[str], modality: str) -> ModelRequestData: +def run_minimax_vl_01(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "MiniMaxAI/MiniMax-VL-01" @@ -1496,7 +1496,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "mantis": run_mantis, "minicpmo": run_minicpmo, "minicpmv": run_minicpmv, - "minimax": run_minimax, + "minimax_vl_01": run_minimax_vl_01, "mistral3": run_mistral3, "mllama": run_mllama, "molmo": run_molmo, From aaae7cf5f3496ef28576e51771d20ec6061f9c76 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 19 Aug 2025 14:02:06 +0800 Subject: [PATCH 7/7] clean Signed-off-by: Isotr0py --- vllm/model_executor/models/minimax_vl_01.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 96b9751d3c9b..cc7db849a28b 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -349,24 +349,6 @@ def _process_image_input( image_sizes = image_input.get("image_sizes") return self.pack_image_features(image_embeds, image_sizes) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - for x in data: - actual_dims = x.shape[1:] - if len(actual_dims) != len(expected_dims) or actual_dims[ - 0] != expected_dims[0] or any( - actual_dims[i] > expected_dims[i] - for i in range(1, 3)): - expected_expr = ("batch_size", *map(str, expected_dims)) - actual_expr = ("batch_size", *map(str, tuple(actual_dims))) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {actual_expr}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: pixel_values = kwargs.pop("pixel_values", None)