Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,46 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
from types import SimpleNamespace
from unittest.mock import AsyncMock

import httpx
import pytest

from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.chat_completion.batch_serving import (
OpenAIServingChatBatch,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
)

# any model with a chat template defined in tokenizer_config should work here
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"


class _FakeRender:
use_harmony = False
chat_template = None
chat_template_content_format = "auto"
default_chat_template_kwargs: dict = {}
tool_parser = None
trust_request_chat_template = False

def __init__(self):
self.preprocessed_requests = []
self.reasoning_parser_args = []

def validate_chat_template(self, **kwargs):
return None

async def preprocess_chat(self, single_request, messages, **kwargs):
self.preprocessed_requests.append(single_request)
self.reasoning_parser_args.append(kwargs["reasoning_parser"])
single_request.skip_special_tokens = False
return messages, [{"prompt_token_ids": [1, 2]}]


@pytest.fixture(scope="module")
def default_server_args():
return [
Expand All @@ -30,6 +60,37 @@ def server(default_server_args):
yield remote_server


@pytest.mark.asyncio
async def test_batch_render_uses_adjusted_reasoning_requests() -> None:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other tests in this module seem like purely integration tests with a real server, not sure if that matters

request = BatchChatCompletionRequest(
model="test-model",
messages=[
[{"role": "user", "content": "one"}],
[{"role": "user", "content": "two"}],
],
)
reasoning_parser_cls = object()

handler = object.__new__(OpenAIServingChatBatch)
handler._check_model = AsyncMock(return_value=None)
handler.engine_client = SimpleNamespace(errored=False)
handler.openai_serving_render = _FakeRender()
handler.reasoning_parser_cls = reasoning_parser_cls

result = await handler.render_batch_chat_request(request)

conversations, engine_prompts, adjusted_requests = result
assert conversations == request.messages
assert engine_prompts == [{"prompt_token_ids": [1, 2]}] * 2
assert handler.openai_serving_render.preprocessed_requests == adjusted_requests
assert handler.openai_serving_render.reasoning_parser_args == [
reasoning_parser_cls,
reasoning_parser_cls,
]
assert [r.messages for r in adjusted_requests] == request.messages
assert [r.skip_special_tokens for r in adjusted_requests] == [False, False]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
72 changes: 48 additions & 24 deletions vllm/entrypoints/openai/chat_completion/batch_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
Expand Down Expand Up @@ -43,16 +44,24 @@ class OpenAIServingChatBatch(OpenAIServingChat):
async def render_batch_chat_request(
self,
request: BatchChatCompletionRequest,
) -> tuple[list[list[ConversationMessage]], list[EngineInput]] | ErrorResponse:
) -> (
tuple[
list[list[ConversationMessage]],
list[EngineInput],
list[ChatCompletionRequest],
]
| ErrorResponse
):
"""Validate the model and preprocess a batched chat completion request.

Performs engine-aware checks then delegates per-conversation
preprocessing to OpenAIServingRender, validating the chat template
once for the whole batch.

Returns:
A tuple of (all_conversations, engine_prompts) on success — one
entry per conversation — or an ErrorResponse on failure.
A tuple of (all_conversations, engine_prompts, single_requests)
on success — one entry per conversation — or an ErrorResponse
on failure.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
Expand All @@ -79,6 +88,7 @@ async def render_batch_chat_request(

all_conversations: list[list[ConversationMessage]] = []
all_engine_prompts: list[EngineInput] = []
single_requests: list[ChatCompletionRequest] = []

for messages in request.messages:
single_request = request.to_chat_completion_request(messages)
Expand All @@ -95,11 +105,13 @@ async def render_batch_chat_request(
default_template_kwargs=render.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
reasoning_parser=self.reasoning_parser_cls,
)
all_conversations.append(conversation)
all_engine_prompts.append(engine_prompts[0])
single_requests.append(single_request)

return all_conversations, all_engine_prompts
return all_conversations, all_engine_prompts, single_requests

async def create_batch_chat_completion(
self,
Expand All @@ -114,10 +126,11 @@ async def create_batch_chat_completion(
"""
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
single_requests = [
request.to_chat_completion_request(messages)
for messages in request.messages
]

render_result = await self.render_batch_chat_request(request)
if isinstance(render_result, ErrorResponse):
return render_result
all_conversations, engine_prompts, single_requests = render_result

reasoning_parser: ReasoningParser | None = None
if self.reasoning_parser_cls:
Expand All @@ -129,11 +142,6 @@ async def create_batch_chat_completion(
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)

render_result = await self.render_batch_chat_request(request)
if isinstance(render_result, ErrorResponse):
return render_result
all_conversations, engine_prompts = render_result

request_id = (
f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
)
Expand All @@ -149,6 +157,7 @@ async def create_batch_chat_completion(
generators: list[AsyncGenerator[RequestOutput, None]] = []
for i, engine_prompt in enumerate(engine_prompts):
sub_request_id = f"{request_id}_{i}"
prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids
max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens
Expand All @@ -173,16 +182,33 @@ async def create_batch_chat_completion(
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
if (
not single_request.include_reasoning
or single_request._grammar_from_tool_parser
):
reasoning_ended = True
elif reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
prompt_token_ids or []
)
else:
reasoning_ended = None
chat_template_kwargs = self._effective_chat_template_kwargs(single_request)
Comment on lines +185 to +196

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes aren't directly related to the bug described, but are part of what is inconsistent with the regular chat completions logic. I can pull these out if desired. Also, it may be worth trying to abstract more of the common logic into the non-batch serving module to help prevent future regressions

generators.append(
self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority if hasattr(request, "priority") else 0,
priority=single_request.priority,
data_parallel_rank=data_parallel_rank,
reasoning_ended=None,
reasoning_ended=reasoning_ended,
reasoning_parser_kwargs={
"chat_template_kwargs": chat_template_kwargs,
}
if reasoning_parser
else None,
)
)

Expand All @@ -195,6 +221,7 @@ async def create_batch_chat_completion(
tokenizer,
request_metadata,
reasoning_parser,
single_requests,
)

async def chat_completion_full_generator_batch(
Expand All @@ -206,7 +233,8 @@ async def chat_completion_full_generator_batch(
all_conversations: list[list[ConversationMessage]],
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
reasoning_parser: ReasoningParser | None,
single_requests: list[ChatCompletionRequest],
) -> ErrorResponse | ChatCompletionResponse:
"""Handle batched (non-streaming) chat completions.

Expand Down Expand Up @@ -263,22 +291,18 @@ async def chat_completion_full_generator_batch(
logprobs = None

if reasoning_parser:
single_request = single_requests[prompt_idx]
reasoning, content = reasoning_parser.extract_reasoning(
output.text,
request=request, # type: ignore[arg-type]
request=single_request,
)
if not getattr(request, "include_reasoning", True):
if not single_request.include_reasoning:
reasoning = None
else:
reasoning = None
content = output.text

role = (
self.response_role
if request.add_generation_prompt
else request.messages[prompt_idx][-1]["role"]
)

role = self.get_chat_request_role(single_requests[prompt_idx])
message = ChatMessage(role=role, reasoning=reasoning, content=content)

if request.echo:
Expand Down
Loading