diff --git a/tests/test_minimax_tool_calling.py b/tests/test_minimax_tool_calling.py new file mode 100644 index 000000000..2b94f967b --- /dev/null +++ b/tests/test_minimax_tool_calling.py @@ -0,0 +1,130 @@ +"""Tests for MiniMax tool call parsing.""" + +import json +import unittest + +from vllm_mlx.api.tool_calling import parse_tool_calls + + +class TestMiniMaxToolCallParsing(unittest.TestCase): + """Test parsing of MiniMax-style tool calls.""" + + def test_single_tool_call(self): + text = """ + +Wanaka +celsius + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0].function.name, "get_weather") + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["city"], "Wanaka") + self.assertEqual(args["units"], "celsius") + self.assertEqual(cleaned, "") + + def test_tool_call_with_surrounding_text(self): + text = """Let me check the weather for you. + + +Wanaka + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 1) + self.assertIn("Let me check", cleaned) + + def test_multiple_tool_calls(self): + text = """ + +MiniMax M2.5 + + + + +/tmp/test.txt + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0].function.name, "search") + self.assertEqual(tool_calls[1].function.name, "read_file") + + def test_json_parameter_value(self): + text = """ + +Meeting +["stuart", "frida"] + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["title"], "Meeting") + self.assertEqual(args["attendees"], ["stuart", "frida"]) + + def test_numeric_parameter(self): + text = """ + +42 + +""" + + cleaned, tool_calls = parse_tool_calls(text) + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["value"], 42) + + def test_no_parameters(self): + text = """ + + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(tool_calls[0].function.name, "get_time") + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args, {}) + + def test_with_think_tags_preserved(self): + text = """ +I should check the weather first. + + + +Wanaka + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertIn("", cleaned) + + def test_no_minimax_tool_calls(self): + text = "Just a regular message with no tool calls." + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNone(tool_calls) + self.assertEqual(cleaned, text) + + def test_tool_call_id_format(self): + text = """ + +1 + +""" + + _, tool_calls = parse_tool_calls(text) + self.assertTrue(tool_calls[0].id.startswith("call_")) + self.assertEqual(tool_calls[0].type, "function") + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py index 1443c1674..364b65993 100644 --- a/vllm_mlx/api/tool_calling.py +++ b/vllm_mlx/api/tool_calling.py @@ -89,6 +89,7 @@ def parse_tool_calls( Parse tool calls from model output. Supports multiple formats: + - MiniMax: v - Qwen3 bracket: [Calling tool: function_name({"arg": "value"})] - Qwen: {"name": "...", "arguments": {...}} - Llama: {"arg": "value"} @@ -106,6 +107,47 @@ def parse_tool_calls( tool_calls = [] cleaned_text = text + # Pattern for MiniMax-style: v + minimax_pattern = r"\s*(.*?)\s*" + minimax_matches = re.findall(minimax_pattern, text, re.DOTALL) + + for invoke_block in minimax_matches: + # Parse blocks within the tool_call + invoke_pattern = r'(.*?)' + invoke_matches = re.findall(invoke_pattern, invoke_block, re.DOTALL) + + for name, params_block in invoke_matches: + # Parse value pairs + param_pattern = r'\s*(.*?)\s*' + params = re.findall(param_pattern, params_block, re.DOTALL) + arguments = {} + for p_name, p_value in params: + # Try to parse value as JSON (for nested objects/arrays/numbers) + try: + arguments[p_name] = json.loads(p_value) + except (json.JSONDecodeError, ValueError): + arguments[p_name] = p_value + + tool_calls.append( + ToolCall( + id=f"call_{uuid.uuid4().hex[:8]}", + type="function", + function=FunctionCall( + name=name.strip(), + arguments=json.dumps(arguments), + ), + ) + ) + + # Remove MiniMax tool call tags from cleaned text + if minimax_matches: + cleaned_text = re.sub( + r"\s*.*?\s*", + "", + cleaned_text, + flags=re.DOTALL, + ).strip() + # Pattern for Qwen3 bracket-style: [Calling tool: function_name({...})] bracket_pattern = r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]" bracket_matches = re.findall(bracket_pattern, text, re.DOTALL)