diff --git a/tests/entrypoints/openai/chat_completion/test_serving_chat.py b/tests/entrypoints/openai/chat_completion/test_serving_chat.py index 11793ac8f49d..7c0a46a4e634 100644 --- a/tests/entrypoints/openai/chat_completion/test_serving_chat.py +++ b/tests/entrypoints/openai/chat_completion/test_serving_chat.py @@ -1935,8 +1935,10 @@ async def result_generator(): finished=True, ) - # Collect tool-call deltas per choice from the SSE stream. + # Collect tool-call deltas and finish_reasons per choice from the SSE + # stream. tc_deltas_by_choice: dict[int, list[dict]] = {i: [] for i in range(num_choices)} + finish_reasons_by_choice: dict[int, list[str]] = {i: [] for i in range(num_choices)} async for chunk_str in serving_chat.chat_completion_stream_generator( request=request, result_generator=result_generator(), @@ -1959,6 +1961,8 @@ async def result_generator(): if delta.get("tool_calls"): for tc in delta["tool_calls"]: tc_deltas_by_choice[idx].append(tc) + if choice.get("finish_reason") is not None: + finish_reasons_by_choice[idx].append(choice["finish_reason"]) # Both choices must independently produce the correct tool call. for choice_idx in range(num_choices): @@ -1984,141 +1988,11 @@ async def result_generator(): f"Choice {choice_idx}: expected {{'city': 'Tokyo'}}, got {parsed_args}" ) - -class TestCreateRemainingArgsDelta: - """Tests for _create_remaining_args_delta helper function. - - This helper is used when streaming tool calls to preserve id/type/name - fields in the finish chunk, which would otherwise be lost. - """ - - def test_preserves_id_type_name(self): - """Test that id, type, and name are preserved from original delta.""" - from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat - from vllm.entrypoints.openai.engine.protocol import ( - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ) - - original_delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=0, - id="call_abc123", - type="function", - function=DeltaFunctionCall( - name="get_weather", - arguments='{"location": "Paris"}', - ), - ) - ] + reasons = finish_reasons_by_choice[choice_idx] + assert len(reasons) == 1, ( + f"Choice {choice_idx}: expected exactly 1 finish_reason, got {reasons}" ) - - result = OpenAIServingChat._create_remaining_args_delta( - original_delta, '", "unit": "celsius"}', 0 - ) - - assert len(result.tool_calls) == 1 - tc = result.tool_calls[0] - assert tc.index == 0 - assert tc.id == "call_abc123" - assert tc.type == "function" - assert tc.function.name == "get_weather" - assert tc.function.arguments == '", "unit": "celsius"}' - - def test_matches_by_index(self): - """Test that the correct tool call is matched by index.""" - from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat - from vllm.entrypoints.openai.engine.protocol import ( - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ) - - original_delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=0, - id="call_first", - type="function", - function=DeltaFunctionCall(name="func_a", arguments="{}"), - ), - DeltaToolCall( - index=1, - id="call_second", - type="function", - function=DeltaFunctionCall(name="func_b", arguments="{}"), - ), - ] - ) - - result = OpenAIServingChat._create_remaining_args_delta( - original_delta, '{"extra": true}', 1 + assert reasons[0] == "tool_calls", ( + f"Choice {choice_idx}: expected finish_reason='tool_calls', " + f"got '{reasons[0]}'" ) - - assert len(result.tool_calls) == 1 - tc = result.tool_calls[0] - assert tc.index == 1 - assert tc.id == "call_second" - assert tc.function.name == "func_b" - - def test_no_matching_tool_call(self): - """Test graceful handling when no matching tool call is found.""" - from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat - from vllm.entrypoints.openai.engine.protocol import ( - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ) - - original_delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=0, - id="call_zero", - type="function", - function=DeltaFunctionCall(name="func", arguments="{}"), - ) - ] - ) - - result = OpenAIServingChat._create_remaining_args_delta( - original_delta, '{"arg": 1}', 5 - ) - - assert len(result.tool_calls) == 1 - tc = result.tool_calls[0] - assert tc.index == 5 - assert tc.id is None - assert tc.type is None - assert tc.function.name is None - assert tc.function.arguments == '{"arg": 1}' - - def test_function_is_none(self): - """Test handling when original tool call has no function.""" - from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat - from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall - - original_delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=0, - id="call_nofunc", - type="function", - function=None, - ) - ] - ) - - result = OpenAIServingChat._create_remaining_args_delta( - original_delta, '{"data": "value"}', 0 - ) - - assert len(result.tool_calls) == 1 - tc = result.tool_calls[0] - assert tc.index == 0 - assert tc.id == "call_nofunc" - assert tc.type == "function" - assert tc.function.name is None - assert tc.function.arguments == '{"data": "value"}' diff --git a/tests/parser/test_streaming.py b/tests/parser/test_streaming.py index c4409117ad91..71f214f3ee65 100644 --- a/tests/parser/test_streaming.py +++ b/tests/parser/test_streaming.py @@ -235,3 +235,86 @@ def test_parse_delta_reasoning_only_thinking_disabled(tokenizer, request_obj): assert "Hello" in content assert "assist" in content assert len(tool_calls) == 0 + + +def test_parse_delta_finished_no_flush_without_tool_call_delta(tokenizer, request_obj): + """When finished=True but the final parse_delta produces no + tool-call delta, unstreamed args are not flushed.""" + parser = make_parser(tokenizer, reasoning=False, tool=True) + + results = stream_text( + parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[] + ) + _, _, tool_calls = collect_fields(results) + assert len(tool_calls) > 0 + + streamed = parser._tool_parser.streamed_args_for_tool[0] + assert len(streamed) > 5 + parser._tool_parser.streamed_args_for_tool[0] = streamed[:-5] + + # Prevent normal extraction from catching the gap — without a + # tool-call delta to merge into, the flush is skipped. + parser._tool_parser.extract_tool_calls_streaming = lambda *a, **kw: None + + flush_result = parser.parse_delta("", [], request_obj, finished=True) + assert flush_result is None or flush_result.tool_calls is None + + +def test_parse_delta_finished_no_extra_args_when_fully_streamed(tokenizer, request_obj): + """When all args have been streamed, finished=True must not + produce extra or duplicate arguments.""" + parser = make_parser(tokenizer, reasoning=False, tool=True) + results = stream_text( + parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[] + ) + _, _, tool_calls = collect_fields(results) + + assert len(tool_calls) > 0 + assert tool_calls[0].function.name == "get_weather" + tool_args = "".join( + tc.function.arguments for tc in tool_calls if tc.function.arguments + ) + assert json.loads(tool_args) == {"city": "Dallas"} + + flush_result = parser.parse_delta("", [], request_obj, finished=True) + assert flush_result is None or flush_result.tool_calls is None + + +def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj): + """When finished=True and the tool parser has unstreamed args, + parse_delta appends the remaining arguments to the tool-call delta.""" + parser = make_parser(tokenizer, reasoning=False, tool=True) + token_ids = tokenizer.encode(MODEL_OUTPUT, add_special_tokens=False) + + remainder = ',"unit":"celsius"}' + prompt_ids: list[int] | None = [] + results: list[DeltaMessage | None] = [] + for i, tid in enumerate(token_ids): + prev = results[-1] if results else None + prev_had_args = ( + prev + and prev.tool_calls + and any(tc.function and tc.function.arguments for tc in prev.tool_calls) + ) + + if prev_had_args: + parser._tool_parser.get_remaining_unstreamed_args = lambda: remainder + + result = parser.parse_delta( + tokenizer.decode([tid]), + [tid], + request_obj, + prompt_token_ids=prompt_ids, + finished=prev_had_args, + ) + prompt_ids = None + results.append(result) + + if prev_had_args: + break + + _, _, tool_calls = collect_fields(results) + tool_args = "".join( + tc.function.arguments for tc in tool_calls if tc.function.arguments + ) + assert tool_args.endswith(remainder) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 92ffc141548b..412583f8b65a 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -3,7 +3,6 @@ import asyncio import io -import json import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence @@ -40,9 +39,7 @@ extract_harmony_streaming_delta, ) from vllm.entrypoints.openai.engine.protocol import ( - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, @@ -65,7 +62,7 @@ from vllm.inputs import EngineInput from vllm.logger import init_logger from vllm.logprobs import Logprob -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import RequestOutput from vllm.parser import ParserManager from vllm.parser.abstract_parser import Parser from vllm.reasoning import ReasoningParser @@ -715,6 +712,7 @@ async def chat_completion_stream_generator( delta_token_ids=as_list(output.token_ids), request=request, prompt_token_ids=res.prompt_token_ids, + finished=output.finish_reason is not None, ) if delta_message and delta_message.tool_calls: tools_streamed[i] = True @@ -805,81 +803,13 @@ async def chat_completion_stream_generator( # finish_reason='error' indicates a retryable error self._raise_if_error(output.finish_reason, request_id) - # check to make sure we haven't "forgotten" to stream - # any tokens that were generated but previously - # matched by partial json parsing - # only happens if we are NOT using structured outputs - index = 0 - auto_tools_called = False - if tool_parser: - auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0 - index = ( - len(tool_parser.prev_tool_call_arr) - 1 - if auto_tools_called - else 0 - ) - should_check = ( - self._should_check_for_unstreamed_tool_arg_tokens( - delta_message, output - ) - ) - # only check if there are any tool calls - # detected by partial parsing - if should_check and tool_parser and auto_tools_called: - latest_delta_len = 0 - if ( - isinstance( - delta_message.tool_calls[0].function, - DeltaFunctionCall, - ) - ) and isinstance( - delta_message.tool_calls[0].function.arguments, str - ): - latest_delta_len = len( - delta_message.tool_calls[0].function.arguments - ) - - # get the expected call based on partial JSON - # parsing which "autocompletes" the JSON. - # Tool parsers (e.g. Qwen3Coder) store - # arguments as a JSON string in - # prev_tool_call_arr. Calling json.dumps() - # on an already-serialized string would - # double-serialize it (e.g. '{"k":1}' becomes - # '"{\\"k\\":1}"'), which then causes the - # replace() below to fail and append the - # entire double-serialized string as a - # spurious final delta. - args = tool_parser.prev_tool_call_arr[index].get( - "arguments", {} - ) - if isinstance(args, str): - expected_call = args - else: - expected_call = json.dumps(args, ensure_ascii=False) - - # get what we've streamed so far for arguments - # for the current tool - actual_call = tool_parser.streamed_args_for_tool[index] - if latest_delta_len > 0: - actual_call = actual_call[:-latest_delta_len] - - # check to see if there's anything left to stream - remaining_call = expected_call.replace(actual_call, "", 1) - # set that as a delta message - delta_message = self._create_remaining_args_delta( - delta_message, remaining_call, index - ) - # Send the finish response for each request.n only once # In OpenAI's API, when a tool is called, the # finish_reason is: # "tool_calls" for "auto" or "required" tool calls, # and "stop" for named tool calls. - if ( - auto_tools_called - or (tools_streamed[i] and not tool_choice_function_name) - or (self.use_harmony and harmony_tools_streamed[i]) + if (tools_streamed[i] and not tool_choice_function_name) or ( + self.use_harmony and harmony_tools_streamed[i] ): finish_reason_ = "tool_calls" else: @@ -1535,56 +1465,3 @@ def _should_stream_with_auto_tool_parsing(self, request: ChatCompletionRequest): and self.enable_auto_tools and request.tool_choice in ["auto", None] ) - - def _should_check_for_unstreamed_tool_arg_tokens( - self, - delta_message: DeltaMessage | None, - output: CompletionOutput, - ) -> bool: - """ - Check to see if we should check for unstreamed tool arguments tokens. - This is only applicable when auto tool parsing is enabled, the delta - is a tool call with arguments. - """ - - return bool( - # if there is a delta message that includes tool calls which - # include a function that has arguments - output.finish_reason is not None - and self.enable_auto_tools - and self.tool_parser - and delta_message - and delta_message.tool_calls - and delta_message.tool_calls[0] - and delta_message.tool_calls[0].function - and delta_message.tool_calls[0].function.arguments is not None - ) - - @staticmethod - def _create_remaining_args_delta( - delta_message: DeltaMessage, - remaining_call: str, - index: int, - ) -> DeltaMessage: - """ - Create a delta message for remaining tool arguments, preserving - id/type/name from the original delta. - """ - original_tc = next( - (tc for tc in delta_message.tool_calls if tc.index == index), - None, - ) - original_fn = original_tc.function if original_tc else None - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=index, - id=original_tc.id if original_tc else None, - type=original_tc.type if original_tc else None, - function=DeltaFunctionCall( - name=original_fn.name if original_fn else None, - arguments=remaining_call, - ), - ) - ] - ) diff --git a/vllm/parser/abstract_parser.py b/vllm/parser/abstract_parser.py index 2a13f138607b..96c04805bcda 100644 --- a/vllm/parser/abstract_parser.py +++ b/vllm/parser/abstract_parser.py @@ -320,6 +320,7 @@ def parse_delta( delta_token_ids: list[int], request: ChatCompletionRequest | ResponsesRequest, prompt_token_ids: list[int] | None = None, + finished: bool = False, ) -> DeltaMessage | None: """Parse a single streaming delta, orchestrating reasoning then tool call extraction via internal stream state. @@ -656,12 +657,28 @@ def _in_tool_call_phase(self, state: StreamState) -> bool: return False return state.reasoning_ended + def _append_unstreamed_tool_args( + self, + delta_message: DeltaMessage | None, + ) -> None: + """Append parsed-but-unstreamed tool-call arguments to *delta_message*.""" + if ( + self._tool_parser is not None + and delta_message + and delta_message.tool_calls + and (last_tc := delta_message.tool_calls[-1]).function + ): + last_tc.function.arguments = ( + last_tc.function.arguments or "" + ) + self._tool_parser.get_remaining_unstreamed_args() + def parse_delta( self, delta_text: str, delta_token_ids: list[int], request: ChatCompletionRequest | ResponsesRequest, prompt_token_ids: list[int] | None = None, + finished: bool = False, ) -> DeltaMessage | None: state = self._stream_state @@ -745,6 +762,10 @@ def parse_delta( state.previous_text = current_text state.previous_token_ids = current_token_ids + + if finished: + self._append_unstreamed_tool_args(delta_message) + return delta_message diff --git a/vllm/tool_parsers/abstract_tool_parser.py b/vllm/tool_parsers/abstract_tool_parser.py index c3438082a72d..94543b82350b 100644 --- a/vllm/tool_parsers/abstract_tool_parser.py +++ b/vllm/tool_parsers/abstract_tool_parser.py @@ -79,6 +79,25 @@ def __init__( else: self.tools = [] + def get_remaining_unstreamed_args(self) -> str: + """Return tool call arguments parsed but not yet streamed.""" + if not self.prev_tool_call_arr: + return "" + index = len(self.prev_tool_call_arr) - 1 + args = self.prev_tool_call_arr[index].get("arguments", {}) + if isinstance(args, str): + expected = args + else: + expected = json.dumps(args, ensure_ascii=False) + actual = ( + self.streamed_args_for_tool[index] + if index < len(self.streamed_args_for_tool) + else "" + ) + if expected.startswith(actual): + return expected[len(actual) :] + return "" + @cached_property def vocab(self) -> dict[str, int]: # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab