diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index 7caaffbf5..4f3c287d1 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -1163,3 +1163,177 @@ def test_streaming_bare_multi_function_blocks(self): assert len(emitted_calls) == 2 assert emitted_calls[0]["function"]["name"] == "func1" assert emitted_calls[1]["function"]["name"] == "func2" + + +class TestQwenFunctionFormat: + """Test Qwen parser's format support.""" + + @pytest.fixture + def parser(self): + return QwenToolParser() + + def test_function_format_with_parameters(self, parser): + """Test value.""" + text = "Prague" + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["city"] == "Prague" + + def test_function_format_with_json(self, parser): + """Test {"key": "val"}.""" + text = '{"city": "Prague"}' + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["city"] == "Prague" + + def test_function_format_multiple(self, parser): + """Test multiple blocks.""" + text = ( + '{"path": "/a.py"}' + '{"path": "/b.py", "content": "hello"}' + ) + result = parser.extract_tool_calls(text) + assert result.tools_called + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["name"] == "read_file" + assert result.tool_calls[1]["name"] == "write_file" + + def test_function_format_with_think_tags(self, parser): + """Test with think tags.""" + text = ( + "I need to check the weather.\n" + '{"city": "Prague"}' + ) + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + + +class TestQwenStreamingBuffering: + """Test Qwen parser streaming with partial-marker buffering.""" + + @pytest.fixture + def parser(self): + return QwenToolParser() + + def test_streaming_function_format_complete(self, parser): + """Test streaming with ... format.""" + chunks = [ + "", + "Prague", + "", + ] + accumulated = "" + tool_calls_found = False + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + tool_calls_found = True + assert r["tool_calls"][0]["function"]["name"] == "get_weather" + break + assert tool_calls_found + + def test_streaming_partial_marker_buffered(self, parser): + """Test that partial '" — not a tool marker + r = parser.extract_tool_calls_streaming( + previous_text="Hello<", + current_text="Hello
", + delta_text="div>", + ) + assert r is not None + assert "content" in r + assert "<" in r["content"] + assert "div>" in r["content"] + + def test_streaming_multiple_function_blocks(self, parser): + """Test streaming with multiple {"a": 1}', + "\n", + "", + "2", + "", + ] + accumulated = "" + emitted_calls = [] + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + emitted_calls.extend(r["tool_calls"]) + assert len(emitted_calls) == 2 + assert emitted_calls[0]["function"]["name"] == "func1" + assert emitted_calls[1]["function"]["name"] == "func2" diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index d4f24de22..e080be8ff 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -2590,7 +2590,7 @@ async def stream_chat_completion( yield f"data: {chunk.model_dump_json()}\n\n" # Fallback: if tool parser accumulated text but never emitted tool_calls - # (e.g., never arrived - incomplete tool call) + # (e.g., never arrived, or " in tool_accumulated_text or "<|tool_call>" in tool_accumulated_text + or "{"name": "func", "arguments": {...}} - Bracket style: [Calling tool: func_name({"arg": "value"})] +- Function style: value """ +import ast import json import re import uuid @@ -20,6 +22,24 @@ ) +def _parse_param_value(val: str) -> Any: + """Parse a parameter value, handling JSON literals and plain strings.""" + try: + return json.loads(val) + except (json.JSONDecodeError, ValueError): + pass + try: + python_val = ast.literal_eval(val) + if isinstance(python_val, set): + python_val = sorted(python_val, key=str) + if isinstance(python_val, (complex, bytes)): + return val + json.dumps(python_val) + return python_val + except (ValueError, SyntaxError, TypeError): + return val + + def generate_tool_id() -> str: """Generate a unique tool call ID.""" return f"call_{uuid.uuid4().hex[:8]}" @@ -33,6 +53,7 @@ class QwenToolParser(ToolParser): Supports multiple Qwen tool call formats: - XML: {"name": "func", "arguments": {...}} - Bracket: [Calling tool: func_name({"arg": "value"})] + - Function: value Used when --enable-auto-tool-choice --tool-call-parser qwen are set. """ @@ -43,6 +64,12 @@ class QwenToolParser(ToolParser): # Pattern for bracket-style: [Calling tool: func_name({...})] BRACKET_PATTERN = re.compile(r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]", re.DOTALL) + # Pattern for function-style: ... + FUNCTION_PATTERN = re.compile(r"]+)>(.*?)", re.DOTALL) + + # Pattern for parameter extraction: value + PARAM_PATTERN = re.compile(r"]+)>\s*(.*?)\s*", re.DOTALL) + def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None ) -> ExtractedToolCallInformation: @@ -101,6 +128,41 @@ def extract_tool_calls( if xml_matches: cleaned_text = self.XML_PATTERN.sub("", cleaned_text).strip() + # Try function-style: value + # Qwen3.5 generates this format natively. + if not tool_calls: + func_matches = self.FUNCTION_PATTERN.findall(cleaned_text) + for name, params_block in func_matches: + # Try JSON arguments first (e.g. {"key": "val"}) + params_block_stripped = params_block.strip() + if params_block_stripped.startswith("{"): + try: + arguments = json.loads(params_block_stripped) + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + continue + except json.JSONDecodeError: + pass + # Parse value tags + params = self.PARAM_PATTERN.findall(params_block) + arguments = {} + for p_name, p_value in params: + arguments[p_name.strip()] = _parse_param_value(p_value.strip()) + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + if func_matches: + cleaned_text = self.FUNCTION_PATTERN.sub("", cleaned_text).strip() + if tool_calls: return ExtractedToolCallInformation( tools_called=True, @@ -112,6 +174,30 @@ def extract_tool_calls( tools_called=False, tool_calls=[], content=model_output ) + # Partial marker prefixes — when current_text ends with one of these, + # we suppress output until the next token confirms or denies a tool call. + # These are long enough to avoid false positives on normal text. + _PARTIAL_MARKERS = (" bool: + """Check if text ends with an incomplete tool call marker prefix.""" + return self._get_partial_marker_len(text) > 0 + + def _get_partial_marker_len(self, text: str) -> int: + """Return the length of a partial tool call marker suffix at end of text.""" + tail = text[-20:] + best = 0 + for marker in self._PARTIAL_MARKERS: + for length in range(len(marker), 0, -1): + if tail.endswith(marker[:length]) and length > best: + best = length + break + return best + + def _was_buffering(self, previous_text: str) -> bool: + """Check if the previous call was buffering a partial marker.""" + return self._has_partial_marker(previous_text) + def extract_tool_calls_streaming( self, previous_text: str, @@ -125,14 +211,67 @@ def extract_tool_calls_streaming( """ Extract tool calls from streaming Qwen model output. """ - # Check for tool call markers + # Check for complete tool call markers has_tool_marker = ( - "" in current_text or "[Calling tool:" in current_text + "" in current_text + or "[Calling tool:" in current_text + or "... (Qwen3.5 native format) + if "") + prev_func_close = previous_text.count("") + + if current_text.count(" func_close_count: + # Inside an incomplete function block, suppress output + return None + + if func_close_count > prev_func_close: + # New function block(s) completed + result = self.extract_tool_calls(current_text) + if result.tools_called: + new_calls = result.tool_calls[prev_func_close:] + if new_calls: + return { + "tool_calls": [ + { + "index": prev_func_close + i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(new_calls) + ] + } + + return None + # If we're in a tool call, accumulate and parse at the end # For simplicity, return None during accumulation if "" in delta_text or ")]" in delta_text: diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index 55dec9577..9d200ab9f 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -9,6 +9,7 @@ import json import logging +from pathlib import Path from .chat_templates import DEFAULT_CHATML_TEMPLATE, NEMOTRON_CHAT_TEMPLATE