diff --git a/tests/entrypoints/openai/chat_completion/test_batched_chat_completions.py b/tests/entrypoints/openai/chat_completion/test_batched_chat_completions.py index c3a8d0b2bdec..f367ba7947fc 100644 --- a/tests/entrypoints/openai/chat_completion/test_batched_chat_completions.py +++ b/tests/entrypoints/openai/chat_completion/test_batched_chat_completions.py @@ -2,16 +2,46 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +from types import SimpleNamespace +from unittest.mock import AsyncMock import httpx import pytest from tests.utils import RemoteOpenAIServer +from vllm.entrypoints.openai.chat_completion.batch_serving import ( + OpenAIServingChatBatch, +) +from vllm.entrypoints.openai.chat_completion.protocol import ( + BatchChatCompletionRequest, +) # any model with a chat template defined in tokenizer_config should work here MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" +class _FakeRender: + use_harmony = False + chat_template = None + chat_template_content_format = "auto" + default_chat_template_kwargs: dict = {} + tool_parser = None + trust_request_chat_template = False + + def __init__(self): + self.preprocessed_requests = [] + self.reasoning_parser_args = [] + + def validate_chat_template(self, **kwargs): + return None + + async def preprocess_chat(self, single_request, messages, **kwargs): + self.preprocessed_requests.append(single_request) + self.reasoning_parser_args.append(kwargs["reasoning_parser"]) + single_request.skip_special_tokens = False + return messages, [{"prompt_token_ids": [1, 2]}] + + @pytest.fixture(scope="module") def default_server_args(): return [ @@ -30,6 +60,37 @@ def server(default_server_args): yield remote_server +@pytest.mark.asyncio +async def test_batch_render_uses_adjusted_reasoning_requests() -> None: + request = BatchChatCompletionRequest( + model="test-model", + messages=[ + [{"role": "user", "content": "one"}], + [{"role": "user", "content": "two"}], + ], + ) + reasoning_parser_cls = object() + + handler = object.__new__(OpenAIServingChatBatch) + handler._check_model = AsyncMock(return_value=None) + handler.engine_client = SimpleNamespace(errored=False) + handler.openai_serving_render = _FakeRender() + handler.reasoning_parser_cls = reasoning_parser_cls + + result = await handler.render_batch_chat_request(request) + + conversations, engine_prompts, adjusted_requests = result + assert conversations == request.messages + assert engine_prompts == [{"prompt_token_ids": [1, 2]}] * 2 + assert handler.openai_serving_render.preprocessed_requests == adjusted_requests + assert handler.openai_serving_render.reasoning_parser_args == [ + reasoning_parser_cls, + reasoning_parser_cls, + ] + assert [r.messages for r in adjusted_requests] == request.messages + assert [r.skip_special_tokens for r in adjusted_requests] == [False, False] + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/vllm/entrypoints/openai/chat_completion/batch_serving.py b/vllm/entrypoints/openai/chat_completion/batch_serving.py index 0dfcdd925158..2986a526f86b 100644 --- a/vllm/entrypoints/openai/chat_completion/batch_serving.py +++ b/vllm/entrypoints/openai/chat_completion/batch_serving.py @@ -11,6 +11,7 @@ from vllm.entrypoints.chat_utils import ConversationMessage from vllm.entrypoints.openai.chat_completion.protocol import ( BatchChatCompletionRequest, + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, @@ -43,7 +44,14 @@ class OpenAIServingChatBatch(OpenAIServingChat): async def render_batch_chat_request( self, request: BatchChatCompletionRequest, - ) -> tuple[list[list[ConversationMessage]], list[EngineInput]] | ErrorResponse: + ) -> ( + tuple[ + list[list[ConversationMessage]], + list[EngineInput], + list[ChatCompletionRequest], + ] + | ErrorResponse + ): """Validate the model and preprocess a batched chat completion request. Performs engine-aware checks then delegates per-conversation @@ -51,8 +59,9 @@ async def render_batch_chat_request( once for the whole batch. Returns: - A tuple of (all_conversations, engine_prompts) on success — one - entry per conversation — or an ErrorResponse on failure. + A tuple of (all_conversations, engine_prompts, single_requests) + on success — one entry per conversation — or an ErrorResponse + on failure. """ error_check_ret = await self._check_model(request) if error_check_ret is not None: @@ -79,6 +88,7 @@ async def render_batch_chat_request( all_conversations: list[list[ConversationMessage]] = [] all_engine_prompts: list[EngineInput] = [] + single_requests: list[ChatCompletionRequest] = [] for messages in request.messages: single_request = request.to_chat_completion_request(messages) @@ -95,11 +105,13 @@ async def render_batch_chat_request( default_template_kwargs=render.default_chat_template_kwargs, tool_dicts=tool_dicts, tool_parser=tool_parser, + reasoning_parser=self.reasoning_parser_cls, ) all_conversations.append(conversation) all_engine_prompts.append(engine_prompts[0]) + single_requests.append(single_request) - return all_conversations, all_engine_prompts + return all_conversations, all_engine_prompts, single_requests async def create_batch_chat_completion( self, @@ -114,10 +126,11 @@ async def create_batch_chat_completion( """ tokenizer = self.renderer.tokenizer assert tokenizer is not None - single_requests = [ - request.to_chat_completion_request(messages) - for messages in request.messages - ] + + render_result = await self.render_batch_chat_request(request) + if isinstance(render_result, ErrorResponse): + return render_result + all_conversations, engine_prompts, single_requests = render_result reasoning_parser: ReasoningParser | None = None if self.reasoning_parser_cls: @@ -129,11 +142,6 @@ async def create_batch_chat_completion( chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] ) - render_result = await self.render_batch_chat_request(request) - if isinstance(render_result, ErrorResponse): - return render_result - all_conversations, engine_prompts = render_result - request_id = ( f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" ) @@ -149,6 +157,7 @@ async def create_batch_chat_completion( generators: list[AsyncGenerator[RequestOutput, None]] = [] for i, engine_prompt in enumerate(engine_prompts): sub_request_id = f"{request_id}_{i}" + prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids max_tokens = get_max_tokens( max_model_len, request.max_completion_tokens @@ -173,6 +182,18 @@ async def create_batch_chat_completion( if raw_request is None else await self._get_trace_headers(raw_request.headers) ) + if ( + not single_request.include_reasoning + or single_request._grammar_from_tool_parser + ): + reasoning_ended = True + elif reasoning_parser: + reasoning_ended = reasoning_parser.is_reasoning_end( + prompt_token_ids or [] + ) + else: + reasoning_ended = None + chat_template_kwargs = self._effective_chat_template_kwargs(single_request) generators.append( self.engine_client.generate( engine_prompt, @@ -180,9 +201,14 @@ async def create_batch_chat_completion( sub_request_id, lora_request=lora_request, trace_headers=trace_headers, - priority=request.priority if hasattr(request, "priority") else 0, + priority=single_request.priority, data_parallel_rank=data_parallel_rank, - reasoning_ended=None, + reasoning_ended=reasoning_ended, + reasoning_parser_kwargs={ + "chat_template_kwargs": chat_template_kwargs, + } + if reasoning_parser + else None, ) ) @@ -195,6 +221,7 @@ async def create_batch_chat_completion( tokenizer, request_metadata, reasoning_parser, + single_requests, ) async def chat_completion_full_generator_batch( @@ -206,7 +233,8 @@ async def chat_completion_full_generator_batch( all_conversations: list[list[ConversationMessage]], tokenizer: TokenizerLike, request_metadata: RequestResponseMetadata, - reasoning_parser: ReasoningParser | None = None, + reasoning_parser: ReasoningParser | None, + single_requests: list[ChatCompletionRequest], ) -> ErrorResponse | ChatCompletionResponse: """Handle batched (non-streaming) chat completions. @@ -263,22 +291,18 @@ async def chat_completion_full_generator_batch( logprobs = None if reasoning_parser: + single_request = single_requests[prompt_idx] reasoning, content = reasoning_parser.extract_reasoning( output.text, - request=request, # type: ignore[arg-type] + request=single_request, ) - if not getattr(request, "include_reasoning", True): + if not single_request.include_reasoning: reasoning = None else: reasoning = None content = output.text - role = ( - self.response_role - if request.add_generation_prompt - else request.messages[prompt_idx][-1]["role"] - ) - + role = self.get_chat_request_role(single_requests[prompt_idx]) message = ChatMessage(role=role, reasoning=reasoning, content=content) if request.echo: