Skip to content

Commit 905569b

Browse files
DarkLight1337root
authored and
root
committed
[Bugfix][VLM] Fix incompatibility between vllm-project#7902 and vllm-project#7230 (vllm-project#7948)
1 parent 418fb05 commit 905569b

File tree

10 files changed

+120
-92
lines changed

10 files changed

+120
-92
lines changed

vllm/model_executor/models/blip2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@
4040
class Blip2ImagePixelInputs(TypedDict):
4141
type: Literal["pixel_values"]
4242
data: torch.Tensor
43-
"""Shape: (batch_size, num_channels, height, width)"""
43+
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
4444

4545

4646
class Blip2ImageEmbeddingInputs(TypedDict):
4747
type: Literal["image_embeds"]
4848
data: torch.Tensor
49-
"""Shape: `(batch_size, image_feature_size, hidden_size)`
49+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
5050
5151
`hidden_size` must match the hidden size of language model backbone.
5252
"""

vllm/model_executor/models/chameleon.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
class ChameleonImagePixelInputs(TypedDict):
5454
type: Literal["pixel_values"]
5555
data: torch.Tensor
56-
"""Shape: `(batch_size, num_channels, height, width)`"""
56+
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
5757

5858

5959
def get_max_chameleon_image_tokens(ctx: InputContext):

vllm/model_executor/models/internvl.py

+15-31
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
3030
get_clip_num_patches)
3131
from .interfaces import SupportsMultiModal
32-
from .utils import (filter_weights, init_vllm_registered_model,
32+
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
3333
merge_multimodal_embeddings)
3434

3535
IMG_START = '<img>'
@@ -42,19 +42,17 @@
4242

4343
class InternVLImagePixelInputs(TypedDict):
4444
type: Literal["pixel_values"]
45-
data: Union[torch.Tensor, List[torch.Tensor]]
45+
data: torch.Tensor
4646
"""
47-
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
48-
49-
Note that `num_patches` may be different for each batch, in which case
50-
the data is passed as a list instead of a batched tensor.
47+
Shape:
48+
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
5149
"""
5250

5351

5452
class InternVLImageEmbeddingInputs(TypedDict):
5553
type: Literal["image_embeds"]
56-
data: Union[torch.Tensor, List[torch.Tensor]]
57-
"""Shape: `(batch_size, image_feature_size, hidden_size)`
54+
data: torch.Tensor
55+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
5856
5957
`hidden_size` must match the hidden size of language model backbone.
6058
"""
@@ -357,7 +355,7 @@ def pixel_shuffle(self, x, scale_factor=0.5):
357355
x = x.permute(0, 2, 1, 3).contiguous()
358356
return x
359357

360-
def extract_feature(self, pixel_values):
358+
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
361359
vit_embeds = self.vision_model(pixel_values=pixel_values)
362360
vit_embeds = vit_embeds[:, 1:, :]
363361

@@ -370,17 +368,7 @@ def extract_feature(self, pixel_values):
370368
vit_embeds = self.mlp1(vit_embeds)
371369
return vit_embeds
372370

373-
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
374-
if list(data.shape[1:]) != [2]:
375-
raise ValueError(
376-
f"The expected image sizes shape is batch dimension plus "
377-
f"{[2]}. You supplied {data.shape}.")
378-
379-
return data
380-
381-
def _validate_pixel_values(
382-
self, data: Union[torch.Tensor, List[torch.Tensor]]
383-
) -> Union[torch.Tensor, List[torch.Tensor]]:
371+
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
384372

385373
h = w = self.config.vision_config.image_size
386374
expected_dims = (3, h, w)
@@ -389,10 +377,11 @@ def _validate_shape(d: torch.Tensor):
389377
actual_dims = tuple(d.shape)
390378

391379
if actual_dims != expected_dims:
392-
expected_expr = ("num_patches", *map(str, expected_dims))
380+
expected_expr = str(expected_dims)
393381
raise ValueError(
394-
"The expected shape of pixel values in each batch element "
395-
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
382+
"The expected shape of pixel values per image per batch "
383+
f" per patch is {expected_expr}. "
384+
f"You supplied {tuple(d.shape)}.")
396385

397386
for d in data:
398387
_validate_shape(d)
@@ -413,12 +402,9 @@ def _parse_and_validate_image_input(
413402
raise ValueError("Incorrect type of image embeddings. "
414403
f"Got type: {type(image_embeds)}")
415404

416-
# Flatten the B and N dimensions
417-
image_embeds = image_embeds.flatten(0, 2)
418-
419405
return InternVLImageEmbeddingInputs(
420406
type="image_embeds",
421-
data=image_embeds,
407+
data=flatten_bn(image_embeds),
422408
)
423409

424410
self.img_context_token_id = image_token_id[0]
@@ -428,12 +414,10 @@ def _parse_and_validate_image_input(
428414
raise ValueError("Incorrect type of pixel values. "
429415
f"Got type: {type(pixel_values)}")
430416

431-
# Flatten the B and N dimensions
432-
pixel_values = pixel_values.flatten(0, 2)
433-
434417
return InternVLImagePixelInputs(
435418
type="pixel_values",
436-
data=self._validate_pixel_values(pixel_values),
419+
data=self._validate_pixel_values(
420+
flatten_bn(pixel_values, concat=True).flatten(0, 1)),
437421
)
438422

439423
raise AssertionError("This line should be unreachable.")

vllm/model_executor/models/llava.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
class LlavaImagePixelInputs(TypedDict):
3131
type: Literal["pixel_values"]
3232
data: torch.Tensor
33-
"""Shape: `(batch_size, num_channels, height, width)`"""
33+
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
3434

3535

3636
class LlavaImageEmbeddingInputs(TypedDict):
3737
type: Literal["image_embeds"]
3838
data: torch.Tensor
39-
"""Shape: `(batch_size, image_feature_size, hidden_size)`
39+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
4040
4141
`hidden_size` must match the hidden size of language model backbone.
4242
"""

vllm/model_executor/models/llava_next.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
3030
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
3131
get_siglip_patch_grid_length, input_processor_for_siglip)
32-
from .utils import (filter_weights, init_vllm_registered_model,
32+
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
3333
merge_multimodal_embeddings)
3434

3535
logger = init_logger(__name__)
@@ -47,15 +47,16 @@ class LlavaNextImagePixelInputs(TypedDict):
4747
type: Literal["pixel_values"]
4848
data: Union[torch.Tensor, List[torch.Tensor]]
4949
"""
50-
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
50+
Shape:
51+
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
5152
52-
Note that `num_patches` may be different for each batch, in which case
53-
the data is passed as a list instead of a batched tensor.
53+
Note that `num_patches` may be different per batch and image,
54+
in which case the data is passed as a list instead of a batched tensor.
5455
"""
5556

5657
image_sizes: NotRequired[torch.Tensor]
5758
"""
58-
Shape: `(batch_size, 2)`
59+
Shape: `(batch_size * num_images, 2)`
5960
6061
This should be in `(height, width)` format.
6162
"""
@@ -64,7 +65,7 @@ class LlavaNextImagePixelInputs(TypedDict):
6465
class LlavaNextImageEmbeddingInputs(TypedDict):
6566
type: Literal["image_embeds"]
6667
data: torch.Tensor
67-
"""Shape: `(batch_size, image_feature_size, hidden_size)`
68+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
6869
6970
`hidden_size` must match the hidden size of language model backbone.
7071
"""
@@ -315,10 +316,19 @@ def __init__(self,
315316
torch.empty(config.text_config.hidden_size))
316317

317318
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
318-
if list(data.shape[1:]) != [2]:
319-
raise ValueError(
320-
f"The expected image sizes shape is batch dimension plus "
321-
f"{[2]}. You supplied {data.shape}.")
319+
expected_dims = (2, )
320+
321+
def _validate_shape(d: torch.Tensor):
322+
actual_dims = tuple(d.shape)
323+
324+
if actual_dims != expected_dims:
325+
expected_expr = str(expected_dims)
326+
raise ValueError(
327+
f"The expected shape of image sizes per image per batch "
328+
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
329+
330+
for d in data:
331+
_validate_shape(d)
322332

323333
return data
324334

@@ -335,7 +345,7 @@ def _validate_shape(d: torch.Tensor):
335345
if actual_dims != expected_dims:
336346
expected_expr = ("num_patches", *map(str, expected_dims))
337347
raise ValueError(
338-
"The expected shape of pixel values in each batch element "
348+
"The expected shape of pixel values per image per batch "
339349
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
340350

341351
for d in data:
@@ -357,35 +367,25 @@ def _parse_and_validate_image_input(
357367
raise ValueError("Incorrect type of pixel values. "
358368
f"Got type: {type(pixel_values)}")
359369

360-
if not isinstance(image_sizes, torch.Tensor):
370+
if not isinstance(image_sizes, (torch.Tensor, list)):
361371
raise ValueError("Incorrect type of image sizes. "
362372
f"Got type: {type(image_sizes)}")
363373

364-
# Remove the N dimension until multiple images are supported.
365-
if isinstance(pixel_values, torch.Tensor):
366-
pixel_values = pixel_values.squeeze(1)
367-
else:
368-
pixel_values = [t.squeeze(0) for t in pixel_values]
369-
370-
image_sizes = image_sizes.squeeze(1)
371-
372374
return LlavaNextImagePixelInputs(
373375
type="pixel_values",
374-
data=self._validate_pixel_values(pixel_values),
375-
image_sizes=self._validate_image_sizes(image_sizes),
376+
data=self._validate_pixel_values(flatten_bn(pixel_values)),
377+
image_sizes=self._validate_image_sizes(
378+
flatten_bn(image_sizes, concat=True)),
376379
)
377380

378381
if image_embeds is not None:
379382
if not isinstance(image_embeds, torch.Tensor):
380383
raise ValueError("Incorrect type of image embeds. "
381384
f"Got type: {type(image_embeds)}")
382385

383-
# Remove the N dimension until multiple images are supported.
384-
image_embeds = image_embeds.squeeze(1)
385-
386386
return LlavaNextImageEmbeddingInputs(
387387
type="image_embeds",
388-
data=image_embeds,
388+
data=flatten_bn(image_embeds),
389389
)
390390

391391
raise AssertionError("This line should be unreachable.")

vllm/model_executor/models/paligemma.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
class PaliGemmaImagePixelInputs(TypedDict):
3535
type: Literal["pixel_values"]
3636
data: torch.Tensor
37-
"""Shape: (batch_size, num_channels, height, width)"""
37+
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
3838

3939

4040
class PaliGemmaImageEmbeddingInputs(TypedDict):
4141
type: Literal["image_embeds"]
4242
data: torch.Tensor
43-
"""Shape: `(batch_size, image_feature_size, hidden_size)`
43+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
4444
4545
`hidden_size` must match the hidden size of language model backbone.
4646
"""

vllm/model_executor/models/phi3v.py

+27-23
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
4646
from .interfaces import SupportsMultiModal
47-
from .utils import merge_multimodal_embeddings
47+
from .utils import flatten_bn, merge_multimodal_embeddings
4848

4949
logger = init_logger(__name__)
5050

@@ -75,15 +75,16 @@ class Phi3VImagePixelInputs(TypedDict):
7575
type: Literal["pixel_values"]
7676
data: Union[torch.Tensor, List[torch.Tensor]]
7777
"""
78-
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
78+
Shape:
79+
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
7980
80-
Note that `num_patches` may be different for each batch, in which case
81-
the data is passed as a list instead of a batched tensor.
81+
Note that `num_patches` may be different per batch and image,
82+
in which case the data is passed as a list instead of a batched tensor.
8283
"""
8384

8485
image_sizes: torch.Tensor
8586
"""
86-
Shape: `(batch_size, 2)`
87+
Shape: `(batch_size * num_images, 2)`
8788
8889
This should be in `(height, width)` format.
8990
"""
@@ -92,7 +93,7 @@ class Phi3VImagePixelInputs(TypedDict):
9293
class Phi3VImageEmbeddingInputs(TypedDict):
9394
type: Literal["image_embeds"]
9495
data: Union[torch.Tensor, List[torch.Tensor]]
95-
"""Shape: `(batch_size, image_feature_size, hidden_size)`
96+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
9697
9798
`hidden_size` must match the hidden size of language model backbone.
9899
"""
@@ -511,10 +512,19 @@ def __init__(self,
511512
self.sampler = Sampler()
512513

513514
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
514-
if list(data.shape[1:]) != [2]:
515-
raise ValueError(
516-
f"The expected shape of image sizes is batch dimension plus "
517-
f"{[2]}. You supplied {tuple(data.shape)}.")
515+
expected_dims = (2, )
516+
517+
def _validate_shape(d: torch.Tensor):
518+
actual_dims = tuple(d.shape)
519+
520+
if actual_dims != expected_dims:
521+
expected_expr = str(expected_dims)
522+
raise ValueError(
523+
f"The expected shape of image sizes per image per batch "
524+
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
525+
526+
for d in data:
527+
_validate_shape(d)
518528

519529
return data
520530

@@ -531,7 +541,7 @@ def _validate_shape(d: torch.Tensor):
531541
if actual_dims != expected_dims:
532542
expected_expr = ("num_patches", *map(str, expected_dims))
533543
raise ValueError(
534-
"The expected shape of pixel values in each batch element "
544+
"The expected shape of pixel values per image per batch "
535545
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
536546

537547
for d in data:
@@ -556,30 +566,24 @@ def _parse_and_validate_image_input(
556566
raise ValueError("Incorrect type of pixel values. "
557567
f"Got type: {type(pixel_values)}")
558568

559-
if not isinstance(image_sizes, torch.Tensor):
569+
if not isinstance(image_sizes, (torch.Tensor, list)):
560570
raise ValueError("Incorrect type of image sizes. "
561571
f"Got type: {type(image_sizes)}")
562572

563-
# Merge the B and N dimensions.
564-
if isinstance(pixel_values, torch.Tensor):
565-
pixel_values = pixel_values.flatten(0, 1)
566-
else:
567-
pixel_values = torch.cat(pixel_values)
568-
569-
image_sizes = image_sizes.flatten(0, 1)
570-
571573
return Phi3VImagePixelInputs(
572574
type="pixel_values",
573-
data=self._validate_pixel_values(pixel_values),
574-
image_sizes=self._validate_image_sizes(image_sizes))
575+
data=self._validate_pixel_values(flatten_bn(pixel_values)),
576+
image_sizes=self._validate_image_sizes(
577+
flatten_bn(image_sizes, concat=True)))
575578

576579
if image_embeds is not None:
577580
if not isinstance(image_embeds, torch.Tensor):
578581
raise ValueError("Incorrect type of image embeddings. "
579582
f"Got type: {type(image_embeds)}")
583+
580584
return Phi3VImageEmbeddingInputs(
581585
type="image_embeds",
582-
data=image_embeds,
586+
data=flatten_bn(image_embeds),
583587
)
584588

585589
raise AssertionError("This line should be unreachable.")

vllm/model_executor/models/ultravox.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
class UltravoxAudioFeatureInputs(TypedDict):
5050
type: Literal["audio_features"]
5151
data: Union[torch.Tensor, List[torch.Tensor]]
52-
"""Shape: `(batch_size, 80, M)"""
52+
"""Shape: `(batch_size * num_audios, 80, M)"""
5353

5454

5555
class UltravoxAudioEmbeddingInputs(TypedDict):

0 commit comments

Comments
 (0)