diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index 35ed8d215f73..6ff37255e48d 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len): previous_text = current_text assert len(messages) > 0 + combined_messages = "[" for message in messages: if message.tool_calls[0].function.name: @@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len): combined_messages += "}]" assert json.loads(combined_messages) == output assert json.dumps(json.loads(combined_messages)) == output_json + + +def test_streaming_output_valid_with_trailing_extra_data(): + self = MagicMock() + + output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}] + output_json = json.dumps(output) + "\nDONE" + + previous_text = "" + function_name_returned = False + messages = [] + delta_len = 3 + for i in range(0, len(output_json), delta_len): + delta_text = output_json[i : i + delta_len] + current_text = previous_text + delta_text + + delta_message, function_name_returned = ( + OpenAIServingChat.extract_tool_call_required_streaming( + self, + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + function_name_returned=function_name_returned, + ) + ) + + if delta_message: + messages.append(delta_message) + + previous_text = current_text + + assert len(messages) > 0 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5a916f39b128..326f3ef198a1 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -13,6 +13,7 @@ import regex as re from fastapi import Request from openai_harmony import Message as OpenAIMessage +from partial_json_parser.core.options import Allow from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( @@ -76,6 +77,7 @@ ) from vllm.tool_parsers import ToolParser from vllm.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.tool_parsers.utils import partial_json_loads from vllm.utils.collection_utils import as_list from vllm.v1.sample.logits_processor import validate_logits_processors_parameters @@ -509,8 +511,12 @@ def extract_tool_call_required_streaming( # if the current text is empty, we cannot parse it return None, function_name_returned try: - obj = partial_json_parser.loads(current_text) - except partial_json_parser.core.exceptions.MalformedJSON: + flags = Allow.ALL + obj, _ = partial_json_loads(current_text, flags) + except ( + partial_json_parser.core.exceptions.MalformedJSON, + json.JSONDecodeError, + ): logger.debug("not enough tokens to parse into JSON yet") obj = None