Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/tool_parsers/test_qwen3coder_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,55 @@ def test_streaming_multi_param_single_chunk(qwen3_tool_parser, qwen3_tokenizer):
assert args["unit"] == "fahrenheit"


def test_streaming_whole_body_after_start_token(qwen3_tool_parser):
"""Regression: with stream_interval=N and speculative decoding, the
entire <function=...>...</function></tool_call> body can arrive in one
delta after is_tool_call_started was flipped on a previous delta. The
parser must emit header AND args in that single delta — otherwise
prev_tool_call_arr stays at "{}" and the serving layer's remaining_call
backfill emits empty arguments.
"""
request = ChatCompletionRequest(model=MODEL, messages=[])

deltas = [
"<tool_call>",
# Header + open + param + close all collapsed into one delta.
"\n<function=get_current_weather>"
"\n<parameter=city>\nDallas\n</parameter>"
"\n<parameter=state>\nTX\n</parameter>"
"\n<parameter=unit>\nfahrenheit\n</parameter>"
"\n</function>"
"\n</tool_call>",
]

from tests.tool_parsers.utils import (
run_tool_extraction_streaming,
)

reconstructor = run_tool_extraction_streaming(
qwen3_tool_parser,
deltas,
request,
assert_one_tool_per_delta=False,
)

assert len(reconstructor.tool_calls) == 1
tc = reconstructor.tool_calls[0]
assert tc.function.name == "get_current_weather"
args = json.loads(tc.function.arguments)
assert args == {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}

# prev_tool_call_arr must reflect the parsed args, not the initial
# "{}" placeholder — the serving layer reads this to compute the
# remaining-args backfill at stream end.
assert len(qwen3_tool_parser.prev_tool_call_arr) == 1
assert json.loads(qwen3_tool_parser.prev_tool_call_arr[0]["arguments"]) == {
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}


def test_no_double_serialization_string_args(qwen3_tool_parser):
"""Regression: string arguments must not be double-serialized (PR #35615)."""
tools = [
Expand Down
73 changes: 32 additions & 41 deletions vllm/tool_parsers/qwen3coder_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,16 @@ def extract_tool_calls_streaming(
tool_start_idx : tool_end_idx + len(self.tool_call_end_token)
]

# A single delta can deliver multiple phases at once (header +
# opening brace + params + closing brace) when stream_interval is
# large or speculative decoding bursts many tokens. We accumulate
# all phases into one DeltaMessage instead of returning early
# after each phase, so a tool call that arrives in one shot is
# still streamed with its arguments populated.
pending_header_name: str | None = None
pending_header_id: str | None = None
pending_args = ""

# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
Expand Down Expand Up @@ -486,20 +496,11 @@ def extract_tool_calls_streaming(
# accesses streamed_args_for_tool[index].
self.streamed_args_for_tool.append("")

# Send header with function info
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""
),
type="function",
)
]
)
return None
pending_header_name = self.current_function_name
pending_header_id = self.current_tool_id
# Fall through to body processing in the same delta.
if not self.header_sent:
return None

# We've sent header, now handle function body
if self.in_function:
Expand All @@ -511,14 +512,7 @@ def extract_tool_calls_streaming(
if not self.json_started:
self.json_started = True
self.streamed_args_for_tool[self.current_tool_index] += "{"
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
]
)
pending_args += "{"

# Find all parameter start positions in current tool_text
param_starts = []
Expand Down Expand Up @@ -622,14 +616,7 @@ def extract_tool_calls_streaming(
len(self.streamed_args_for_tool),
)

return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=combined),
)
]
)
pending_args += combined

# Check for function end AFTER processing parameters.
# This ordering is critical: with speculative decoding a
Expand Down Expand Up @@ -672,21 +659,25 @@ def extract_tool_calls_streaming(
len(self.streamed_args_for_tool),
)

result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)
pending_args += "}"

self.in_function = False
self.json_closed = True
self.accumulated_params = {}

return result

if pending_header_name is not None or pending_args:
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=pending_header_id,
function=DeltaFunctionCall(
name=pending_header_name,
arguments=pending_args,
),
type="function" if pending_header_name else None,
)
]
)
return None

def get_structural_tag(self, request: ChatCompletionRequest):
Expand Down
Loading