diff --git a/tests/test_server.py b/tests/test_server.py index b8d2c3fd..6e08b211 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1042,6 +1042,93 @@ async def stream_chat(self, messages, **kwargs): assert payloads[2]["choices"][0]["delta"]["content"] == "world" assert payloads[2]["choices"][0]["finish_reason"] == "stop" + @pytest.mark.anyio + async def test_auto_parser_streams_bare_bracket_tool_calls(self, monkeypatch): + """Bare bracket tool calls should stream as structured tool_calls.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="[read(", finished=False), + GenerationOutput( + text="", + new_text='{"file_path": "/tmp/test.py"}', + finished=False, + ), + GenerationOutput( + text="", + new_text=")]", + finished=True, + finish_reason="stop", + prompt_tokens=4, + completion_tokens=3, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", True) + monkeypatch.setattr(server, "_tool_call_parser", "auto") + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[Message(role="user", content="hi")], + tools=[ + { + "type": "function", + "function": { + "name": "read", + "description": "Read a file", + "parameters": { + "type": "object", + "properties": {"file_path": {"type": "string"}}, + "required": ["file_path"], + }, + }, + } + ], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + tool_payloads = [ + payload + for payload in payloads + if payload["choices"] and payload["choices"][0]["delta"].get("tool_calls") + ] + + assert len(tool_payloads) == 1 + delta = tool_payloads[0]["choices"][0]["delta"] + assert delta["tool_calls"][0]["function"]["name"] == "read" + assert delta["tool_calls"][0]["function"]["arguments"] == ( + '{"file_path": "/tmp/test.py"}' + ) + assert delta["content"] is None + assert tool_payloads[0]["choices"][0]["finish_reason"] == "tool_calls" + @pytest.mark.anyio async def test_reasoning_stream_emits_structured_tool_calls(self, monkeypatch): """Tool markup after should emit tool_calls chunks.""" diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index 4f3c287d..1d37e948 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -555,6 +555,16 @@ def test_detects_qwen_bracket(self, parser): assert result.tools_called assert result.tool_calls[0]["name"] == "add" + def test_detects_bare_bracket(self, parser): + """Test auto detection of bare bracket format.""" + text = '[read({"file_path": "/tmp/test.py"})]' + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert result.tool_calls[0]["name"] == "read" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["file_path"] == "/tmp/test.py" + def test_detects_llama(self, parser): """Test auto detection of Llama format.""" text = '{"x": 2}' @@ -651,6 +661,39 @@ def test_tool_call_id_uniqueness(self): assert len(ids) == len(set(ids)), "Tool call IDs should be unique" +class TestBareBracketStreaming: + """Test streaming for bare bracket tool calls.""" + + def test_auto_streaming_bare_bracket(self): + """Auto parser should emit structured tool calls for bare bracket streaming.""" + parser = AutoToolParser() + + chunks = [ + "[read(", + '{"file_path": "/tmp/test.py"}', + ")]", + ] + accumulated = "" + tool_calls_found = False + + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + tool_calls_found = True + assert r["tool_calls"][0]["function"]["name"] == "read" + args = json.loads(r["tool_calls"][0]["function"]["arguments"]) + assert args["file_path"] == "/tmp/test.py" + break + + assert tool_calls_found + + class TestStreamingParsing: """Test streaming tool call parsing.""" diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 95c8c610..ae75cb91 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -219,6 +219,8 @@ def _resolve_top_p(request_value: float | None) -> float: "", ' None: @@ -1561,7 +1563,11 @@ def _get_streaming_tool_parser(request: ChatCompletionRequest | None): def _streaming_tool_markup_possible(text: str) -> bool: """Heuristic marker check to avoid parser work on ordinary text chunks.""" - return any(marker in text for marker in _STREAMING_TOOL_MARKERS) + return ( + any(marker in text for marker in _STREAMING_TOOL_MARKERS) + or _STREAMING_BARE_BRACKET_MARKER.search(text) is not None + or _STREAMING_BARE_BRACKET_PARTIAL.search(text) is not None + ) def load_embedding_model( diff --git a/vllm_mlx/tool_parsers/auto_tool_parser.py b/vllm_mlx/tool_parsers/auto_tool_parser.py index 37ab10d7..c759e33d 100644 --- a/vllm_mlx/tool_parsers/auto_tool_parser.py +++ b/vllm_mlx/tool_parsers/auto_tool_parser.py @@ -55,6 +55,8 @@ class AutoToolParser(ToolParser): NEMOTRON_PARAM_PATTERN = re.compile( r"]+)>\s*(.*?)\s*", re.DOTALL ) + BARE_BRACKET_PATTERN = re.compile(r"\[(\w+)\((\{.*?\})\)\]", re.DOTALL) + BARE_BRACKET_PARTIAL_PATTERN = re.compile(r"\[\w+\($") def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None @@ -150,7 +152,35 @@ def extract_tool_calls( if bracket_matches: cleaned_text = self.QWEN_BRACKET_PATTERN.sub("", cleaned_text).strip() - # 4. Try Nemotron pattern (before Qwen XML as it's more specific) + # 4. Try bare bracket format: [func({...})] + bare_matches = self.BARE_BRACKET_PATTERN.findall(cleaned_text) + for name, args_str in bare_matches: + try: + arguments = json.loads(args_str) + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": ( + json.dumps(arguments, ensure_ascii=False) + if isinstance(arguments, dict) + else str(arguments) + ), + } + ) + except json.JSONDecodeError: + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": args_str, + } + ) + + if bare_matches: + cleaned_text = self.BARE_BRACKET_PATTERN.sub("", cleaned_text).strip() + + # 5. Try Nemotron pattern (before Qwen XML as it's more specific) nemotron_matches = self.NEMOTRON_PATTERN.findall(cleaned_text) for name, params_block in nemotron_matches: params = self.NEMOTRON_PARAM_PATTERN.findall(params_block) @@ -166,7 +196,7 @@ def extract_tool_calls( if nemotron_matches: cleaned_text = self.NEMOTRON_PATTERN.sub("", cleaned_text).strip() - # 5. Try Qwen/Hermes XML pattern + # 6. Try Qwen/Hermes XML pattern xml_matches = self.QWEN_XML_PATTERN.findall(cleaned_text) for match in xml_matches: try: @@ -191,7 +221,7 @@ def extract_tool_calls( if xml_matches: cleaned_text = self.QWEN_XML_PATTERN.sub("", cleaned_text).strip() - # 6. Try Llama pattern + # 7. Try Llama pattern llama_matches = self.LLAMA_PATTERN.findall(cleaned_text) for name, args_str in llama_matches: try: @@ -219,7 +249,7 @@ def extract_tool_calls( if llama_matches: cleaned_text = self.LLAMA_PATTERN.sub("", cleaned_text).strip() - # 7. Fallback: Try raw JSON + # 8. Fallback: Try raw JSON if not tool_calls: raw_calls = self._parse_raw_json_tool_calls(cleaned_text) if raw_calls: @@ -339,11 +369,24 @@ def extract_tool_calls_streaming( "<|tool_call>", self.MISTRAL_TOKEN, "[Calling tool:", + "[", "", "