Skip to content

Commit a607827

Browse files
ywang96Alvant
authored andcommitted
[Bugfix] Fix img_sizes Parsing in Phi3-Vision (vllm-project#5888)
Signed-off-by: Alvant <[email protected]>
1 parent 77cde6d commit a607827

File tree

1 file changed

+6
-20
lines changed

1 file changed

+6
-20
lines changed

vllm/model_executor/models/phi3v.py

+6-20
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,6 @@ def __init__(self, wte=None) -> None:
6565
self.type_feature: str
6666
self.img_processor: CLIPVisionModel
6767

68-
def set_img_features(self, img_features: torch.FloatTensor) -> None:
69-
self.img_features = img_features
70-
71-
def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
72-
self.img_sizes = img_sizes
73-
7468
def get_img_features(self,
7569
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
7670
LAYER_IDX = self.layer_idx
@@ -144,21 +138,16 @@ def __init__(self,
144138
self.layer_idx = config.img_processor.get('layer_idx', -2)
145139
self.type_feature = config.img_processor.get('type_feature', 'patch')
146140

147-
def forward(self,
148-
input_ids: torch.LongTensor,
141+
def forward(self, input_ids: torch.LongTensor,
149142
pixel_values: torch.FloatTensor,
150-
image_sizes=None) -> torch.FloatTensor:
143+
image_sizes: torch.Tensor) -> torch.FloatTensor:
151144
"""process and merge text embeddings with image embeddings."""
152145

146+
# (batch_size, max_num_crops, 3, height, width)
153147
img_embeds = pixel_values
154-
img_sizes = image_sizes
155148

156-
if self.img_features is not None:
157-
img_embeds = self.img_features.clone()
158-
self.img_features = None
159-
160-
if self.img_sizes is not None:
161-
img_sizes = self.img_sizes
149+
# (batch_size, 2)
150+
img_sizes = image_sizes
162151

163152
input_shape = input_ids.size()
164153
input_ids = input_ids.view(-1, input_shape[-1])
@@ -190,11 +179,8 @@ def forward(self,
190179
output_imgs = []
191180
output_len = []
192181

193-
if isinstance(img_sizes, torch.Tensor):
194-
img_sizes.squeeze_(0)
195-
196182
for _bs in range(bs):
197-
h, w = img_sizes
183+
h, w = img_sizes[_bs]
198184
h = h // 336
199185
w = w // 336
200186
B_ = h * w

0 commit comments

Comments
 (0)