Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 1
assert len(choice.logprobs.top_logprobs[0]) == 1


@pytest.mark.asyncio
Expand All @@ -246,7 +246,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 6
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6


@pytest.mark.asyncio
Expand Down Expand Up @@ -1032,8 +1032,9 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI):
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
@pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
model_name: str):
model_name: str, logprobs_arg: int):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
Expand All @@ -1042,7 +1043,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=1)
logprobs=logprobs_arg)
Copy link
Copy Markdown
Contributor

@zifeitong zifeitong Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

Can you expand the test to verify the size of dicts in logprobs.top_logprobs.

You probably also need to fix here:

if num_logprobs > 0:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand the test to verify the size of dicts in logprobs.top_logprobs.

Added the verification to the logprobs tests (with and without echo=True).


prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
Expand All @@ -1055,6 +1056,9 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5


Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def request_output_to_completion_response(
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
if request.logprobs is not None else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _prepare_seq_groups(
logits = hidden_states[selected_token_indices]
"""

if sampling_params.prompt_logprobs:
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + prompt_logprob_len))
model_output_idx += prompt_logprob_len
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def _prepare_model_input(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is not None else 1))

mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
Expand Down