Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
08ab78e
update of beam search function
FerdinandZhong Oct 15, 2024
cac55e1
update of testing
FerdinandZhong Oct 16, 2024
358f89c
Merge remote-tracking branch 'upstream/main' into beam_search_multi_m…
FerdinandZhong Oct 16, 2024
2dde695
fix error in implementation
FerdinandZhong Oct 16, 2024
eb92b7d
add checking for logprobs and add more test cases
FerdinandZhong Oct 16, 2024
014d753
formatting
FerdinandZhong Oct 16, 2024
eae5b9b
Merge remote-tracking branch 'upstream/main' into beam_search_multi_m…
FerdinandZhong Oct 17, 2024
5f0e1cd
update BeamSequence, prompt preprocess and adding stop_reason
FerdinandZhong Oct 17, 2024
6e29318
Merge branch 'beam_search_multi_modality' of https://github.com/Ferdi…
FerdinandZhong Oct 17, 2024
5a256cb
fix the wrong declaration
FerdinandZhong Oct 17, 2024
b01a615
formatting
FerdinandZhong Oct 17, 2024
bc74931
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 18, 2024
8291a80
remove checking for logprobs
FerdinandZhong Oct 18, 2024
a682b63
format
FerdinandZhong Oct 18, 2024
8940743
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 18, 2024
bb53cbd
output beam's logprobs to Output's logprobs
FerdinandZhong Oct 18, 2024
c275ae3
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 19, 2024
3b7ab92
update calling of beam_search from serving_completion based on latest…
FerdinandZhong Oct 19, 2024
f96fa9a
Merge branch 'main' of github.com:vllm-project/vllm into beam_search_…
FerdinandZhong Oct 22, 2024
314a31e
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 28, 2024
8705266
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,53 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
model_name: str,
image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]

chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content

with pytest.raises(openai.BadRequestError) as exc_info:
await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(use_beam_search=True))

# Assert that the exception message is correct
assert "Only the `cumulative_logprob` " in str(exc_info.value)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
Expand Down Expand Up @@ -160,6 +207,41 @@ async def test_single_chat_session_image_base64encoded(
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: Dict[str, str]):

messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url":
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
Expand Down
1 change: 1 addition & 0 deletions vllm/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class BeamSearchSequence:
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None
finish_reason: Optional[str] = None


@dataclass
Expand Down
72 changes: 53 additions & 19 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import AsyncGenerator, List, Mapping, Optional, Union

from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
Expand Down Expand Up @@ -69,24 +70,50 @@ async def beam_search(
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output

tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

if isinstance(prompt, dict):
if "prompt" in prompt:
tokenized_prompt = tokenizer.encode(prompt.get("prompt"))
multi_modal_data = prompt.get("multi_modal_data")
mm_processor_kwargs = prompt.get("mm_processor_kwargs")
elif "prompt_token_ids" in prompt:
tokenized_prompt = prompt.get("prompt_token_ids")
multi_modal_data = prompt.get("multi_modal_data")
mm_processor_kwargs = prompt.get("mm_processor_kwargs")
else:
raise TypeError(
"Dictionary input must be a TextPrompt or TokensPrompt")
else:
tokenized_prompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
multi_modal_data = None
mm_processor_kwargs = None

tokenized_length = len(tokenized_prompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(tokens=tokenized_prompt, cum_logprob=0)
]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
TokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=deepcopy(
multi_modal_data), # always the values from inputs
mm_processor_kwargs=deepcopy(mm_processor_kwargs))
for beam in all_beams
]

Expand All @@ -112,16 +139,23 @@ async def beam_search(
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens +
[token_id] if include_stop_str_in_output
else current_beam.tokens, #
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
finish_reason="stop"))
else:
new_beams.append(new_beam)
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id], #
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
))

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
Expand All @@ -131,22 +165,22 @@ async def beam_search(
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam.text = tokenizer.decode(beam.tokens[tokenized_length:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
prompt=tokenizer.decode(tokenized_prompt),
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_token_ids=tokenized_prompt,
prompt_logprobs=None)

yield beam_search_output
Expand Down
8 changes: 6 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
include_stop_str_in_output=self.include_stop_str_in_output)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
Expand Down Expand Up @@ -400,6 +400,10 @@ def check_logprobs(cls, data):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
if data.get("logprobs") and data.get("use_beam_search"):
raise ValueError(
"Only the `cumulative_logprob` of each output will be returned."
)

return data

Expand Down Expand Up @@ -594,7 +598,7 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
include_stop_str_in_output=self.include_stop_str_in_output)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def create_chat_completion(

if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
engine_inputs,
request_id,
sampling_params,
)
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def create_completion(

if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"],
prompt_inputs,
request_id_item,
sampling_params,
)
Expand Down
1 change: 1 addition & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,4 @@ class BeamSearchParams(
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False