Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions tests/tool_parsers/test_qwen3coder_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,3 +1146,157 @@ def test_no_double_serialization_string_args(qwen3_tool_parser):
args = json.loads(raw_arguments)
assert args["message"] == "hello world"
assert '\\"hello world\\"' not in raw_arguments


def test_extract_tool_calls_streaming_split_tag(qwen3_tool_parser):
"""
This highlights the need to use current_text instead of delta_text.
"""
request = ChatCompletionRequest(model=MODEL, messages=[])

# Iteration 1: "<tool"
prev_text_1 = "I will use a tool."
delta_text_1 = "<tool"
curr_text_1 = prev_text_1 + delta_text_1

msg1 = qwen3_tool_parser.extract_tool_calls_streaming(
previous_text=prev_text_1,
current_text=curr_text_1,
delta_text=delta_text_1,
previous_token_ids=[1, 2, 3],
current_token_ids=[1, 2, 3, 4],
delta_token_ids=[4],
request=request
)

# Iteration 2: "_call>"
prev_text_2 = curr_text_1
delta_text_2 = "_call>"
curr_text_2 = prev_text_2 + delta_text_2

msg2 = qwen3_tool_parser.extract_tool_calls_streaming(
previous_text=prev_text_2,
current_text=curr_text_2,
delta_text=delta_text_2,
previous_token_ids=[1, 2, 3, 4],
current_token_ids=[1, 2, 3, 4, 5],
delta_token_ids=[5],
request=request
)

# The assertion must verify that the is_tool_call_started variable correctly switches to True
assert qwen3_tool_parser.is_tool_call_started is True, "is_tool_call_started should be True when '<tool_call>' is completed in current_text."

# and that the function does not return fragments of the tag in DeltaMessage(content=...)
if msg1 and msg1.content:
assert "<tool" not in msg1.content
if msg2 and msg2.content:
assert "_call>" not in msg2.content



def test_extract_tool_calls_streaming_speculative_decode_loss(qwen3_tool_parser):
"""
if json_started=False, and the delta contains the parameters AND the end of the tool call,
the parser should not just return '{' and lose the parameters.
"""

request = ChatCompletionRequest(model="test", messages=[])

text1 = "<tool_call>\n<function=test>\n"
qwen3_tool_parser.extract_tool_calls_streaming(
"", text1, text1, [], [1], [1], request
)

# Delta 2 has the rest of the tool call
delta_str = "<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>"
text2 = text1 + delta_str
delta2 = qwen3_tool_parser.extract_tool_calls_streaming(
text1, text2, delta_str, [1], [1,2], [2], request
)

# The parameters should be in delta2!
assert delta2 is not None
assert delta2.tool_calls is not None
assert len(delta2.tool_calls) == 1
args = delta2.tool_calls[0].function.arguments
assert "Paris" in args, f"Arguments lost! Got: {args}"


def test_extract_tool_calls_streaming_various_chunk_sizes(qwen3_tool_parser):
"""
Test streaming with various chunk sizes using the exact template from Qwen 3.6.
"""

request = ChatCompletionRequest(model="test", messages=[])

# Exact template format from Qwen 3.6
template_text = """<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>"""

# Test with different chunk sizes to simulate different network/speculative decoding behaviors
for chunk_size in [1, 3, 15, len(template_text)]:
# Reset parser state
qwen3_tool_parser._reset_streaming_state()

tool_states = {}

# Simulate custom streaming to precisely control chunk sizes
current_text = ""
previous_text = ""
ptr = 0

while ptr < len(template_text):
delta = template_text[ptr:ptr+chunk_size]
previous_text = current_text
current_text += delta
ptr += chunk_size

delta_message = qwen3_tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=request
)

if delta_message and delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None,
}

if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments

assert 0 in tool_states
assert tool_states[0]["name"] == "example_function_name"

import json
args = json.loads(tool_states[0]["arguments"])
assert args["example_parameter_1"] == "value_1"
assert args["example_parameter_2"] == "This is the value for the second parameter\nthat can span\nmultiple lines"
133 changes: 71 additions & 62 deletions vllm/tool_parsers/qwen3coder_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import find_tool_properties
from vllm.tool_parsers.utils import find_tool_properties, partial_tag_overlap

logger = init_logger(__name__)

Expand Down Expand Up @@ -109,6 +109,8 @@ def _reset_streaming_state(self):
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
self._sent_content_idx = 0
Comment thread
ExtReMLapin marked this conversation as resolved.
self.current_tool_index = 0
Comment thread
ExtReMLapin marked this conversation as resolved.

def _convert_param_value(
self, param_value: str, param_name: str, param_config: dict, func_name: str
Expand Down Expand Up @@ -372,6 +374,22 @@ def extract_tool_calls_streaming(
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
# Find the end of the tool call that just finished and update
# _sent_content_idx to prevent it from leaking into content.
search_idx = 0
for _ in range(self.current_tool_index + 1):
search_idx = current_text.find(self.tool_call_start_token,
search_idx)
if search_idx == -1:
break
end_idx = current_text.find(self.tool_call_end_token,
search_idx)
if end_idx != -1:
self._sent_content_idx = max(
self._sent_content_idx,
end_idx + len(self.tool_call_end_token))
search_idx += len(self.tool_call_start_token)

# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
Expand All @@ -380,47 +398,55 @@ def extract_tool_calls_streaming(
self.json_closed = False
self.accumulated_params = {}

# Check if there are more tool calls
tool_starts = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts:
# No more tool calls
self.is_tool_call_started = False
# Always reset is_tool_call_started when a tool call ends.
# This allows correctly sending any content between or after
# tool calls.
self.is_tool_call_started = False
# Continue processing next tool
return None

content_message = None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
tool_starts_count = current_text.count(self.tool_call_start_token)
if (
self.tool_call_start_token_id in delta_token_ids
or self.tool_call_start_token in delta_text
or tool_starts_count > self.current_tool_index
):
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[
: delta_text.index(self.tool_call_start_token)
]
last_start = current_text.find(self.tool_call_start_token, self._sent_content_idx)
if last_start != -1 and last_start > self._sent_content_idx:
content_before = current_text[self._sent_content_idx:last_start]
self._sent_content_idx = last_start
if content_before:
return DeltaMessage(content=content_before)
return None
content_message = DeltaMessage(content=content_before)
else:
overlap = partial_tag_overlap(current_text, self.tool_call_start_token)
sendable_idx = len(current_text) - overlap

# Check if we're between tool calls - skip whitespace
if (
current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""
):
# We just ended a tool call, skip whitespace
self._sent_content_idx = len(current_text)
return None
Comment thread
ExtReMLapin marked this conversation as resolved.
# Normal content, no tool call
return DeltaMessage(content=delta_text)

if sendable_idx > self._sent_content_idx:
content = current_text[self._sent_content_idx:sendable_idx]
self._sent_content_idx = sendable_idx
if content:
return DeltaMessage(content=content)
return None

# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# We're past all tool calls, shouldn't be here
return None
return content_message

# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
Expand All @@ -434,8 +460,7 @@ def extract_tool_calls_streaming(
idx += len(self.tool_call_start_token)

if self.current_tool_index >= len(tool_start_positions):
# No more tool calls to process yet
return None
return content_message

tool_start_idx = tool_start_positions[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
Expand All @@ -447,6 +472,7 @@ def extract_tool_calls_streaming(
tool_start_idx : tool_end_idx + len(self.tool_call_end_token)
]

tool_call_fragments = None
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
Expand Down Expand Up @@ -479,21 +505,16 @@ def extract_tool_calls_streaming(
# accesses streamed_args_for_tool[index].
self.streamed_args_for_tool.append("")

# Send header with function info
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""
),
type="function",
)
]
tool_call_fragments = DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(name=self.current_function_name, arguments=""),
type="function",
)
return None
if not self.header_sent:
return content_message

arguments_to_emit = ""
# We've sent header, now handle function body
if self.in_function:
# Always send opening brace first, regardless of whether
Expand All @@ -504,16 +525,8 @@ def extract_tool_calls_streaming(
if not self.json_started:
self.json_started = True
self.streamed_args_for_tool[self.current_tool_index] += "{"
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
]
)
arguments_to_emit += "{"

# Find all parameter start positions in current tool_text
param_starts = []
search_idx = 0
while True:
Expand Down Expand Up @@ -614,15 +627,7 @@ def extract_tool_calls_streaming(
self.current_tool_index,
len(self.streamed_args_for_tool),
)

return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=combined),
)
]
)
arguments_to_emit += combined

# Check for function end AFTER processing parameters.
# This ordering is critical: with speculative decoding a
Expand Down Expand Up @@ -664,20 +669,24 @@ def extract_tool_calls_streaming(
self.current_tool_index,
len(self.streamed_args_for_tool),
)

result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)

arguments_to_emit += "}"
self.in_function = False
self.json_closed = True
self.accumulated_params = {}

return result
if tool_call_fragments or arguments_to_emit:
if not tool_call_fragments:
tool_call_fragments = DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=arguments_to_emit),
)
else:
tool_call_fragments.function.arguments += arguments_to_emit

if content_message:
content_message.tool_calls = [tool_call_fragments]
return content_message
else:
return DeltaMessage(tool_calls=[tool_call_fragments])

return None
return content_message
Loading