11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33from collections .abc import Iterable , Mapping
4- from typing import Literal , Optional , TypedDict , Union , cast
4+ from typing import Annotated , Literal , Optional , Union , cast
55
66import torch
77import torch .nn as nn
88from transformers import BatchFeature , PretrainedConfig
9+ from transformers .models .llava_next .modeling_llava_next import (
10+ get_anyres_image_grid_shape , unpad_image )
911
1012from vllm .config import VllmConfig
1113from vllm .model_executor .layers .activation import get_act_fn
1719from vllm .multimodal .inputs import MultiModalFieldConfig
1820from vllm .sequence import IntermediateTensors
1921from vllm .utils .jsontree import json_map_leaves
22+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
2023
2124from .clip import CLIPVisionModel
2225from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
2932 maybe_prefix , merge_multimodal_embeddings )
3033
3134
32- class MiniMaxVL01ImagePixelInputs (TypedDict ):
33- type : Literal ["pixel_values" ]
34- pixel_values : torch .Tensor
35+ class MiniMaxVL01ImagePixelInputs (TensorSchema ):
3536 """
36- Shape: `(batch_size * num_images, num_channels, height, width)`
37-
38- Note that `height` or `width` may be different per batch and image,
37+ Dimensions:
38+ - bn: Batch size * number of images
39+ - np: Number of patches + 1
40+ - c: Number of channels (3)
41+ - h: Height
42+ - w: Width
43+
44+ Note that `num_patches` may be different per batch and image,
3945 in which case the data is passed as a list instead of a batched tensor.
4046 """
47+ type : Literal ["pixel_values" ] = "pixel_values"
48+ pixel_values : Annotated [
49+ Union [torch .Tensor , list [torch .Tensor ]],
50+ TensorShape ("bn" , "np" , 3 , "h" , "w" , dynamic_dims = {"np" , "h" , "w" })]
4151
52+ image_sizes : Annotated [Optional [torch .Tensor ], TensorShape ("bn" , 2 )]
53+ # This should be in `(height, width)` format.
4254
43- class MiniMaxVL01ImageEmbeddingInputs (TypedDict ):
44- type : Literal ["image_embeds" ]
45- data : torch .Tensor
46- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
4755
48- `hidden_size` must match the hidden size of language model backbone.
56+ class MiniMaxVL01ImageEmbeddingInputs (TensorSchema ):
57+ """
58+ Dimensions:
59+ - bn: Batch size * number of images
60+ - ifs: Image feature size
61+ - hs: Hidden size (must match language model backbone)
4962 """
63+ type : Literal ["image_embeds" ] = "image_embeds"
64+ data : Annotated [torch .Tensor , TensorShape ("bn" , "ifs" , "hs" )]
5065
5166
5267MiniMaxVL01ImageInputs = Union [MiniMaxVL01ImagePixelInputs ,
@@ -141,6 +156,7 @@ def _get_mm_fields_config(
141156 ) -> Mapping [str , MultiModalFieldConfig ]:
142157 return {
143158 "pixel_values" : MultiModalFieldConfig .batched ("image" ),
159+ "image_sizes" : MultiModalFieldConfig .batched ("image" ),
144160 "image_embeds" : MultiModalFieldConfig .batched ("image" ),
145161 }
146162
@@ -239,7 +255,7 @@ def _image_pixels_to_features(
239255 ) -> Union [torch .Tensor , tuple [torch .Tensor , ...]]:
240256 # NOTE: we skip the step to select the vision feature layer since
241257 # this is already done inside the vision tower
242- image_features = vision_tower (pixel_values )
258+ image_features = tuple ( vision_tower (p ) for p in pixel_values )
243259
244260 def select_features (leaf : torch .Tensor ):
245261 return self ._select_image_features (
@@ -252,14 +268,63 @@ def select_features(leaf: torch.Tensor):
252268 json_map_leaves (select_features , image_features ),
253269 )
254270
271+ # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
272+ def pack_image_features (self , image_features : list [torch .Tensor ],
273+ image_sizes : torch .Tensor ):
274+ new_image_features = []
275+ for image_idx , image_feature in enumerate (image_features ):
276+ if image_feature .shape [0 ] > 1 :
277+ base_image_feature = image_feature [0 ]
278+ image_feature = image_feature [1 :]
279+ height = width = (self .config .vision_config .image_size //
280+ self .config .vision_config .patch_size )
281+ if height * width != base_image_feature .shape [0 ]:
282+ raise ValueError (
283+ "The number of patches is not consistent with "
284+ "the image size." )
285+ num_patch_height , num_patch_width = get_anyres_image_grid_shape (
286+ image_sizes [image_idx ],
287+ self .config .image_grid_pinpoints ,
288+ self .config .vision_config .image_size ,
289+ )
290+
291+ image_feature = image_feature .view (num_patch_height ,
292+ num_patch_width , height ,
293+ width , - 1 )
294+ image_feature = image_feature .permute (4 , 0 , 2 , 1 ,
295+ 3 ).contiguous ()
296+ image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
297+ image_feature = unpad_image (image_feature ,
298+ image_sizes [image_idx ])
299+
300+ image_feature = torch .cat (
301+ (
302+ image_feature ,
303+ self .image_newline [:, None , None ].expand (
304+ * image_feature .shape [:- 1 ], 1 ).to (
305+ image_feature .dtype ),
306+ ),
307+ dim = - 1 ,
308+ )
309+ image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
310+ image_feature = torch .cat ((base_image_feature , image_feature ),
311+ dim = 0 )
312+ else :
313+ image_feature = image_feature [0 ]
314+ image_feature = torch .cat (
315+ (image_feature ,
316+ self .image_newline [None ].to (image_feature )),
317+ dim = 0 )
318+ new_image_features .append (image_feature )
319+ return new_image_features
320+
255321 def _process_image_pixels (
256322 self ,
257323 inputs : MiniMaxVL01ImagePixelInputs ,
258324 ) -> Union [torch .Tensor , tuple [torch .Tensor , ...]]:
259325 assert self .vision_tower is not None
260326
261327 pixel_values = inputs ["pixel_values" ]
262-
263328 return self ._image_pixels_to_features (self .vision_tower , pixel_values )
264329
265330 def _process_image_input (
@@ -281,38 +346,31 @@ def _process_image_input(
281346
282347 image_embeds = self .multi_modal_projector (torch .cat (image_features ))
283348 image_embeds = torch .split (image_embeds , feature_sizes )
284- return image_embeds
285-
286- def _validate_pixel_values (self , data : torch .Tensor ) -> torch .Tensor :
287- h = w = self .config .vision_config .image_size
288- expected_dims = (3 , h , w )
289- actual_dims = tuple (data .shape [1 :])
290-
291- if actual_dims != expected_dims :
292- expected_expr = ("batch_size" , * map (str , expected_dims ))
293- raise ValueError (
294- f"The expected shape of pixel values is { expected_expr } . "
295- f"You supplied { tuple (data .shape )} ." )
296-
297- return data
349+ image_sizes = image_input .get ("image_sizes" )
350+ return self .pack_image_features (image_embeds , image_sizes )
298351
299352 def _parse_and_validate_image_input (
300353 self , ** kwargs : object ) -> Optional [MiniMaxVL01ImageInputs ]:
301354 pixel_values = kwargs .pop ("pixel_values" , None )
355+ image_sizes = kwargs .pop ("image_sizes" , None )
302356 image_embeds = kwargs .pop ("image_embeds" , None )
303357
304358 if pixel_values is None and image_embeds is None :
305359 return None
306360
307- if pixel_values is not None :
361+ if pixel_values is not None and image_sizes is not None :
308362 if not isinstance (pixel_values , (torch .Tensor , list )):
309363 raise ValueError ("Incorrect type of pixel values. "
310364 f"Got type: { type (pixel_values )} " )
311365
366+ if not isinstance (image_sizes , (torch .Tensor , list )):
367+ raise ValueError ("Incorrect type of image sizes. "
368+ f"Got type: { type (image_sizes )} " )
369+
312370 return MiniMaxVL01ImagePixelInputs (
313371 type = "pixel_values" ,
314- pixel_values = self . _validate_pixel_values (
315- flatten_bn (pixel_values , concat = True ) ),
372+ pixel_values = flatten_bn ( pixel_values ),
373+ image_sizes = flatten_bn (image_sizes , concat = True ),
316374 )
317375
318376 if image_embeds is not None :
0 commit comments