Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = "<|image|>" * len(image_urls)
prompt = f"{placeholders}<|begin_of_text|>{question}"
img_prompt = "Given the first image <|image|> and the second image<|image|>"
prompt = f"<|begin_of_text|>{img_prompt}, {question}?"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
Expand Down
65 changes: 50 additions & 15 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,66 @@ def apply(
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)

image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches
# the number of images provided in mm_data
num_image_tokens = mm_inputs['prompt_token_ids'].count(
self.info.get_hf_config().image_token_index)
num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id)
image_data = mm_data.get("image", [])
num_images = 1 if isinstance(image_data, Image) else len(image_data)
if num_image_tokens != num_images:
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({num_images})")

# Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
# P0 & P1 do cross attention with placeholder of <IMG0>
# P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }

if mm_data:
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tokens for those images.
token_per_chunk = self.info.get_token_per_chunk_from_config()
num_decode_images = self._get_num_image_in_last_group(
mm_inputs["prompt_token_ids"])
num_encode_images = num_images - num_decode_images

# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
num_tiles = mm_inputs["mm_kwargs"]["num_tiles"]
decode_tiles = num_tiles[num_encode_images:num_images].sum().item()
num_tokens = decode_tiles * token_per_chunk
mm_inputs["encoder_prompt_token_ids"] = [image_token_id
] * num_tokens
mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens

return mm_inputs

def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == self.info.get_hf_config().image_token_index:
num_images += 1
elif num_images > 0:
break
return num_images

def _call_hf_processor(
self,
prompt: str,
Expand All @@ -211,19 +258,7 @@ def _call_hf_processor(
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
processed_outputs[k] = processed_outputs[k].squeeze(0)
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000],
# },
# }

processed_token_ids = processed_outputs.pop("input_ids")
start_idx, end_idx = 0, processed_token_ids.size(1)
processed_prompt_text = tokenizer.decode(processed_token_ids[0])
Expand Down