Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/v1/engine/test_processor_multi_modal_uuids.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def fake_preprocess(prompt,
*,
tokenization_kwargs=None,
lora_request=None,
mm_hash_overrides=None):
captured["mm_hash_overrides"] = mm_hash_overrides
mm_uuids=None):
captured["mm_uuids"] = mm_uuids
# Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]}

Expand All @@ -180,7 +180,7 @@ def fake_preprocess(prompt,
params=SamplingParams(),
)

assert captured["mm_hash_overrides"] == mm_uuids
assert captured["mm_uuids"] == mm_uuids


def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
Expand All @@ -196,8 +196,8 @@ def fake_preprocess(prompt,
*,
tokenization_kwargs=None,
lora_request=None,
mm_hash_overrides=None):
captured["mm_hash_overrides"] = mm_hash_overrides
mm_uuids=None):
captured["mm_uuids"] = mm_uuids
return {"type": "token", "prompt_token_ids": [1]}

monkeypatch.setattr(processor.input_preprocessor,
Expand All @@ -223,7 +223,7 @@ def fake_preprocess(prompt,
)

# Expect request-id-based overrides are passed through
assert captured["mm_hash_overrides"] == {
assert captured["mm_uuids"] == {
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
"video": [f"{request_id}-video-0"],
}
88 changes: 37 additions & 51 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def _process_multimodal(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
Expand All @@ -281,7 +280,7 @@ def _process_multimodal(
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
mm_hashes = mm_input["mm_hashes"]

Expand All @@ -302,8 +301,7 @@ async def _process_multimodal_async(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
"""
Async version of
Expand All @@ -325,7 +323,7 @@ async def _process_multimodal_async(
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
mm_hashes = mm_input["mm_hashes"]

Expand Down Expand Up @@ -390,8 +388,7 @@ def _process_tokens(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
Expand All @@ -404,7 +401,7 @@ def _process_tokens(
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids)
Expand All @@ -420,8 +417,7 @@ async def _process_tokens_async(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
Expand All @@ -434,7 +430,7 @@ async def _process_tokens_async(
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
Expand All @@ -450,8 +446,7 @@ def _process_text(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]

Expand All @@ -463,7 +458,7 @@ def _process_text(
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = self._tokenize_prompt(
Expand All @@ -487,8 +482,7 @@ async def _process_text_async(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]

Expand All @@ -500,7 +494,7 @@ async def _process_text_async(
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
Expand All @@ -524,8 +518,7 @@ def _prompt_to_llm_inputs(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
Expand All @@ -547,21 +540,21 @@ def _prompt_to_llm_inputs(
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

assert_never(parsed)
Expand All @@ -572,8 +565,7 @@ async def _prompt_to_llm_inputs_async(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
"""
Async version of
Expand All @@ -587,21 +579,21 @@ async def _prompt_to_llm_inputs_async(
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

assert_never(parsed)
Expand Down Expand Up @@ -712,8 +704,7 @@ def _process_encoder_decoder_prompt(
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
Expand Down Expand Up @@ -755,7 +746,7 @@ def _process_encoder_decoder_prompt(
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
Expand All @@ -771,7 +762,7 @@ def _process_encoder_decoder_prompt(
inputs = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
Expand All @@ -788,8 +779,7 @@ async def _process_encoder_decoder_prompt_async(
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> EncoderDecoderInputs:
"""
Async version of
Expand All @@ -802,7 +792,7 @@ async def _process_encoder_decoder_prompt_async(
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

if (decoder_input := prompt["decoder_prompt"]) is None:
Expand All @@ -812,7 +802,7 @@ async def _process_encoder_decoder_prompt_async(
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

encoder_inputs, decoder_inputs = await asyncio.gather(
Expand All @@ -828,7 +818,7 @@ async def _process_encoder_decoder_prompt_async(
inputs = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
Expand Down Expand Up @@ -856,8 +846,7 @@ def _process_decoder_only_prompt(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
Expand All @@ -878,7 +867,7 @@ def _process_decoder_only_prompt(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

return self._build_decoder_only_llm_inputs(prompt_comps)
Expand All @@ -889,8 +878,7 @@ async def _process_decoder_only_prompt_async(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
"""
Async version of
Expand All @@ -900,7 +888,7 @@ async def _process_decoder_only_prompt_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

return self._build_decoder_only_llm_inputs(prompt_comps)
Expand All @@ -911,8 +899,7 @@ def preprocess(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
Expand All @@ -921,7 +908,7 @@ def preprocess(
return self._process_encoder_decoder_prompt(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

if is_explicit_encoder_decoder_prompt(prompt):
Expand All @@ -933,7 +920,7 @@ def preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

async def preprocess_async(
Expand All @@ -942,8 +929,7 @@ async def preprocess_async(
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
"""
Async version of
Expand All @@ -955,7 +941,7 @@ async def preprocess_async(
return await self._process_encoder_decoder_prompt_async(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

if is_explicit_encoder_decoder_prompt(prompt):
Expand All @@ -967,7 +953,7 @@ async def preprocess_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
mm_uuids=mm_uuids,
)

def clear_cache(self) -> None:
Expand Down
Loading