Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix][Pixtral] Fix multiple images bugs #8415

Merged
merged 16 commits into from
Sep 12, 2024
40 changes: 25 additions & 15 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,30 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder

mm_config = ctx.model_config.multimodal_config
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
max_model_len = ctx.model_config.max_model_len
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
patch_size = mm_encoder.mm_config.image_patch_size
image_token_id = mm_encoder.special_ids.img

# approximate image size
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
mm_config = ctx.model_config.multimodal_config
num_images = mm_config.limit_per_prompt.get("image", 1)

# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0)
img_chunk = ImageChunk(image=image)

tokens = mm_encoder(img_chunk).tokens
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
tokens)
image_feature_size = (size ** 2) // (patch_size ** 2)

num_image_tokens = image_feature_size * num_images

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)

seq_data = SequenceData(token_ids)
mm_data = {"image": max_num_images_per_request * [image]}
mm_data = {"image": num_images * [image]}
return seq_data, mm_data


Expand Down Expand Up @@ -115,8 +122,9 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
"to image features dim {D_img}")
assert (seq_len == N_txt +
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
f"{(N_txt, N_img, image_locations.sum().item())}")
N_img), (f"seq_len should be equal to N_txt + N_img, "
f" but got {seq_len=} != {N_txt=} + {N_img=}, "
f" for {image_locations.sum().item()} img locs")
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

inputs_embeds[image_locations, :] = image_features
return inputs_embeds
Expand Down Expand Up @@ -201,10 +209,12 @@ def _parse_and_validate_image_input(
return None

if isinstance(images, torch.Tensor):
# always take last images
images = [images[-1][i] for i in range(images.size(1))]
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# always take last images
# if passed as list always take last images
images = [images[-1][i] for i in range(len(images[0]))]

return images
Expand Down
Loading