Skip to content
Closed

Closed #43613

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/tool_use/test_tool_choice_required.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,48 @@ def test_streaming_output_valid(output, empty_params, delta_len):
assert json.dumps(json.loads(combined_messages)) == output_json


def test_streaming_output_emits_header_for_each_tool_call():
output = [
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
{"name": "get_forecast", "parameters": {"city": "Berlin", "days": 3}},
]
output_json = json.dumps(output)

previous_text = ""
function_name_returned = False
headers: dict[int, str] = {}
arguments: dict[int, str] = {}

for i in range(0, len(output_json), 4):
delta_text = output_json[i : i + 4]
current_text = previous_text + delta_text

delta_message, function_name_returned = extract_required_tool_call_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
tool_call_idx=None,
tool_call_id_type="random",
)

if delta_message:
tool_call = delta_message.tool_calls[0]
index = tool_call.index
if tool_call.function.name:
headers[index] = tool_call.function.name
if tool_call.function.arguments:
arguments[index] = (
arguments.get(index, "") + tool_call.function.arguments
)

previous_text = current_text

assert headers == {0: "get_current_weather", 1: "get_forecast"}
assert json.loads(arguments[0]) == output[0]["parameters"]
assert json.loads(arguments[1]) == output[1]["parameters"]


def test_streaming_output_valid_with_trailing_extra_data():
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
output_json = json.dumps(output) + "\nDONE"
Expand Down
65 changes: 44 additions & 21 deletions vllm/tool_parsers/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,34 @@ def filter_delta_text(
return updated_delta, passed_zero


def _load_partial_tool_array(text: str) -> list | None:
if not text:
return None
try:
obj, _ = partial_json_loads(text, Allow.ALL)
except (
partial_json_parser.core.exceptions.MalformedJSON,
json.JSONDecodeError,
):
return None

if not isinstance(obj, list) or not obj:
return None
return obj


def _target_required_tool_index(obj: list) -> int:
idx = len(obj) - 1
current_tool_call = obj[idx]
if (
idx > 0
and isinstance(current_tool_call, dict)
and "parameters" not in current_tool_call
):
return idx - 1
return idx


def extract_named_tool_call_streaming(
*,
delta_text: str,
Expand Down Expand Up @@ -112,29 +140,30 @@ def extract_required_tool_call_streaming(
if current_text is None or current_text == "":
# if the current text is empty, we cannot parse it
return None, function_name_returned
try:
flags = Allow.ALL
obj, _ = partial_json_loads(current_text, flags)
except (
partial_json_parser.core.exceptions.MalformedJSON,
json.JSONDecodeError,
):
obj = None
obj = _load_partial_tool_array(current_text)

# check if the current text is a valid array
# containing a partial tool calling object
# if not repeat
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
if obj is None:
function_name_returned = False
delta_message = None
else:
_, finishes_previous_tool = filter_delta_text(delta_text, previous_text)
# take the last tool call from the generated list
current_tool_call = obj[-1]
target_index = _target_required_tool_index(obj)
current_tool_call = obj[target_index]
previous_obj = _load_partial_tool_array(previous_text)
if previous_obj is not None and target_index != _target_required_tool_index(
previous_obj
):
function_name_returned = False

# once parameters have been generated the name is complete as well
if not finishes_previous_tool and (
"name" not in current_tool_call or "parameters" not in current_tool_call
if not isinstance(current_tool_call, dict) or (
not finishes_previous_tool
and (
"name" not in current_tool_call or "parameters" not in current_tool_call
)
):
function_name_returned = False
delta_message = None
Expand All @@ -147,12 +176,6 @@ def extract_required_tool_call_streaming(
arguments = param_match.group(1) if param_match else ""
arguments, _ = filter_delta_text(arguments, previous_text)

# if this iteration finishes a previous tool call but a
# new incomplete tool is already generated, take the
# previous from the list
if finishes_previous_tool and "parameters" not in current_tool_call:
current_tool_call = obj[-2]

function_name_returned = True
tool_call_id = make_tool_call_id(
id_type=tool_call_id_type,
Expand All @@ -166,7 +189,7 @@ def extract_required_tool_call_streaming(
function=DeltaFunctionCall(
name=current_tool_call["name"], arguments=arguments
),
index=len(obj) - 1,
index=target_index,
type="function",
)
]
Expand All @@ -185,7 +208,7 @@ def extract_required_tool_call_streaming(
name=None,
arguments=delta_text,
),
index=len(obj) - 1,
index=target_index,
)
]
)
Expand Down
Loading