Skip to content

Commit a9748c4

Browse files
FerdinandZhongsumitd2
authored andcommitted
[Frontend] re-enable multi-modality input in the new beam search implementation (vllm-project#9427)
Signed-off-by: Qishuai [email protected] Signed-off-by: Sumit Dubey <[email protected]>
1 parent d612841 commit a9748c4

File tree

7 files changed

+150
-40
lines changed

7 files changed

+150
-40
lines changed

tests/entrypoints/openai/test_vision.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
107107
assert message.content is not None and len(message.content) >= 0
108108

109109

110+
@pytest.mark.asyncio
111+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
112+
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
113+
async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
114+
model_name: str,
115+
image_url: str):
116+
messages = [{
117+
"role":
118+
"user",
119+
"content": [
120+
{
121+
"type": "image_url",
122+
"image_url": {
123+
"url": image_url
124+
}
125+
},
126+
{
127+
"type": "text",
128+
"text": "What's in this image?"
129+
},
130+
],
131+
}]
132+
133+
chat_completion = await client.chat.completions.create(
134+
model=model_name,
135+
messages=messages,
136+
n=2,
137+
max_tokens=10,
138+
logprobs=True,
139+
top_logprobs=5,
140+
extra_body=dict(use_beam_search=True))
141+
assert len(chat_completion.choices) == 2
142+
assert chat_completion.choices[
143+
0].message.content != chat_completion.choices[1].message.content
144+
145+
110146
@pytest.mark.asyncio
111147
@pytest.mark.parametrize("model_name", [MODEL_NAME])
112148
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
162198
assert message.content is not None and len(message.content) >= 0
163199

164200

201+
@pytest.mark.asyncio
202+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
203+
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
204+
async def test_single_chat_session_image_base64encoded_beamsearch(
205+
client: openai.AsyncOpenAI, model_name: str, image_url: str,
206+
base64_encoded_image: Dict[str, str]):
207+
208+
messages = [{
209+
"role":
210+
"user",
211+
"content": [
212+
{
213+
"type": "image_url",
214+
"image_url": {
215+
"url":
216+
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
217+
}
218+
},
219+
{
220+
"type": "text",
221+
"text": "What's in this image?"
222+
},
223+
],
224+
}]
225+
chat_completion = await client.chat.completions.create(
226+
model=model_name,
227+
messages=messages,
228+
n=2,
229+
max_tokens=10,
230+
extra_body=dict(use_beam_search=True))
231+
assert len(chat_completion.choices) == 2
232+
assert chat_completion.choices[
233+
0].message.content != chat_completion.choices[1].message.content
234+
235+
165236
@pytest.mark.asyncio
166237
@pytest.mark.parametrize("model_name", [MODEL_NAME])
167238
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)

vllm/beam_search.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from dataclasses import dataclass
2-
from typing import Dict, List, Optional
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
33

44
from vllm.sequence import Logprob
55

6+
if TYPE_CHECKING:
7+
from vllm.multimodal import MultiModalDataDict
8+
69

710
@dataclass
811
class BeamSearchSequence:
@@ -16,6 +19,10 @@ class BeamSearchSequence:
1619
logprobs: List[Dict[int, Logprob]]
1720
cum_logprob: float = 0.0
1821
text: Optional[str] = None
22+
finish_reason: Optional[str] = None
23+
stop_reason: Union[int, str, None] = None
24+
multi_modal_data: Optional["MultiModalDataDict"] = None
25+
mm_processor_kwargs: Optional[Dict[str, Any]] = None
1926

2027

2128
@dataclass

vllm/engine/protocol.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from vllm.config import DecodingConfig, ModelConfig
77
from vllm.core.scheduler import SchedulerOutputs
88
from vllm.inputs.data import PromptType, TokensPrompt
9+
from vllm.inputs.preprocess import InputPreprocessor
910
from vllm.logger import init_logger
1011
from vllm.lora.request import LoRARequest
1112
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -59,7 +60,8 @@ def generate(
5960

6061
async def beam_search(
6162
self,
62-
prompt: Union[str, List[int]],
63+
prompt: Union[PromptType, List[int]],
64+
model_config: ModelConfig,
6365
request_id: str,
6466
params: BeamSearchParams,
6567
) -> AsyncGenerator[RequestOutput, None]:
@@ -69,32 +71,40 @@ async def beam_search(
6971
ignore_eos = params.ignore_eos
7072
temperature = params.temperature
7173
length_penalty = params.length_penalty
74+
include_stop_str_in_output = params.include_stop_str_in_output
7275

73-
tokenizer = await self.get_tokenizer(lora_request=None)
74-
if isinstance(prompt, str):
75-
tokenized_prompt = tokenizer.encode(prompt)
76-
prompt_text = prompt
77-
else:
78-
tokenized_prompt = prompt
79-
prompt_text = None
80-
tokenized_length = len(tokenized_prompt)
76+
tokenizer = await self.get_tokenizer()
77+
input_preprocessor = InputPreprocessor(model_config, tokenizer)
78+
79+
(prompt_text, prompt_token_ids, multi_modal_data,
80+
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
81+
prompt,
82+
request_id=request_id,
83+
)
84+
tokenized_length = len(prompt_token_ids)
8185

8286
sort_beams_key = create_sort_beams_key_function(
8387
tokenizer.eos_token_id, length_penalty)
8488

85-
beam_search_params = SamplingParams(logprobs=2 * beam_width,
86-
max_tokens=1,
87-
temperature=temperature)
89+
beam_search_params = SamplingParams(
90+
logprobs=2 * beam_width,
91+
max_tokens=1,
92+
temperature=temperature,
93+
)
8894
all_beams = [
89-
BeamSearchSequence(tokens=tokenized_prompt,
95+
BeamSearchSequence(tokens=prompt_token_ids,
96+
cum_logprob=0,
9097
logprobs=[],
91-
cum_logprob=0)
98+
multi_modal_data=multi_modal_data,
99+
mm_processor_kwargs=mm_processor_kwargs)
92100
]
93101
completed = []
94102

95103
for _ in range(max_tokens):
96104
prompts_batch = [
97-
TokensPrompt(prompt_token_ids=beam.tokens)
105+
TokensPrompt(prompt_token_ids=beam.tokens,
106+
multi_modal_data=beam.multi_modal_data,
107+
mm_processor_kwargs=beam.mm_processor_kwargs)
98108
for beam in all_beams
99109
]
100110

@@ -120,17 +130,31 @@ async def beam_search(
120130
if result.outputs[0].logprobs is not None:
121131
logprobs = result.outputs[0].logprobs[0]
122132
for token_id, logprob_obj in logprobs.items():
123-
new_beam = BeamSearchSequence(
124-
tokens=current_beam.tokens + [token_id],
125-
logprobs=current_beam.logprobs + [logprobs],
126-
cum_logprob=current_beam.cum_logprob +
127-
logprob_obj.logprob)
128-
129133
if token_id == tokenizer.eos_token_id and \
130134
not ignore_eos:
131-
completed.append(new_beam)
135+
completed.append(
136+
BeamSearchSequence(
137+
tokens=current_beam.tokens +
138+
[token_id] if include_stop_str_in_output
139+
else current_beam.tokens,
140+
logprobs=current_beam.logprobs +
141+
[logprobs],
142+
cum_logprob=current_beam.cum_logprob +
143+
logprob_obj.logprob,
144+
finish_reason="stop",
145+
stop_reason=tokenizer.eos_token_id))
132146
else:
133-
new_beams.append(new_beam)
147+
new_beams.append(
148+
BeamSearchSequence(
149+
tokens=current_beam.tokens + [token_id],
150+
logprobs=current_beam.logprobs +
151+
[logprobs],
152+
cum_logprob=current_beam.cum_logprob +
153+
logprob_obj.logprob,
154+
multi_modal_data=current_beam.
155+
multi_modal_data,
156+
mm_processor_kwargs=current_beam.
157+
mm_processor_kwargs))
134158

135159
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
136160
all_beams = sorted_beams[:beam_width]
@@ -151,16 +175,18 @@ async def beam_search(
151175
request_id=request_id,
152176
prompt=prompt_text,
153177
outputs=[
154-
CompletionOutput(
155-
text=beam.text,
156-
cumulative_logprob=beam.cum_logprob,
157-
token_ids=beam.tokens[tokenized_length:],
158-
index=i,
159-
logprobs=beam.logprobs,
160-
) for (i, beam) in enumerate(best_beams)
178+
CompletionOutput(text=beam.text,
179+
cumulative_logprob=beam.cum_logprob,
180+
token_ids=beam.tokens[tokenized_length:],
181+
index=i,
182+
logprobs=beam.logprobs,
183+
finish_reason=beam.finish_reason if
184+
beam.finish_reason is not None else "length",
185+
stop_reason=beam.stop_reason)
186+
for (i, beam) in enumerate(best_beams)
161187
],
162188
finished=True,
163-
prompt_token_ids=tokenized_prompt,
189+
prompt_token_ids=prompt_token_ids,
164190
prompt_logprobs=None)
165191

166192
yield beam_search_output

vllm/entrypoints/openai/protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def to_beam_search_params(self,
308308
ignore_eos=self.ignore_eos,
309309
temperature=temperature,
310310
length_penalty=self.length_penalty,
311-
)
311+
include_stop_str_in_output=self.include_stop_str_in_output)
312312

313313
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
314314
max_tokens = self.max_tokens
@@ -606,7 +606,7 @@ def to_beam_search_params(self,
606606
ignore_eos=self.ignore_eos,
607607
temperature=temperature,
608608
length_penalty=self.length_penalty,
609-
)
609+
include_stop_str_in_output=self.include_stop_str_in_output)
610610

611611
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
612612
max_tokens = self.max_tokens

vllm/entrypoints/openai/serving_chat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,10 @@ async def create_chat_completion(
236236

237237
if isinstance(sampling_params, BeamSearchParams):
238238
result_generator = self.engine_client.beam_search(
239-
engine_inputs['prompt_token_ids'],
240-
request_id,
241-
sampling_params,
239+
prompt=engine_inputs,
240+
model_config=self.model_config,
241+
request_id=request_id,
242+
params=sampling_params,
242243
)
243244
else:
244245
result_generator = self.engine_client.generate(

vllm/entrypoints/openai/serving_completion.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,13 @@ async def create_completion(
150150

151151
if isinstance(sampling_params, BeamSearchParams):
152152
generator = self.engine_client.beam_search(
153-
prompt_inputs["prompt_token_ids"],
154-
request_id_item,
155-
sampling_params,
153+
prompt={
154+
"prompt_token_ids":
155+
prompt_inputs["prompt_token_ids"]
156+
},
157+
model_config=self.model_config,
158+
request_id=request_id,
159+
params=sampling_params,
156160
)
157161
else:
158162
generator = self.engine_client.generate(

vllm/sampling_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,4 @@ class BeamSearchParams(
500500
ignore_eos: bool = False
501501
temperature: float = 0.0
502502
length_penalty: float = 1.0
503+
include_stop_str_in_output: bool = False

0 commit comments

Comments
 (0)