Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions tests/test_minimax_tool_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Tests for MiniMax tool call parsing."""

import json
import unittest

from vllm_mlx.api.tool_calling import parse_tool_calls


class TestMiniMaxToolCallParsing(unittest.TestCase):
"""Test parsing of MiniMax-style tool calls."""

def test_single_tool_call(self):
text = """<minimax:tool_call>
<invoke name="get_weather">
<parameter name="city">Wanaka</parameter>
<parameter name="units">celsius</parameter>
</invoke>
</minimax:tool_call>"""

cleaned, tool_calls = parse_tool_calls(text)
self.assertIsNotNone(tool_calls)
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0].function.name, "get_weather")
args = json.loads(tool_calls[0].function.arguments)
self.assertEqual(args["city"], "Wanaka")
self.assertEqual(args["units"], "celsius")
self.assertEqual(cleaned, "")

def test_tool_call_with_surrounding_text(self):
text = """Let me check the weather for you.
<minimax:tool_call>
<invoke name="get_weather">
<parameter name="city">Wanaka</parameter>
</invoke>
</minimax:tool_call>"""

cleaned, tool_calls = parse_tool_calls(text)
self.assertIsNotNone(tool_calls)
self.assertEqual(len(tool_calls), 1)
self.assertIn("Let me check", cleaned)

def test_multiple_tool_calls(self):
text = """<minimax:tool_call>
<invoke name="search">
<parameter name="query">MiniMax M2.5</parameter>
</invoke>
</minimax:tool_call>
<minimax:tool_call>
<invoke name="read_file">
<parameter name="path">/tmp/test.txt</parameter>
</invoke>
</minimax:tool_call>"""

cleaned, tool_calls = parse_tool_calls(text)
self.assertIsNotNone(tool_calls)
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0].function.name, "search")
self.assertEqual(tool_calls[1].function.name, "read_file")

def test_json_parameter_value(self):
text = """<minimax:tool_call>
<invoke name="create_event">
<parameter name="title">Meeting</parameter>
<parameter name="attendees">["stuart", "frida"]</parameter>
</invoke>
</minimax:tool_call>"""

cleaned, tool_calls = parse_tool_calls(text)
self.assertIsNotNone(tool_calls)
args = json.loads(tool_calls[0].function.arguments)
self.assertEqual(args["title"], "Meeting")
self.assertEqual(args["attendees"], ["stuart", "frida"])

def test_numeric_parameter(self):
text = """<minimax:tool_call>
<invoke name="set_temperature">
<parameter name="value">42</parameter>
</invoke>
</minimax:tool_call>"""

cleaned, tool_calls = parse_tool_calls(text)
args = json.loads(tool_calls[0].function.arguments)
self.assertEqual(args["value"], 42)

def test_no_parameters(self):
text = """<minimax:tool_call>
<invoke name="get_time">
</invoke>
</minimax:tool_call>"""

cleaned, tool_calls = parse_tool_calls(text)
self.assertIsNotNone(tool_calls)
self.assertEqual(tool_calls[0].function.name, "get_time")
args = json.loads(tool_calls[0].function.arguments)
self.assertEqual(args, {})

def test_with_think_tags_preserved(self):
text = """<think>
I should check the weather first.
</think>
<minimax:tool_call>
<invoke name="get_weather">
<parameter name="city">Wanaka</parameter>
</invoke>
</minimax:tool_call>"""

cleaned, tool_calls = parse_tool_calls(text)
self.assertIsNotNone(tool_calls)
self.assertIn("<think>", cleaned)

def test_no_minimax_tool_calls(self):
text = "Just a regular message with no tool calls."
cleaned, tool_calls = parse_tool_calls(text)
self.assertIsNone(tool_calls)
self.assertEqual(cleaned, text)

def test_tool_call_id_format(self):
text = """<minimax:tool_call>
<invoke name="test">
<parameter name="x">1</parameter>
</invoke>
</minimax:tool_call>"""

_, tool_calls = parse_tool_calls(text)
self.assertTrue(tool_calls[0].id.startswith("call_"))
self.assertEqual(tool_calls[0].type, "function")


if __name__ == "__main__":
unittest.main()
42 changes: 42 additions & 0 deletions vllm_mlx/api/tool_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def parse_tool_calls(
Parse tool calls from model output.

Supports multiple formats:
- MiniMax: <minimax:tool_call><invoke name="..."><parameter name="p">v</parameter></invoke></minimax:tool_call>
- Qwen3 bracket: [Calling tool: function_name({"arg": "value"})]
- Qwen: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
- Llama: <function=name>{"arg": "value"}</function>
Expand All @@ -106,6 +107,47 @@ def parse_tool_calls(
tool_calls = []
cleaned_text = text

# Pattern for MiniMax-style: <minimax:tool_call><invoke name="fn"><parameter name="p">v</parameter></invoke></minimax:tool_call>
minimax_pattern = r"<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>"
minimax_matches = re.findall(minimax_pattern, text, re.DOTALL)

for invoke_block in minimax_matches:
# Parse <invoke name="..."> blocks within the tool_call
invoke_pattern = r'<invoke\s+name="([^"]+)">(.*?)</invoke>'
invoke_matches = re.findall(invoke_pattern, invoke_block, re.DOTALL)

for name, params_block in invoke_matches:
# Parse <parameter name="...">value</parameter> pairs
param_pattern = r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>'
params = re.findall(param_pattern, params_block, re.DOTALL)
arguments = {}
for p_name, p_value in params:
# Try to parse value as JSON (for nested objects/arrays/numbers)
try:
arguments[p_name] = json.loads(p_value)
except (json.JSONDecodeError, ValueError):
arguments[p_name] = p_value

tool_calls.append(
ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
type="function",
function=FunctionCall(
name=name.strip(),
arguments=json.dumps(arguments),
),
)
)

# Remove MiniMax tool call tags from cleaned text
if minimax_matches:
cleaned_text = re.sub(
r"<minimax:tool_call>\s*.*?\s*</minimax:tool_call>",
"",
cleaned_text,
flags=re.DOTALL,
).strip()

# Pattern for Qwen3 bracket-style: [Calling tool: function_name({...})]
bracket_pattern = r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]"
bracket_matches = re.findall(bracket_pattern, text, re.DOTALL)
Expand Down
Loading