Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/tool_use/test_tool_choice_required.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
previous_text = current_text

assert len(messages) > 0

combined_messages = "["
for message in messages:
if message.tool_calls[0].function.name:
Expand All @@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages += "}]"
assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json


def test_streaming_output_valid_with_trailing_extra_data():
self = MagicMock()

output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
output_json = json.dumps(output) + "\nDONE"

previous_text = ""
function_name_returned = False
messages = []
delta_len = 3
for i in range(0, len(output_json), delta_len):
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text

delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
)

if delta_message:
messages.append(delta_message)

previous_text = current_text

assert len(messages) > 0
10 changes: 8 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import regex as re
from fastapi import Request
from openai_harmony import Message as OpenAIMessage
from partial_json_parser.core.options import Allow

from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
Expand Down Expand Up @@ -76,6 +77,7 @@
)
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters

Expand Down Expand Up @@ -509,8 +511,12 @@ def extract_tool_call_required_streaming(
# if the current text is empty, we cannot parse it
return None, function_name_returned
try:
obj = partial_json_parser.loads(current_text)
except partial_json_parser.core.exceptions.MalformedJSON:
flags = Allow.ALL
obj, _ = partial_json_loads(current_text, flags)
except (
partial_json_parser.core.exceptions.MalformedJSON,
json.JSONDecodeError,
):
logger.debug("not enough tokens to parse into JSON yet")
obj = None

Expand Down