diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index 01c1360818eb..e99165f3569a 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from copy import deepcopy -from unittest.mock import MagicMock import pytest import regex as re @@ -11,7 +10,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionToolsParam, ) -from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat +from vllm.tool_parsers.streaming import extract_required_tool_call_streaming from vllm.tool_parsers.utils import get_json_schema_from_tools pytestmark = pytest.mark.cpu_test @@ -281,8 +280,6 @@ def test_structured_outputs_json_without_parameters( @pytest.mark.parametrize("empty_params", [False, True]) @pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) def test_streaming_output_valid(output, empty_params, delta_len): - self = MagicMock() - output = deepcopy(output) if empty_params: output = [{"name": o["name"], "parameters": {}} for o in output] @@ -295,14 +292,13 @@ def test_streaming_output_valid(output, empty_params, delta_len): delta_text = output_json[i : i + delta_len] current_text = previous_text + delta_text - delta_message, function_name_returned = ( - OpenAIServingChat.extract_tool_call_required_streaming( - self, - previous_text=previous_text, - current_text=current_text, - delta_text=delta_text, - function_name_returned=function_name_returned, - ) + delta_message, function_name_returned = extract_required_tool_call_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + function_name_returned=function_name_returned, + tool_call_idx=None, + tool_call_id_type="random", ) if delta_message: @@ -332,8 +328,6 @@ def test_streaming_output_valid(output, empty_params, delta_len): def test_streaming_output_valid_with_trailing_extra_data(): - self = MagicMock() - output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}] output_json = json.dumps(output) + "\nDONE" @@ -345,14 +339,13 @@ def test_streaming_output_valid_with_trailing_extra_data(): delta_text = output_json[i : i + delta_len] current_text = previous_text + delta_text - delta_message, function_name_returned = ( - OpenAIServingChat.extract_tool_call_required_streaming( - self, - previous_text=previous_text, - current_text=current_text, - delta_text=delta_text, - function_name_returned=function_name_returned, - ) + delta_message, function_name_returned = extract_required_tool_call_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + function_name_returned=function_name_returned, + tool_call_idx=None, + tool_call_id_type="random", ) if delta_message: diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index f42cc8afeeb1..24f48e7b3bf4 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -70,10 +70,6 @@ from vllm.renderers import ChatParams from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike -from vllm.tool_parsers.streaming import ( - extract_named_tool_call_streaming, - extract_required_tool_call_streaming, -) from vllm.utils.collection_utils import as_list from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser @@ -389,23 +385,6 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return self.response_role return request.messages[-1]["role"] - def extract_tool_call_required_streaming( - self, - previous_text: str, - current_text: str | None, - delta_text: str, - function_name_returned: bool, - tool_call_idx: int | None = None, - ) -> tuple[DeltaMessage | None, bool]: - return extract_required_tool_call_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_text, - function_name_returned=function_name_returned, - tool_call_idx=tool_call_idx, - tool_call_id_type=self.tool_call_id_type, - ) - async def chat_completion_stream_generator( self, request: ChatCompletionRequest, @@ -448,22 +427,7 @@ async def chat_completion_stream_generator( and self._should_stream_with_auto_tool_parsing(request) ) - # Determine whether required/named tool_choice should fall back to - # the auto tool_parser path instead of the standard JSON-based parsing. - # This happens when the parser declares supports_required_and_named=False - # (e.g. GLM models that output XML instead of JSON). - tool_choice_uses_parser = ( - self.tool_parser is not None - and not self.tool_parser.supports_required_and_named - and request.tools - and ( - request.tool_choice == "required" - or isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) - ) - ) - all_previous_token_ids: list[list[int]] | None - function_name_returned = [False] * num_choices if self.tool_call_id_type == "kimi_k2": history_tool_call_cnt = get_history_tool_calls_cnt(conversation) else: @@ -477,10 +441,10 @@ async def chat_completion_stream_generator( if ( is_mistral_grammar_path or tool_choice_auto - or tool_choice_uses_parser + or tool_choice_function_name + or request.tool_choice == "required" or reasoning_parser ): - # These are only required in "auto" tool choice case all_previous_token_ids = [[] for _ in range(num_choices)] reasoning_end_arr = [False] * num_choices prompt_is_reasoning_end_arr: list[bool | None] = [None] * num_choices @@ -501,6 +465,10 @@ async def chat_completion_stream_generator( ) for _ in range(num_choices) ] + for p in parsers: + if p is not None: + p._stream_state.tool_call_id_type = self.tool_call_id_type + p._stream_state.history_tool_call_cnt = history_tool_call_cnt else: parsers = [None] * num_choices except Exception as e: @@ -677,7 +645,8 @@ async def chat_completion_stream_generator( if ( is_mistral_grammar_path or tool_choice_auto - or tool_choice_uses_parser + or tool_choice_function_name + or request.tool_choice == "required" or reasoning_parser ): assert previous_texts is not None @@ -731,135 +700,6 @@ async def chat_completion_stream_generator( current_token_ids = result.current_token_ids if result.tools_called: tools_streamed[i] = True - # handle streaming deltas for tools with named tool_choice - # Skip when tool_choice_uses_parser so it falls through - # to the auto tool_parser branches below. - elif tool_choice_function_name and not tool_choice_uses_parser: - # When encountering think end id in prompt_token_ids - # i.e {"enable_thinking": False}, - # check BEFORE calling the parser to avoid a spurious - # reasoning delta on the first chunk. - if ( - reasoning_parser - and not reasoning_end_arr[i] - and prompt_is_reasoning_end_arr[i] - ): - reasoning_end_arr[i] = True - - if ( - reasoning_parser - and not reasoning_end_arr[i] - and not reasoning_parser.is_reasoning_end( - previous_token_ids - ) - ): - assert reasoning_parser is not None - delta_message = ( - reasoning_parser.extract_reasoning_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - ) - ) - # When encountering think end id in delta_token_ids, - # set reasoning status to end. - # Only keep 'content', remove 'reasoning'. - if reasoning_parser.is_reasoning_end( - as_list(output.token_ids) - ): - reasoning_end_arr[i] = True - if delta_message and delta_message.content: - current_text = delta_message.content - delta_message.content = None - else: - current_text = "" - else: - # Just to add remaining `content` - if reasoning_parser: - delta_text = previous_text + delta_text - current_text = "" - - delta_message, function_name_returned[i] = ( - extract_named_tool_call_streaming( - delta_text=delta_text, - function_name=tool_choice_function_name, - function_name_returned=function_name_returned[i], - tool_call_idx=history_tool_call_cnt, - tool_call_id_type=self.tool_call_id_type, - tokenizer=tokenizer, - tool_call_array_index=i, - ) - ) - if ( - delta_message - and delta_message.tool_calls - and delta_message.tool_calls[0].id is not None - ): - history_tool_call_cnt += 1 - tools_streamed[i] = True - - # Skip when tool_choice_uses_parser so it falls through - # to the auto tool_parser branches below. - elif ( - request.tool_choice == "required" - and not tool_choice_uses_parser - ): - assert previous_texts is not None - previous_text = previous_texts[i] - current_text = previous_text + delta_text - fn_name_returned = function_name_returned[i] - output_token_ids = as_list(output.token_ids) - - if ( - reasoning_parser is not None - and not reasoning_end_arr[i] - and prompt_is_reasoning_end_arr[i] - ): - reasoning_end_arr[i] = True - - if reasoning_parser and not reasoning_end_arr[i]: - delta_message = ( - reasoning_parser.extract_reasoning_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output_token_ids, - ) - ) - if reasoning_parser.is_reasoning_end(output_token_ids): - reasoning_end_arr[i] = True - if delta_message and delta_message.content: - current_text = delta_message.content - delta_message.content = None - else: - # reasoning ended - current_text = "" - - else: - # either finished reasoning or no reasoning at all - content = current_text - - delta_message, function_name_returned[i] = ( - self.extract_tool_call_required_streaming( - previous_text=previous_text, - current_text=content, - delta_text=delta_text, - function_name_returned=fn_name_returned, - tool_call_idx=history_tool_call_cnt, - ) - ) - if ( - delta_message - and delta_message.tool_calls - and delta_message.tool_calls[0].id is not None - ): - history_tool_call_cnt += 1 - tools_streamed[i] = True elif parser is not None: delta_message = parser.parse_delta( @@ -878,7 +718,8 @@ async def chat_completion_stream_generator( if ( is_mistral_grammar_path or tool_choice_auto - or tool_choice_uses_parser + or tool_choice_function_name + or request.tool_choice == "required" or reasoning_parser ) and not self.use_harmony: assert previous_texts is not None diff --git a/vllm/parser/abstract_parser.py b/vllm/parser/abstract_parser.py index 1b861dcade87..46acc7fb155e 100644 --- a/vllm/parser/abstract_parser.py +++ b/vllm/parser/abstract_parser.py @@ -587,9 +587,15 @@ def _extract_tool_calls_streaming( tool_call_id_type: str = "random", function_name_returned: bool = False, ) -> tuple[DeltaMessage | None, bool]: - if request.tool_choice and isinstance( - request.tool_choice, - (ToolChoiceFunction, ChatCompletionNamedToolChoiceParam), + assert self._tool_parser is not None + supports_required_and_named = self._tool_parser.supports_required_and_named + if ( + supports_required_and_named + and request.tool_choice + and isinstance( + request.tool_choice, + (ToolChoiceFunction, ChatCompletionNamedToolChoiceParam), + ) ): delta_message, function_name_returned = extract_named_tool_call_streaming( delta_text=delta_text, @@ -601,7 +607,7 @@ def _extract_tool_calls_streaming( ) return delta_message, function_name_returned - if request.tool_choice == "required": + if supports_required_and_named and request.tool_choice == "required": delta_message, function_name_returned = ( extract_required_tool_call_streaming( previous_text=previous_text, @@ -706,6 +712,12 @@ def parse_delta( function_name_returned=state.function_name_returned, ) ) + if ( + delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0].id is not None + ): + state.history_tool_call_cnt += 1 # No phase active: pass through as content if (