-
-
Notifications
You must be signed in to change notification settings - Fork 17.9k
[Bugfix] Fix Gemma4 reasoning for batch chat completions #42105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| from vllm.entrypoints.chat_utils import ConversationMessage | ||
| from vllm.entrypoints.openai.chat_completion.protocol import ( | ||
| BatchChatCompletionRequest, | ||
| ChatCompletionRequest, | ||
| ChatCompletionResponse, | ||
| ChatCompletionResponseChoice, | ||
| ChatMessage, | ||
|
|
@@ -43,16 +44,24 @@ class OpenAIServingChatBatch(OpenAIServingChat): | |
| async def render_batch_chat_request( | ||
| self, | ||
| request: BatchChatCompletionRequest, | ||
| ) -> tuple[list[list[ConversationMessage]], list[EngineInput]] | ErrorResponse: | ||
| ) -> ( | ||
| tuple[ | ||
| list[list[ConversationMessage]], | ||
| list[EngineInput], | ||
| list[ChatCompletionRequest], | ||
| ] | ||
| | ErrorResponse | ||
| ): | ||
| """Validate the model and preprocess a batched chat completion request. | ||
|
|
||
| Performs engine-aware checks then delegates per-conversation | ||
| preprocessing to OpenAIServingRender, validating the chat template | ||
| once for the whole batch. | ||
|
|
||
| Returns: | ||
| A tuple of (all_conversations, engine_prompts) on success — one | ||
| entry per conversation — or an ErrorResponse on failure. | ||
| A tuple of (all_conversations, engine_prompts, single_requests) | ||
| on success — one entry per conversation — or an ErrorResponse | ||
| on failure. | ||
| """ | ||
| error_check_ret = await self._check_model(request) | ||
| if error_check_ret is not None: | ||
|
|
@@ -79,6 +88,7 @@ async def render_batch_chat_request( | |
|
|
||
| all_conversations: list[list[ConversationMessage]] = [] | ||
| all_engine_prompts: list[EngineInput] = [] | ||
| single_requests: list[ChatCompletionRequest] = [] | ||
|
|
||
| for messages in request.messages: | ||
| single_request = request.to_chat_completion_request(messages) | ||
|
|
@@ -95,11 +105,13 @@ async def render_batch_chat_request( | |
| default_template_kwargs=render.default_chat_template_kwargs, | ||
| tool_dicts=tool_dicts, | ||
| tool_parser=tool_parser, | ||
| reasoning_parser=self.reasoning_parser_cls, | ||
| ) | ||
| all_conversations.append(conversation) | ||
| all_engine_prompts.append(engine_prompts[0]) | ||
| single_requests.append(single_request) | ||
|
|
||
| return all_conversations, all_engine_prompts | ||
| return all_conversations, all_engine_prompts, single_requests | ||
|
|
||
| async def create_batch_chat_completion( | ||
| self, | ||
|
|
@@ -114,10 +126,11 @@ async def create_batch_chat_completion( | |
| """ | ||
| tokenizer = self.renderer.tokenizer | ||
| assert tokenizer is not None | ||
| single_requests = [ | ||
| request.to_chat_completion_request(messages) | ||
| for messages in request.messages | ||
| ] | ||
|
|
||
| render_result = await self.render_batch_chat_request(request) | ||
| if isinstance(render_result, ErrorResponse): | ||
| return render_result | ||
| all_conversations, engine_prompts, single_requests = render_result | ||
|
|
||
| reasoning_parser: ReasoningParser | None = None | ||
| if self.reasoning_parser_cls: | ||
|
|
@@ -129,11 +142,6 @@ async def create_batch_chat_completion( | |
| chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] | ||
| ) | ||
|
|
||
| render_result = await self.render_batch_chat_request(request) | ||
| if isinstance(render_result, ErrorResponse): | ||
| return render_result | ||
| all_conversations, engine_prompts = render_result | ||
|
|
||
| request_id = ( | ||
| f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" | ||
| ) | ||
|
|
@@ -149,6 +157,7 @@ async def create_batch_chat_completion( | |
| generators: list[AsyncGenerator[RequestOutput, None]] = [] | ||
| for i, engine_prompt in enumerate(engine_prompts): | ||
| sub_request_id = f"{request_id}_{i}" | ||
| prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids | ||
| max_tokens = get_max_tokens( | ||
| max_model_len, | ||
| request.max_completion_tokens | ||
|
|
@@ -173,16 +182,33 @@ async def create_batch_chat_completion( | |
| if raw_request is None | ||
| else await self._get_trace_headers(raw_request.headers) | ||
| ) | ||
| if ( | ||
| not single_request.include_reasoning | ||
| or single_request._grammar_from_tool_parser | ||
| ): | ||
| reasoning_ended = True | ||
| elif reasoning_parser: | ||
| reasoning_ended = reasoning_parser.is_reasoning_end( | ||
| prompt_token_ids or [] | ||
| ) | ||
| else: | ||
| reasoning_ended = None | ||
| chat_template_kwargs = self._effective_chat_template_kwargs(single_request) | ||
|
Comment on lines
+185
to
+196
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes aren't directly related to the bug described, but are part of what is inconsistent with the regular chat completions logic. I can pull these out if desired. Also, it may be worth trying to abstract more of the common logic into the non-batch serving module to help prevent future regressions |
||
| generators.append( | ||
| self.engine_client.generate( | ||
| engine_prompt, | ||
| sampling_params, | ||
| sub_request_id, | ||
| lora_request=lora_request, | ||
| trace_headers=trace_headers, | ||
| priority=request.priority if hasattr(request, "priority") else 0, | ||
| priority=single_request.priority, | ||
| data_parallel_rank=data_parallel_rank, | ||
| reasoning_ended=None, | ||
| reasoning_ended=reasoning_ended, | ||
| reasoning_parser_kwargs={ | ||
| "chat_template_kwargs": chat_template_kwargs, | ||
| } | ||
| if reasoning_parser | ||
| else None, | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -195,6 +221,7 @@ async def create_batch_chat_completion( | |
| tokenizer, | ||
| request_metadata, | ||
| reasoning_parser, | ||
| single_requests, | ||
| ) | ||
|
|
||
| async def chat_completion_full_generator_batch( | ||
|
|
@@ -206,7 +233,8 @@ async def chat_completion_full_generator_batch( | |
| all_conversations: list[list[ConversationMessage]], | ||
| tokenizer: TokenizerLike, | ||
| request_metadata: RequestResponseMetadata, | ||
| reasoning_parser: ReasoningParser | None = None, | ||
| reasoning_parser: ReasoningParser | None, | ||
| single_requests: list[ChatCompletionRequest], | ||
| ) -> ErrorResponse | ChatCompletionResponse: | ||
| """Handle batched (non-streaming) chat completions. | ||
|
|
||
|
|
@@ -263,22 +291,18 @@ async def chat_completion_full_generator_batch( | |
| logprobs = None | ||
|
|
||
| if reasoning_parser: | ||
| single_request = single_requests[prompt_idx] | ||
| reasoning, content = reasoning_parser.extract_reasoning( | ||
| output.text, | ||
| request=request, # type: ignore[arg-type] | ||
| request=single_request, | ||
| ) | ||
| if not getattr(request, "include_reasoning", True): | ||
| if not single_request.include_reasoning: | ||
| reasoning = None | ||
| else: | ||
| reasoning = None | ||
| content = output.text | ||
|
|
||
| role = ( | ||
| self.response_role | ||
| if request.add_generation_prompt | ||
| else request.messages[prompt_idx][-1]["role"] | ||
| ) | ||
|
|
||
| role = self.get_chat_request_role(single_requests[prompt_idx]) | ||
| message = ChatMessage(role=role, reasoning=reasoning, content=content) | ||
|
|
||
| if request.echo: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other tests in this module seem like purely integration tests with a real server, not sure if that matters