Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions tests/test_tool_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,3 +1163,177 @@ def test_streaming_bare_multi_function_blocks(self):
assert len(emitted_calls) == 2
assert emitted_calls[0]["function"]["name"] == "func1"
assert emitted_calls[1]["function"]["name"] == "func2"


class TestQwenFunctionFormat:
"""Test Qwen parser's <function=name> format support."""

@pytest.fixture
def parser(self):
return QwenToolParser()

def test_function_format_with_parameters(self, parser):
"""Test <function=name><parameter=key>value</parameter></function>."""
text = "<function=get_weather><parameter=city>Prague</parameter></function>"
result = parser.extract_tool_calls(text)
assert result.tools_called
assert result.tool_calls[0]["name"] == "get_weather"
args = json.loads(result.tool_calls[0]["arguments"])
assert args["city"] == "Prague"

def test_function_format_with_json(self, parser):
"""Test <function=name>{"key": "val"}</function>."""
text = '<function=get_weather>{"city": "Prague"}</function>'
result = parser.extract_tool_calls(text)
assert result.tools_called
assert result.tool_calls[0]["name"] == "get_weather"
args = json.loads(result.tool_calls[0]["arguments"])
assert args["city"] == "Prague"

def test_function_format_multiple(self, parser):
"""Test multiple <function=...> blocks."""
text = (
'<function=read_file>{"path": "/a.py"}</function>'
'<function=write_file>{"path": "/b.py", "content": "hello"}</function>'
)
result = parser.extract_tool_calls(text)
assert result.tools_called
assert len(result.tool_calls) == 2
assert result.tool_calls[0]["name"] == "read_file"
assert result.tool_calls[1]["name"] == "write_file"

def test_function_format_with_think_tags(self, parser):
"""Test <function=...> with think tags."""
text = (
"<think>I need to check the weather.</think>\n"
'<function=get_weather>{"city": "Prague"}</function>'
)
result = parser.extract_tool_calls(text)
assert result.tools_called
assert result.tool_calls[0]["name"] == "get_weather"


class TestQwenStreamingBuffering:
"""Test Qwen parser streaming with partial-marker buffering."""

@pytest.fixture
def parser(self):
return QwenToolParser()

def test_streaming_function_format_complete(self, parser):
"""Test streaming with <function=name>...</function> format."""
chunks = [
"<function=get_weather>",
"<parameter=city>Prague</parameter>",
"</function>",
]
accumulated = ""
tool_calls_found = False
for chunk in chunks:
prev = accumulated
accumulated += chunk
r = parser.extract_tool_calls_streaming(
previous_text=prev,
current_text=accumulated,
delta_text=chunk,
)
if r is not None and "tool_calls" in r:
tool_calls_found = True
assert r["tool_calls"][0]["function"]["name"] == "get_weather"
break
assert tool_calls_found

def test_streaming_partial_marker_buffered(self, parser):
"""Test that partial '<function' is buffered (not leaked as content)."""
r = parser.extract_tool_calls_streaming(
previous_text="",
current_text="Sure.",
delta_text="Sure.",
)
assert r == {"content": "Sure."}

# Partial marker "<function" — should be buffered
r = parser.extract_tool_calls_streaming(
previous_text="Sure.",
current_text="Sure.<function",
delta_text="<function",
)
assert r is None # Buffered, not emitted

# "=" confirms tool call marker
r = parser.extract_tool_calls_streaming(
previous_text="Sure.<function",
current_text="Sure.<function=get_weather>",
delta_text="=get_weather>",
)
assert r is None # Inside incomplete function block

def test_streaming_false_positive_functional(self, parser):
"""Regression: '<functional' across chunk boundary must NOT be suppressed.

When a delta contains 'Look at <function', the content before the
partial marker ('Look at ') must be emitted immediately. Only the
marker suffix is buffered. On recovery, the marker prefix is
re-emitted with the next delta so no text is lost.
"""
# Token 1: "Look at <function" — emit "Look at ", buffer "<function"
r = parser.extract_tool_calls_streaming(
previous_text="",
current_text="Look at <function",
delta_text="Look at <function",
)
assert r == {"content": "Look at "}

# Token 2: "al interface" — confirms it's NOT a tool call
r = parser.extract_tool_calls_streaming(
previous_text="Look at <function",
current_text="Look at <functional interface",
delta_text="al interface",
)
assert r == {"content": "<functional interface"}

def test_streaming_false_positive_single_angle_bracket(self, parser):
"""Test that a single '<' followed by non-marker text is recovered."""
# "<" alone — partial marker
r = parser.extract_tool_calls_streaming(
previous_text="Hello",
current_text="Hello<",
delta_text="<",
)
assert r is None # Buffered

# "div>" — not a tool marker
r = parser.extract_tool_calls_streaming(
previous_text="Hello<",
current_text="Hello<div>",
delta_text="div>",
)
assert r is not None
assert "content" in r
assert "<" in r["content"]
assert "div>" in r["content"]

def test_streaming_multiple_function_blocks(self, parser):
"""Test streaming with multiple <function= blocks."""
chunks = [
'<function=func1>{"a": 1}</function>',
"\n",
"<function=func2>",
"<parameter=b>2</parameter>",
"</function>",
]
accumulated = ""
emitted_calls = []
for chunk in chunks:
prev = accumulated
accumulated += chunk
r = parser.extract_tool_calls_streaming(
previous_text=prev,
current_text=accumulated,
delta_text=chunk,
)
if r is not None and "tool_calls" in r:
emitted_calls.extend(r["tool_calls"])
assert len(emitted_calls) == 2
assert emitted_calls[0]["function"]["name"] == "func1"
assert emitted_calls[1]["function"]["name"] == "func2"
3 changes: 2 additions & 1 deletion vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2590,14 +2590,15 @@ async def stream_chat_completion(
yield f"data: {chunk.model_dump_json()}\n\n"

# Fallback: if tool parser accumulated text but never emitted tool_calls
# (e.g., </tool_call> never arrived - incomplete tool call)
# (e.g., </tool_call> never arrived, or <function= block still incomplete)
if (
tool_parser
and tool_accumulated_text
and not tool_calls_detected
and (
"<tool_call>" in tool_accumulated_text
or "<|tool_call>" in tool_accumulated_text
or "<function" in tool_accumulated_text
)
):
result = tool_parser.extract_tool_calls(tool_accumulated_text)
Expand Down
143 changes: 141 additions & 2 deletions vllm_mlx/tool_parsers/qwen_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
Handles Qwen's tool calling formats:
- XML style: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
- Bracket style: [Calling tool: func_name({"arg": "value"})]
- Function style: <function=name><parameter=key>value</parameter></function>
"""

import ast
import json
import re
import uuid
Expand All @@ -20,6 +22,24 @@
)


def _parse_param_value(val: str) -> Any:
"""Parse a parameter value, handling JSON literals and plain strings."""
try:
return json.loads(val)
except (json.JSONDecodeError, ValueError):
pass
try:
python_val = ast.literal_eval(val)
if isinstance(python_val, set):
python_val = sorted(python_val, key=str)
if isinstance(python_val, (complex, bytes)):
return val
json.dumps(python_val)
return python_val
except (ValueError, SyntaxError, TypeError):
return val


def generate_tool_id() -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:8]}"
Expand All @@ -33,6 +53,7 @@ class QwenToolParser(ToolParser):
Supports multiple Qwen tool call formats:
- XML: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
- Bracket: [Calling tool: func_name({"arg": "value"})]
- Function: <function=name><parameter=key>value</parameter></function>

Used when --enable-auto-tool-choice --tool-call-parser qwen are set.
"""
Expand All @@ -43,6 +64,12 @@ class QwenToolParser(ToolParser):
# Pattern for bracket-style: [Calling tool: func_name({...})]
BRACKET_PATTERN = re.compile(r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]", re.DOTALL)

# Pattern for function-style: <function=name>...</function>
FUNCTION_PATTERN = re.compile(r"<function=([^>]+)>(.*?)</function>", re.DOTALL)

# Pattern for parameter extraction: <parameter=key>value</parameter>
PARAM_PATTERN = re.compile(r"<parameter=([^>]+)>\s*(.*?)\s*</parameter>", re.DOTALL)

def extract_tool_calls(
self, model_output: str, request: dict[str, Any] | None = None
) -> ExtractedToolCallInformation:
Expand Down Expand Up @@ -101,6 +128,41 @@ def extract_tool_calls(
if xml_matches:
cleaned_text = self.XML_PATTERN.sub("", cleaned_text).strip()

# Try function-style: <function=name><parameter=key>value</parameter></function>
# Qwen3.5 generates this format natively.
if not tool_calls:
func_matches = self.FUNCTION_PATTERN.findall(cleaned_text)
for name, params_block in func_matches:
# Try JSON arguments first (e.g. <function=name>{"key": "val"}</function>)
params_block_stripped = params_block.strip()
if params_block_stripped.startswith("{"):
try:
arguments = json.loads(params_block_stripped)
tool_calls.append(
{
"id": generate_tool_id(),
"name": name.strip(),
"arguments": json.dumps(arguments, ensure_ascii=False),
}
)
continue
except json.JSONDecodeError:
pass
# Parse <parameter=key>value</parameter> tags
params = self.PARAM_PATTERN.findall(params_block)
arguments = {}
for p_name, p_value in params:
arguments[p_name.strip()] = _parse_param_value(p_value.strip())
tool_calls.append(
{
"id": generate_tool_id(),
"name": name.strip(),
"arguments": json.dumps(arguments, ensure_ascii=False),
}
)
if func_matches:
cleaned_text = self.FUNCTION_PATTERN.sub("", cleaned_text).strip()

if tool_calls:
return ExtractedToolCallInformation(
tools_called=True,
Expand All @@ -112,6 +174,30 @@ def extract_tool_calls(
tools_called=False, tool_calls=[], content=model_output
)

# Partial marker prefixes — when current_text ends with one of these,
# we suppress output until the next token confirms or denies a tool call.
# These are long enough to avoid false positives on normal text.
_PARTIAL_MARKERS = ("<function", "[Calling tool", "<tool_call")

def _has_partial_marker(self, text: str) -> bool:
"""Check if text ends with an incomplete tool call marker prefix."""
return self._get_partial_marker_len(text) > 0

def _get_partial_marker_len(self, text: str) -> int:
"""Return the length of a partial tool call marker suffix at end of text."""
tail = text[-20:]
best = 0
for marker in self._PARTIAL_MARKERS:
for length in range(len(marker), 0, -1):
if tail.endswith(marker[:length]) and length > best:
best = length
break
return best

def _was_buffering(self, previous_text: str) -> bool:
"""Check if the previous call was buffering a partial marker."""
return self._has_partial_marker(previous_text)

def extract_tool_calls_streaming(
self,
previous_text: str,
Expand All @@ -125,14 +211,67 @@ def extract_tool_calls_streaming(
"""
Extract tool calls from streaming Qwen model output.
"""
# Check for tool call markers
# Check for complete tool call markers
has_tool_marker = (
"<tool_call>" in current_text or "[Calling tool:" in current_text
"<tool_call>" in current_text
or "[Calling tool:" in current_text
or "<function=" in current_text
)

if not has_tool_marker:
# Buffer partial markers (e.g. "<function" before "=" arrives).
# Only the marker suffix is buffered; content before it in the
# same delta is emitted immediately so no text is lost.
if self._has_partial_marker(current_text):
marker_len = self._get_partial_marker_len(current_text)
marker_start = len(current_text) - marker_len
safe_chars = marker_start - len(previous_text)
if safe_chars > 0:
return {"content": delta_text[:safe_chars]}
return None
# If we were buffering before but the marker didn't complete,
# emit the buffered marker prefix together with the new delta.
if self._was_buffering(previous_text):
for marker in self._PARTIAL_MARKERS:
for length in range(len(marker), 0, -1):
prefix = marker[:length]
if previous_text.endswith(prefix):
return {"content": prefix + delta_text}
return {"content": delta_text}
return {"content": delta_text}

# Handle <function=name>...</function> (Qwen3.5 native format)
if "<function=" in current_text:
func_close_count = current_text.count("</function>")
prev_func_close = previous_text.count("</function>")

if current_text.count("<function=") > func_close_count:
# Inside an incomplete function block, suppress output
return None

if func_close_count > prev_func_close:
# New function block(s) completed
result = self.extract_tool_calls(current_text)
if result.tools_called:
new_calls = result.tool_calls[prev_func_close:]
if new_calls:
return {
"tool_calls": [
{
"index": prev_func_close + i,
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"],
},
}
for i, tc in enumerate(new_calls)
]
}

return None

# If we're in a tool call, accumulate and parse at the end
# For simplicity, return None during accumulation
if "</tool_call>" in delta_text or ")]" in delta_text:
Expand Down
Loading
Loading