Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/lora/test_default_mm_loras.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,5 @@ class MockEngineException(Exception):
# Then check to make sure the submitted lora request
# and text prompt were zipped together correctly
engine_args, engine_kwargs = mock_add_request.call_args
assert engine_args[1]["prompt"] == AUDIO_PROMPT
assert engine_kwargs["lora_request"] is None
assert engine_kwargs["prompt_text"] == AUDIO_PROMPT
19 changes: 10 additions & 9 deletions tests/lora/test_qwenvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def run_test(
# Validate outputs
for generated, expected in zip(generated_texts, expected_outputs):
assert expected.startswith(generated), (
f"Generated text {generated} doesn't "
f"Generated text {generated} doesn't match expected pattern {expected}"
)
f"match expected pattern {expected}"

def run_beam_search_test(
self,
Expand Down Expand Up @@ -118,11 +117,14 @@ def run_beam_search_test(
inputs, beam_search_params, lora_request=lora_request
)

for output_obj, expected_outs in zip(outputs, expected_outputs):
for output_obj, expected_texts in zip(outputs, expected_outputs):
output_texts = [seq.text for seq in output_obj.sequences]
assert output_texts == expected_outs, (
f"Generated texts {output_texts} do not match expected {expected_outs}"
) # noqa: E501

for output_text, expected_text in zip(output_texts, expected_texts):
# NOTE beam search .text contains the whole text including inputs
assert output_text.endswith(expected_text), (
Comment thread
mgoin marked this conversation as resolved.
f"Generated {output_text} does not match expected {expected_text}"
)


TEST_IMAGES = [
Expand Down Expand Up @@ -151,11 +153,10 @@ def run_beam_search_test(
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
]

# NOTE - beam search .text contains the whole text
EXPECTED_BEAM_SEARCH_OUTPUTS = [
[
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501
"A majestic skyscraper stands",
"A majestic tower stands tall",
],
]

Expand Down
51 changes: 17 additions & 34 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,51 +542,31 @@ def wait_for_completion(
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)

def _resolve_lora_reqs(
self,
prompts: Sequence[ProcessorInputs],
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
):
lora_config = self.llm_engine.vllm_config.lora_config
seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))

if (
lora_config is None
or not self.model_config.is_multimodal_model
or (lora_config and lora_config.default_mm_loras is None)
):
return seq_lora_requests

return [
self._resolve_single_prompt_mm_lora(
prompt,
lora_req,
lora_config.default_mm_loras,
)
for prompt, lora_req in zip(prompts, seq_lora_requests)
]

def _resolve_single_prompt_mm_lora(
def _resolve_mm_lora(
self,
prompt: ProcessorInputs,
lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None,
):
if not default_mm_loras or prompt["type"] != "multimodal":
) -> LoRARequest | None:
if prompt["type"] != "multimodal":
return lora_request

lora_config = self.llm_engine.vllm_config.lora_config
default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
if not default_mm_loras:
return lora_request

prompt_modalities = prompt["mm_placeholders"].keys()
intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
if not intersection:
return lora_request

if len(intersection) > 1:
# TODO: Would be nice to be able to have multiple loras per prompt
logger.warning(
"Multiple modality specific loras were registered and would be"
" used by a single prompt consuming several modalities; "
" currently we only support one lora per request; as such,"
" lora(s) registered with modalities: %s"
" will be skipped",
"Multiple modality specific loras were registered and would be "
"used by a single prompt consuming several modalities; "
"currently we only support one lora per request; as such, "
"lora(s) registered with modalities: %s will be skipped",
intersection,
)
return lora_request
Expand Down Expand Up @@ -1915,7 +1895,10 @@ def _render_and_add_requests(
request_id = self._add_request(
prompt,
params[i],
lora_request=None if lora_requests is None else lora_requests[i],
lora_request=self._resolve_mm_lora(
prompt,
None if lora_requests is None else lora_requests[i],
),
priority=0 if priorities is None else priorities[i],
)
added_request_ids.append(request_id)
Expand Down