Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,93 @@ async def stream_chat(self, messages, **kwargs):
assert payloads[2]["choices"][0]["delta"]["content"] == "world"
assert payloads[2]["choices"][0]["finish_reason"] == "stop"

@pytest.mark.anyio
async def test_auto_parser_streams_bare_bracket_tool_calls(self, monkeypatch):
"""Bare bracket tool calls should stream as structured tool_calls."""
from vllm_mlx.engine.base import GenerationOutput
from vllm_mlx.server import (
ChatCompletionRequest,
Message,
stream_chat_completion,
)
import vllm_mlx.server as server

class FakeEngine:
model_name = "fake-engine"

async def stream_chat(self, messages, **kwargs):
chunks = [
GenerationOutput(text="", new_text="[read(", finished=False),
GenerationOutput(
text="",
new_text='{"file_path": "/tmp/test.py"}',
finished=False,
),
GenerationOutput(
text="",
new_text=")]",
finished=True,
finish_reason="stop",
prompt_tokens=4,
completion_tokens=3,
),
]
for chunk in chunks:
yield chunk

monkeypatch.setattr(server, "_model_name", "served-model")
monkeypatch.setattr(server, "_reasoning_parser", None)
monkeypatch.setattr(server, "_enable_auto_tool_choice", True)
monkeypatch.setattr(server, "_tool_call_parser", "auto")
monkeypatch.setattr(server, "_tool_parser_instance", None)

request = ChatCompletionRequest(
model="served-model",
messages=[Message(role="user", content="hi")],
tools=[
{
"type": "function",
"function": {
"name": "read",
"description": "Read a file",
"parameters": {
"type": "object",
"properties": {"file_path": {"type": "string"}},
"required": ["file_path"],
},
},
}
],
stream=True,
)

chunks = [
chunk
async for chunk in stream_chat_completion(
FakeEngine(), request.messages, request
)
]

payloads = [
json.loads(chunk.removeprefix("data: ").strip())
for chunk in chunks
if chunk != "data: [DONE]\n\n"
]
tool_payloads = [
payload
for payload in payloads
if payload["choices"] and payload["choices"][0]["delta"].get("tool_calls")
]

assert len(tool_payloads) == 1
delta = tool_payloads[0]["choices"][0]["delta"]
assert delta["tool_calls"][0]["function"]["name"] == "read"
assert delta["tool_calls"][0]["function"]["arguments"] == (
'{"file_path": "/tmp/test.py"}'
)
assert delta["content"] is None
assert tool_payloads[0]["choices"][0]["finish_reason"] == "tool_calls"

@pytest.mark.anyio
async def test_reasoning_stream_emits_structured_tool_calls(self, monkeypatch):
"""Tool markup after </think> should emit tool_calls chunks."""
Expand Down
43 changes: 43 additions & 0 deletions tests/test_tool_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,16 @@ def test_detects_qwen_bracket(self, parser):
assert result.tools_called
assert result.tool_calls[0]["name"] == "add"

def test_detects_bare_bracket(self, parser):
"""Test auto detection of bare bracket format."""
text = '[read({"file_path": "/tmp/test.py"})]'
result = parser.extract_tool_calls(text)

assert result.tools_called
assert result.tool_calls[0]["name"] == "read"
args = json.loads(result.tool_calls[0]["arguments"])
assert args["file_path"] == "/tmp/test.py"

def test_detects_llama(self, parser):
"""Test auto detection of Llama format."""
text = '<function=multiply>{"x": 2}</function>'
Expand Down Expand Up @@ -651,6 +661,39 @@ def test_tool_call_id_uniqueness(self):
assert len(ids) == len(set(ids)), "Tool call IDs should be unique"


class TestBareBracketStreaming:
"""Test streaming for bare bracket tool calls."""

def test_auto_streaming_bare_bracket(self):
"""Auto parser should emit structured tool calls for bare bracket streaming."""
parser = AutoToolParser()

chunks = [
"[read(",
'{"file_path": "/tmp/test.py"}',
")]",
]
accumulated = ""
tool_calls_found = False

for chunk in chunks:
prev = accumulated
accumulated += chunk
r = parser.extract_tool_calls_streaming(
previous_text=prev,
current_text=accumulated,
delta_text=chunk,
)
if r is not None and "tool_calls" in r:
tool_calls_found = True
assert r["tool_calls"][0]["function"]["name"] == "read"
args = json.loads(r["tool_calls"][0]["function"]["arguments"])
assert args["file_path"] == "/tmp/test.py"
break

assert tool_calls_found


class TestStreamingParsing:
"""Test streaming tool call parsing."""

Expand Down
8 changes: 7 additions & 1 deletion vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def _resolve_top_p(request_value: float | None) -> float:
"<minimax:tool_call>",
'<invoke name="',
)
_STREAMING_BARE_BRACKET_MARKER = re.compile(r"\[\w+\(\{")
_STREAMING_BARE_BRACKET_PARTIAL = re.compile(r"\[\w+\($")


def _load_prefix_cache_from_disk() -> None:
Expand Down Expand Up @@ -1561,7 +1563,11 @@ def _get_streaming_tool_parser(request: ChatCompletionRequest | None):

def _streaming_tool_markup_possible(text: str) -> bool:
"""Heuristic marker check to avoid parser work on ordinary text chunks."""
return any(marker in text for marker in _STREAMING_TOOL_MARKERS)
return (
any(marker in text for marker in _STREAMING_TOOL_MARKERS)
or _STREAMING_BARE_BRACKET_MARKER.search(text) is not None
or _STREAMING_BARE_BRACKET_PARTIAL.search(text) is not None
)


def load_embedding_model(
Expand Down
53 changes: 48 additions & 5 deletions vllm_mlx/tool_parsers/auto_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class AutoToolParser(ToolParser):
NEMOTRON_PARAM_PATTERN = re.compile(
r"<parameter=([^>]+)>\s*(.*?)\s*</parameter>", re.DOTALL
)
BARE_BRACKET_PATTERN = re.compile(r"\[(\w+)\((\{.*?\})\)\]", re.DOTALL)
BARE_BRACKET_PARTIAL_PATTERN = re.compile(r"\[\w+\($")

def extract_tool_calls(
self, model_output: str, request: dict[str, Any] | None = None
Expand Down Expand Up @@ -150,7 +152,35 @@ def extract_tool_calls(
if bracket_matches:
cleaned_text = self.QWEN_BRACKET_PATTERN.sub("", cleaned_text).strip()

# 4. Try Nemotron pattern (before Qwen XML as it's more specific)
# 4. Try bare bracket format: [func({...})]
bare_matches = self.BARE_BRACKET_PATTERN.findall(cleaned_text)
for name, args_str in bare_matches:
try:
arguments = json.loads(args_str)
tool_calls.append(
{
"id": generate_tool_id(),
"name": name.strip(),
"arguments": (
json.dumps(arguments, ensure_ascii=False)
if isinstance(arguments, dict)
else str(arguments)
),
}
)
except json.JSONDecodeError:
tool_calls.append(
{
"id": generate_tool_id(),
"name": name.strip(),
"arguments": args_str,
}
)

if bare_matches:
cleaned_text = self.BARE_BRACKET_PATTERN.sub("", cleaned_text).strip()

# 5. Try Nemotron pattern (before Qwen XML as it's more specific)
nemotron_matches = self.NEMOTRON_PATTERN.findall(cleaned_text)
for name, params_block in nemotron_matches:
params = self.NEMOTRON_PARAM_PATTERN.findall(params_block)
Expand All @@ -166,7 +196,7 @@ def extract_tool_calls(
if nemotron_matches:
cleaned_text = self.NEMOTRON_PATTERN.sub("", cleaned_text).strip()

# 5. Try Qwen/Hermes XML pattern
# 6. Try Qwen/Hermes XML pattern
xml_matches = self.QWEN_XML_PATTERN.findall(cleaned_text)
for match in xml_matches:
try:
Expand All @@ -191,7 +221,7 @@ def extract_tool_calls(
if xml_matches:
cleaned_text = self.QWEN_XML_PATTERN.sub("", cleaned_text).strip()

# 6. Try Llama pattern
# 7. Try Llama pattern
llama_matches = self.LLAMA_PATTERN.findall(cleaned_text)
for name, args_str in llama_matches:
try:
Expand Down Expand Up @@ -219,7 +249,7 @@ def extract_tool_calls(
if llama_matches:
cleaned_text = self.LLAMA_PATTERN.sub("", cleaned_text).strip()

# 7. Fallback: Try raw JSON
# 8. Fallback: Try raw JSON
if not tool_calls:
raw_calls = self._parse_raw_json_tool_calls(cleaned_text)
if raw_calls:
Expand Down Expand Up @@ -339,11 +369,24 @@ def extract_tool_calls_streaming(
"<|tool_call>",
self.MISTRAL_TOKEN,
"[Calling tool:",
"[",
"<tool_call>",
"<function=",
]

has_marker = any(m in current_text for m in markers)
has_marker = any(m in current_text for m in markers) and (
self.BARE_BRACKET_PARTIAL_PATTERN.search(current_text) is not None
or self.BARE_BRACKET_PATTERN.search(current_text) is not None
or "[Calling tool:" in current_text
or self.MISTRAL_TOKEN in current_text
or "<" in current_text
)

if (
self.BARE_BRACKET_PARTIAL_PATTERN.search(current_text) is not None
and self.BARE_BRACKET_PATTERN.search(current_text) is None
):
return None

if not has_marker:
return {"content": delta_text}
Expand Down
Loading