29
29
from .siglip import (SiglipVisionModel , dummy_image_for_siglip ,
30
30
dummy_seq_data_for_siglip , get_siglip_image_feature_size ,
31
31
get_siglip_patch_grid_length , input_processor_for_siglip )
32
- from .utils import (filter_weights , init_vllm_registered_model ,
32
+ from .utils import (filter_weights , flatten_bn , init_vllm_registered_model ,
33
33
merge_multimodal_embeddings )
34
34
35
35
logger = init_logger (__name__ )
@@ -47,15 +47,16 @@ class LlavaNextImagePixelInputs(TypedDict):
47
47
type : Literal ["pixel_values" ]
48
48
data : Union [torch .Tensor , List [torch .Tensor ]]
49
49
"""
50
- Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
50
+ Shape:
51
+ `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
51
52
52
- Note that `num_patches` may be different for each batch, in which case
53
- the data is passed as a list instead of a batched tensor.
53
+ Note that `num_patches` may be different per batch and image,
54
+ in which case the data is passed as a list instead of a batched tensor.
54
55
"""
55
56
56
57
image_sizes : NotRequired [torch .Tensor ]
57
58
"""
58
- Shape: `(batch_size, 2)`
59
+ Shape: `(batch_size * num_images , 2)`
59
60
60
61
This should be in `(height, width)` format.
61
62
"""
@@ -64,7 +65,7 @@ class LlavaNextImagePixelInputs(TypedDict):
64
65
class LlavaNextImageEmbeddingInputs (TypedDict ):
65
66
type : Literal ["image_embeds" ]
66
67
data : torch .Tensor
67
- """Shape: `(batch_size, image_feature_size, hidden_size)`
68
+ """Shape: `(batch_size * num_images , image_feature_size, hidden_size)`
68
69
69
70
`hidden_size` must match the hidden size of language model backbone.
70
71
"""
@@ -315,10 +316,19 @@ def __init__(self,
315
316
torch .empty (config .text_config .hidden_size ))
316
317
317
318
def _validate_image_sizes (self , data : torch .Tensor ) -> torch .Tensor :
318
- if list (data .shape [1 :]) != [2 ]:
319
- raise ValueError (
320
- f"The expected image sizes shape is batch dimension plus "
321
- f"{ [2 ]} . You supplied { data .shape } ." )
319
+ expected_dims = (2 , )
320
+
321
+ def _validate_shape (d : torch .Tensor ):
322
+ actual_dims = tuple (d .shape )
323
+
324
+ if actual_dims != expected_dims :
325
+ expected_expr = str (expected_dims )
326
+ raise ValueError (
327
+ f"The expected shape of image sizes per image per batch "
328
+ f"is { expected_expr } . You supplied { tuple (d .shape )} ." )
329
+
330
+ for d in data :
331
+ _validate_shape (d )
322
332
323
333
return data
324
334
@@ -335,7 +345,7 @@ def _validate_shape(d: torch.Tensor):
335
345
if actual_dims != expected_dims :
336
346
expected_expr = ("num_patches" , * map (str , expected_dims ))
337
347
raise ValueError (
338
- "The expected shape of pixel values in each batch element "
348
+ "The expected shape of pixel values per image per batch "
339
349
f"is { expected_expr } . You supplied { tuple (d .shape )} ." )
340
350
341
351
for d in data :
@@ -357,35 +367,25 @@ def _parse_and_validate_image_input(
357
367
raise ValueError ("Incorrect type of pixel values. "
358
368
f"Got type: { type (pixel_values )} " )
359
369
360
- if not isinstance (image_sizes , torch .Tensor ):
370
+ if not isinstance (image_sizes , ( torch .Tensor , list ) ):
361
371
raise ValueError ("Incorrect type of image sizes. "
362
372
f"Got type: { type (image_sizes )} " )
363
373
364
- # Remove the N dimension until multiple images are supported.
365
- if isinstance (pixel_values , torch .Tensor ):
366
- pixel_values = pixel_values .squeeze (1 )
367
- else :
368
- pixel_values = [t .squeeze (0 ) for t in pixel_values ]
369
-
370
- image_sizes = image_sizes .squeeze (1 )
371
-
372
374
return LlavaNextImagePixelInputs (
373
375
type = "pixel_values" ,
374
- data = self ._validate_pixel_values (pixel_values ),
375
- image_sizes = self ._validate_image_sizes (image_sizes ),
376
+ data = self ._validate_pixel_values (flatten_bn (pixel_values )),
377
+ image_sizes = self ._validate_image_sizes (
378
+ flatten_bn (image_sizes , concat = True )),
376
379
)
377
380
378
381
if image_embeds is not None :
379
382
if not isinstance (image_embeds , torch .Tensor ):
380
383
raise ValueError ("Incorrect type of image embeds. "
381
384
f"Got type: { type (image_embeds )} " )
382
385
383
- # Remove the N dimension until multiple images are supported.
384
- image_embeds = image_embeds .squeeze (1 )
385
-
386
386
return LlavaNextImageEmbeddingInputs (
387
387
type = "image_embeds" ,
388
- data = image_embeds ,
388
+ data = flatten_bn ( image_embeds ) ,
389
389
)
390
390
391
391
raise AssertionError ("This line should be unreachable." )
0 commit comments