From a26962941e1328e10863506df39376e63be1fd64 Mon Sep 17 00:00:00 2001 From: Martin Vit Date: Tue, 31 Mar 2026 13:35:01 +0000 Subject: [PATCH] fix: tool_choice="required" falls back to tool_parser for non-JSON formats When tool_choice="required" and the model produces non-JSON tool calls (e.g. XML from Qwen3 with qwen3_coder parser), both non-streaming and streaming paths now fall back to the configured tool_parser instead of silently dropping tool calls or failing. Non-streaming (engine/serving.py): - Replace contextlib.suppress(ValidationError) from #36841 with try/except that preserves crash-safety (content or "") while adding fallback to tool_parser.extract_tool_calls() for non-JSON formats. Streaming (chat_completion/serving.py): - Initialize tool_parsers for "required" (not just "auto"). - Use separate if blocks (not if/else) so tool parsing runs in the same iteration when reasoning ends (critical for MTP/speculative decoding where and tool call arrive in one chunk). - Dual parser: try tool_parser first (XML), fall back to JSON-only extract_tool_call_required_streaming() for non-deterministic MTP. Signed-off-by: voipmonitor --- .../test_tool_choice_required_fallback.py | 240 ++++++++++++++++++ .../openai/chat_completion/serving.py | 100 ++++++-- vllm/entrypoints/openai/engine/serving.py | 43 +++- 3 files changed, 352 insertions(+), 31 deletions(-) create mode 100644 tests/tool_use/test_tool_choice_required_fallback.py diff --git a/tests/tool_use/test_tool_choice_required_fallback.py b/tests/tool_use/test_tool_choice_required_fallback.py new file mode 100644 index 000000000000..a5357115d9e0 --- /dev/null +++ b/tests/tool_use/test_tool_choice_required_fallback.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for tool_choice="required" fallback to tool_parser. + +When tool_choice="required" and the model produces non-JSON tool calls +(e.g. XML format from Qwen3), the non-streaming path should fall back +to the configured tool_parser instead of returning a 400 error. + +See: https://github.com/vllm-project/vllm/pull/35936 +""" + +import json +from unittest.mock import MagicMock + +import pytest + +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, +) +from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.tool_parsers.abstract_tool_parser import ToolParser +from vllm.tool_parsers.utils import ExtractedToolCallInformation + +pytestmark = pytest.mark.cpu_test + +MODEL = "test-model" + +SAMPLE_TOOLS = [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + }, + }, + "required": ["city"], + }, + }, + ), +] + +# JSON format tool call (standard) +JSON_TOOL_CALL = json.dumps( + [{"name": "get_current_weather", "parameters": {"city": "Dallas"}}] +) + +# XML format tool call (Qwen3 style) +XML_TOOL_CALL = """ + + +Dallas + + +""" + + +class MockToolParser(ToolParser): + """A minimal tool parser that recognizes XML-style tool calls.""" + + def __init__(self, tokenizer, tools=None): + super().__init__(tokenizer, tools) + + def extract_tool_calls(self, model_output, request): + from vllm.entrypoints.openai.engine.protocol import ( + FunctionCall, + ToolCall, + ) + + # Simple check: if it contains ", model_output) + param_matches = re.findall( + r"\n(.*?)\n", + model_output, + re.DOTALL, + ) + + if func_match: + name = func_match.group(1) + args = {k: v.strip() for k, v in param_matches} + tool_calls = [ + ToolCall( + function=FunctionCall( + name=name, + arguments=json.dumps(args), + ) + ) + ] + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=None + ) + + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming(self, *args, **kwargs): + raise NotImplementedError + + +def _make_request(tool_choice="required"): + return ChatCompletionRequest( + model=MODEL, + messages=[], + tools=SAMPLE_TOOLS, + tool_choice=tool_choice, + ) + + +class TestToolChoiceRequiredNonStreaming: + """Tests for _parse_tool_calls_from_content with tool_choice='required'.""" + + def test_json_content_succeeds_directly(self): + """Valid JSON tool calls should be parsed without fallback.""" + request = _make_request() + function_calls, content = OpenAIServing._parse_tool_calls_from_content( + request=request, + tokenizer=None, + enable_auto_tools=False, + tool_parser_cls=None, + content=JSON_TOOL_CALL, + ) + + assert function_calls is not None + assert len(function_calls) == 1 + assert function_calls[0].name == "get_current_weather" + assert json.loads(function_calls[0].arguments) == {"city": "Dallas"} + assert content is None # cleared after tool call + + def test_xml_content_falls_back_to_tool_parser(self): + """XML tool calls should fail JSON validation, then fall back + to the configured tool_parser.""" + request = _make_request() + tokenizer = MagicMock() + + function_calls, content = OpenAIServing._parse_tool_calls_from_content( + request=request, + tokenizer=tokenizer, + enable_auto_tools=True, + tool_parser_cls=MockToolParser, + content=XML_TOOL_CALL, + ) + + assert function_calls is not None + assert len(function_calls) == 1 + assert function_calls[0].name == "get_current_weather" + assert json.loads(function_calls[0].arguments) == {"city": "Dallas"} + assert content is None # cleared after tool call + + def test_xml_content_no_tool_calls_without_tool_parser(self): + """Without a configured tool_parser, XML content should result + in no tool calls (graceful degradation).""" + request = _make_request() + + function_calls, content = OpenAIServing._parse_tool_calls_from_content( + request=request, + tokenizer=None, + enable_auto_tools=False, + tool_parser_cls=None, + content=XML_TOOL_CALL, + ) + + assert function_calls is not None + assert len(function_calls) == 0 + assert content is None # still cleared + + def test_xml_content_no_tool_calls_without_enable_auto_tools(self): + """Even with tool_parser_cls, if enable_auto_tools is False, + the fallback should not activate.""" + request = _make_request() + tokenizer = MagicMock() + + function_calls, content = OpenAIServing._parse_tool_calls_from_content( + request=request, + tokenizer=tokenizer, + enable_auto_tools=False, + tool_parser_cls=MockToolParser, + content=XML_TOOL_CALL, + ) + + assert function_calls is not None + assert len(function_calls) == 0 + assert content is None # still cleared + + def test_multiple_json_tool_calls(self): + """Multiple JSON tool calls should all be parsed.""" + content = json.dumps( + [ + {"name": "get_current_weather", "parameters": {"city": "Dallas"}}, + {"name": "get_current_weather", "parameters": {"city": "Berlin"}}, + ] + ) + request = _make_request() + + function_calls, returned_content = ( + OpenAIServing._parse_tool_calls_from_content( + request=request, + tokenizer=None, + enable_auto_tools=False, + tool_parser_cls=None, + content=content, + ) + ) + + assert function_calls is not None + assert len(function_calls) == 2 + assert function_calls[0].name == "get_current_weather" + assert function_calls[1].name == "get_current_weather" + assert json.loads(function_calls[0].arguments) == {"city": "Dallas"} + assert json.loads(function_calls[1].arguments) == {"city": "Berlin"} + + def test_none_content_does_not_crash(self): + """When content is None (e.g. max_tokens exceeded), should not + crash (regression test for #36841).""" + request = _make_request() + + function_calls, content = OpenAIServing._parse_tool_calls_from_content( + request=request, + tokenizer=None, + enable_auto_tools=False, + tool_parser_cls=None, + content=None, + ) + + assert function_calls is not None + assert len(function_calls) == 0 + assert content is None diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index a426836afd35..96e0289ed147 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -558,7 +558,8 @@ async def chat_completion_stream_generator( # Prepare the tool parser if it's needed try: - if tool_choice_auto and self.tool_parser: + need_tool_parser = tool_choice_auto or request.tool_choice == "required" + if need_tool_parser and self.tool_parser: if tokenizer is None: raise ValueError( "Tokenizer not available when `skip_tokenizer_init=True`" @@ -861,6 +862,7 @@ async def chat_completion_stream_generator( and prompt_is_reasoning_end_arr[i] ): reasoning_end_arr[i] = True + current_token_ids = output_token_ids if reasoning_parser and not reasoning_end_arr[i]: delta_message = ( @@ -875,33 +877,89 @@ async def chat_completion_stream_generator( ) if reasoning_parser.is_reasoning_end(output_token_ids): reasoning_end_arr[i] = True + # Strip reasoning-related token ids so + # the tool parser only sees content ids. + current_token_ids = ( + reasoning_parser.extract_content_ids( + output_token_ids + ) + ) if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None else: - # reasoning ended current_text = "" - else: - # either finished reasoning or no reasoning at all - content = current_text - - delta_message, function_name_returned[i] = ( - self.extract_tool_call_required_streaming( - previous_text=previous_text, - current_text=content, - delta_text=delta_text, - function_name_returned=fn_name_returned, - tool_call_idx=history_tool_call_cnt, + # Process tool calls after reasoning is done. + # This is a separate `if` (not `else`) so that + # when reasoning ends and tool call tokens + # arrive in the same chunk (common with + # MTP/speculative decoding), the tool parser + # runs immediately in the same iteration. + if reasoning_end_arr[i] if reasoning_parser else True: + if tool_parser is not None and self.enable_auto_tools: + assert added_content_delta_arr is not None + delta_token_ids = output_token_ids + if not added_content_delta_arr[i]: + added_content_delta_arr[i] = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=(previous_token_ids), + current_token_ids=(current_token_ids), + delta_token_ids=delta_token_ids, + request=request, + ) ) - ) - if ( - delta_message - and delta_message.tool_calls - and delta_message.tool_calls[0].id is not None - ): - history_tool_call_cnt += 1 - tools_streamed[i] = True + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True + # If tool_parser didn't produce tool + # calls, also try the JSON-only parser + # as fallback (MTP may produce JSON + # instead of XML non-deterministically). + if not tools_streamed[i]: + content = current_text + json_msg, function_name_returned[i] = ( + self.extract_tool_call_required_streaming( + previous_text=previous_text, + current_text=content, + delta_text=delta_text, + function_name_returned=(fn_name_returned), + tool_call_idx=(history_tool_call_cnt), + ) + ) + if ( + json_msg + and json_msg.tool_calls + and json_msg.tool_calls[0].id is not None + ): + delta_message = json_msg + history_tool_call_cnt += 1 + tools_streamed[i] = True + else: + content = current_text + delta_message, function_name_returned[i] = ( + self.extract_tool_call_required_streaming( + previous_text=previous_text, + current_text=content, + delta_text=delta_text, + function_name_returned=(fn_name_returned), + tool_call_idx=(history_tool_call_cnt), + ) + ) + if ( + delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0].id is not None + ): + history_tool_call_cnt += 1 + tools_streamed[i] = True # handle streaming deltas for tools with "auto" tool choice # and reasoning parser diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index f5f011a96f27..910ff53e56c5 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import contextlib import json import time from collections.abc import AsyncGenerator, Mapping @@ -803,19 +802,43 @@ def _parse_tool_calls_from_content( ) content = None # Clear content since tool is called. elif request.tool_choice == "required": - tool_calls = [] - with contextlib.suppress(ValidationError): - content = content or "" + content = content or "" + try: tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json( content ) - for tool_call in tool_calls: - function_calls.append( - FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + for tool_call in tool_calls: + function_calls.append( + FunctionCall( + name=tool_call.name, + arguments=json.dumps( + tool_call.parameters, ensure_ascii=False + ), + ) ) - ) + except (ValidationError, json.JSONDecodeError): + # JSON validation failed — fall back to the configured + # tool parser (e.g. qwen3_coder) which may understand + # non-JSON formats such as XML tool calls. + if tool_parser_cls and enable_auto_tools and tokenizer is not None: + try: + tool_parser = tool_parser_cls(tokenizer, request.tools) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + raise e + tool_call_info = tool_parser.extract_tool_calls( + content, + request=request, # type: ignore + ) + if tool_call_info is not None and tool_call_info.tools_called: + function_calls.extend( + FunctionCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_call_info.tool_calls + ) content = None # Clear content since tool is called. elif ( tool_parser_cls