diff --git a/evals/README.md b/evals/README.md index 3e8ec2b..b087958 100644 --- a/evals/README.md +++ b/evals/README.md @@ -25,24 +25,28 @@ Different models require different server flags for tool calling. Use the correc | Model Family | Server Flags | |-------------|-------------| | **Qwen / Hermes** | `vllm-mlx serve --port 8000 --enable-auto-tool-choice --tool-call-parser hermes` | -| **GPT-OSS** | `vllm-mlx serve --port 8000 --enable-auto-tool-choice --tool-call-parser minimax` | +| **GPT-OSS (Harmony)** | `vllm-mlx serve --port 8000 --enable-auto-tool-choice --tool-call-parser harmony` | | **MiniMax** | `vllm-mlx serve --port 8000 --enable-auto-tool-choice --tool-call-parser minimax` | +| **DeepSeek V3.1 / R1-0528** | `vllm-mlx serve --port 8000 --enable-auto-tool-choice --tool-call-parser deepseek_v31` | | **GLM-4** | `vllm-mlx serve --port 8000 --enable-auto-tool-choice --tool-call-parser glm47` | +| **Qwen3-Coder (XML)** | `vllm-mlx serve --port 8000 --enable-auto-tool-choice --tool-call-parser qwen3_coder_xml` | | **Other / No tools** | `vllm-mlx serve --port 8000` | Then pass the matching `--parser` to the eval script: ```bash python evals/run_eval.py --model "X" --parser hermes # for Qwen/Hermes models -python evals/run_eval.py --model "X" --parser minimax # for GPT-OSS models +python evals/run_eval.py --model "X" --parser harmony # for GPT-OSS (Harmony) models python evals/run_eval.py --model "X" --parser minimax # for MiniMax models +python evals/run_eval.py --model "X" --parser deepseek_v31 # for DeepSeek V3.1 / R1-0528 python evals/run_eval.py --model "X" --parser glm47 # for GLM-4 models +python evals/run_eval.py --model "X" --parser qwen3_coder_xml # for Qwen3-Coder (XML) ``` ## Eval Suites | Suite | Items | What it tests | Scoring | |-------|-------|---------------|---------| -| **Speed** | 4 metrics | TTFT cold/warm, decode tok/s short/long | Absolute numbers | +| **Speed** | 6 metrics | TTFT cold/warm, decode tok/s short/long, RAM active/peak | Absolute numbers | | **Tool Calling** | 30 scenarios | Tool detection, parallel calls, irrelevance, error recovery | % fully correct | | **Coding** | 10 tasks | HumanEval+ problems (medium-hard) | % tests pass | | **Reasoning** | 10 problems | MATH-500 competition math (levels 2-5, fractions + integers) | % correct answer | diff --git a/evals/run_all_models.sh b/evals/run_all_models.sh new file mode 100755 index 0000000..4b082cb --- /dev/null +++ b/evals/run_all_models.sh @@ -0,0 +1,126 @@ +#!/bin/bash +# Batch eval runner — runs ALL suites for all text LLMs +# Usage: bash evals/run_all_models.sh [suite1 suite2 ...] +# Examples: +# bash evals/run_all_models.sh # all suites +# bash evals/run_all_models.sh speed tool_calling # specific suites +# NOTE: Model paths below are machine-specific. Update them to match your +# local model directory before running. +# No set -e: server kill/wait returns non-zero which is expected + +PYTHON=python3.12 +CLI_CMD="from vllm_mlx.cli import main; import sys; sys.argv = ['vllm-mlx'] + sys.argv[1:]; main()" +PORT=8000 +EVAL_CMD="$PYTHON evals/run_eval.py" + +# Suites to run (all by default, or from command line) +if [ $# -gt 0 ]; then + SUITES="$*" +else + SUITES="speed tool_calling coding reasoning general" +fi + +# Model configs: name|path|parser|quantization +declare -a MODELS=( + "Qwen3-0.6B-4bit|/Users/raullenstudio/.lmstudio/models/mlx-community/Qwen3-0.6B-MLX-4bit|hermes|4bit" + "GLM-4.7-4bit|/Users/raullenstudio/.lmstudio/models/mlx-community/GLM-4.7-4bit|glm47|4bit" + "GPT-OSS-20B-mxfp4-q8|/Users/raullenstudio/.lmstudio/models/mlx-community/gpt-oss-20b-MXFP4-Q8|harmony|mxfp4-q8" + "MiniMax-M2.5-4bit|/Users/raullenstudio/.lmstudio/models/lmstudio-community/MiniMax-M2.5-MLX-4bit|minimax|4bit" + "Qwen3.5-35B-A3B-4bit|/Users/raullenstudio/.lmstudio/models/mlx-community/Qwen3.5-35B-A3B-4bit|hermes|4bit" + "Qwen3.5-35B-A3B-8bit|/Users/raullenstudio/.lmstudio/models/mlx-community/Qwen3.5-35B-A3B-8bit|hermes|8bit" + "Qwen3-Coder-Next-4bit|/Users/raullenstudio/.lmstudio/models/lmstudio-community/Qwen3-Coder-Next-MLX-4bit|hermes|4bit" + "Qwen3-Coder-Next-6bit|/Users/raullenstudio/.lmstudio/models/lmstudio-community/Qwen3-Coder-Next-MLX-6bit|hermes|6bit" + "Qwen3.5-122B-A10B-mxfp4|/Users/raullenstudio/.lmstudio/models/nightmedia/Qwen3.5-122B-A10B-Text-mxfp4-mlx|hermes|mxfp4" + "Qwen3.5-122B-A10B-8bit|/Users/raullenstudio/.lmstudio/models/mlx-community/Qwen3.5-122B-A10B-8bit|hermes|8bit" + # Requested by community — download and uncomment to eval: + # "Mistral-Small-3.2-4bit||hermes|4bit" + # "Devstral-Small-4bit||hermes|4bit" + # "GLM-4.5-Air-4bit||glm47|4bit" + # "Nemotron-Nano-30B-4bit||hermes|4bit" + # "Qwen3.5-4B-4bit||hermes|4bit" + # "Qwen3.5-9B-4bit||hermes|4bit" +) + +start_server() { + local model_path="$1" + local parser="$2" + echo " Starting server: $(basename "$model_path") (parser=$parser)..." + $PYTHON -c "$CLI_CMD" serve "$model_path" --port $PORT \ + --enable-auto-tool-choice --tool-call-parser "$parser" & + SERVER_PID=$! + + for i in $(seq 1 120); do + if curl -s "http://localhost:$PORT/health" | grep -q "healthy"; then + echo " Server ready (${i}s)" + return 0 + fi + sleep 2 + done + echo " ERROR: Server failed to start within 240s" + kill $SERVER_PID 2>/dev/null + return 1 +} + +stop_server() { + if [ -n "$SERVER_PID" ]; then + kill $SERVER_PID 2>/dev/null + wait $SERVER_PID 2>/dev/null + SERVER_PID="" + fi + lsof -ti:$PORT | xargs kill 2>/dev/null || true + sleep 3 +} + +echo "========================================" +echo "vllm-mlx Full Model Evaluation" +echo "========================================" +echo "Models: ${#MODELS[@]}" +echo "Suites: $SUITES" +echo "" + +TOTAL_START=$(date +%s) + +for model_config in "${MODELS[@]}"; do + IFS='|' read -r name path parser quant <<< "$model_config" + + # Skip if model path doesn't exist + if [ ! -d "$path" ]; then + echo "SKIP: $name (path not found: $path)" + echo "" + continue + fi + + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo "Model: $name ($quant, parser=$parser)" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + stop_server + + if start_server "$path" "$parser"; then + $EVAL_CMD \ + --model "$name" \ + --parser "$parser" \ + --quantization "$quant" \ + --suite $SUITES \ + --server-flags "--enable-auto-tool-choice --tool-call-parser $parser" + echo "" + else + echo " SKIPPED: $name (server failed to start)" + echo "" + fi + + stop_server +done + +TOTAL_END=$(date +%s) +TOTAL_ELAPSED=$((TOTAL_END - TOTAL_START)) +MINUTES=$((TOTAL_ELAPSED / 60)) + +echo "========================================" +echo "All evals complete in ${MINUTES}m ${TOTAL_ELAPSED}s" +echo "========================================" +echo "" +echo "Results:" +ls -la evals/results/*.json +echo "" +echo "Regenerate scorecard with: python3.12 evals/generate_scorecard.py" diff --git a/tests/test_harmony_parsers.py b/tests/test_harmony_parsers.py index af67e0a..046267b 100644 --- a/tests/test_harmony_parsers.py +++ b/tests/test_harmony_parsers.py @@ -1323,14 +1323,18 @@ def test_invalid_parser_not_registered(self): class TestHarmonyNativeFormat: - """Test that Harmony parser correctly declares no native format support.""" + """Test that Harmony parser declares native format support. - def test_supports_native_format_false(self): - """HarmonyToolParser does not support native tool format.""" - assert HarmonyToolParser.SUPPORTS_NATIVE_TOOL_FORMAT is False - assert HarmonyToolParser.supports_native_format() is False + GPT-OSS chat templates natively handle tool_calls and role='tool' + messages using harmony channel tokens. + """ + + def test_supports_native_format_true(self): + """HarmonyToolParser supports native tool format.""" + assert HarmonyToolParser.SUPPORTS_NATIVE_TOOL_FORMAT is True + assert HarmonyToolParser.supports_native_format() is True def test_instance_supports_native_format(self): - """Instance-level check also returns False.""" + """Instance-level check also returns True.""" parser = HarmonyToolParser() - assert parser.supports_native_format() is False + assert parser.supports_native_format() is True diff --git a/tests/test_native_tool_format.py b/tests/test_native_tool_format.py index 1841161..9f6d691 100644 --- a/tests/test_native_tool_format.py +++ b/tests/test_native_tool_format.py @@ -13,6 +13,7 @@ DeepSeekToolParser, FunctionaryToolParser, GraniteToolParser, + HarmonyToolParser, HermesToolParser, KimiToolParser, LlamaToolParser, @@ -37,6 +38,7 @@ def test_parsers_with_native_support(self): FunctionaryToolParser, KimiToolParser, HermesToolParser, + HarmonyToolParser, ] for parser_cls in native_parsers: assert ( @@ -73,6 +75,7 @@ def test_via_manager(self): "functionary", "kimi", "hermes", + "harmony", ]: parser_cls = ToolParserManager.get_tool_parser(name) assert ( diff --git a/tests/test_upstream_regression.py b/tests/test_upstream_regression.py new file mode 100644 index 0000000..c70ad42 --- /dev/null +++ b/tests/test_upstream_regression.py @@ -0,0 +1,1248 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Upstream regression tests — test cases ported from vLLM (vllm-project/vllm) +to verify our tool parser forks haven't broken correctness. + +Sources: + - tests/tool_parsers/test_glm4_moe_tool_parser.py + - tests/tool_parsers/test_mistral_tool_parser.py + - tests/tool_parsers/test_seed_oss_tool_parser.py + - tests/tool_parsers/test_deepseekv31_tool_parser.py + - tests/tool_parsers/test_qwen3coder_tool_parser.py +""" + +import json + +import pytest + +from vllm_mlx.tool_parsers import ToolParserManager + + +# ─── Fixtures ──────────────────────────────────────────────────────── + +@pytest.fixture +def glm47_parser(): + cls = ToolParserManager.get_tool_parser("glm47") + return cls(tokenizer=None) + + +@pytest.fixture +def mistral_parser(): + cls = ToolParserManager.get_tool_parser("mistral") + return cls(tokenizer=None) + + +@pytest.fixture +def glm47_request(): + """Minimal request dict with tools (GLM47 uses tool names for validation).""" + return { + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "state": {"type": "string"}, + "unit": {"type": "string"}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "date": {"type": "string"}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_current_time", + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "send_message", + "parameters": { + "type": "object", + "properties": { + "recipient": {"type": "string"}, + "message": {"type": "string"}, + "priority": {"type": "string"}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "calculate", + "parameters": { + "type": "object", + "properties": { + "operation": {"type": "string"}, + "a": {"type": "number"}, + "b": {"type": "number"}, + "enabled": {"type": "boolean"}, + }, + }, + }, + }, + ] + } + + +# ═══════════════════════════════════════════════════════════════════════ +# GLM-4.7 (glm47) — ported from vLLM test_glm4_moe_tool_parser.py +# ═══════════════════════════════════════════════════════════════════════ + + +class TestGlm47UpstreamNonStreaming: + """Non-streaming tests ported from upstream vLLM.""" + + def test_no_tools(self, glm47_parser, glm47_request): + """Plain text → no tool calls.""" + result = glm47_parser.extract_tool_calls("This is a test", glm47_request) + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == "This is a test" + + def test_single_tool_call(self, glm47_parser, glm47_request): + """Single tool with 3 args.""" + output = """get_current_weather +cityDallas +stateTX +unitfahrenheit +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc["name"] == "get_current_weather" + args = json.loads(tc["arguments"]) + assert args == {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + + def test_multiple_tool_calls(self, glm47_parser, glm47_request): + """Two tool calls in sequence.""" + output = """get_current_weather +cityDallas +stateTX +unitfahrenheit + +get_current_weather +cityOrlando +stateFL +unitfahrenheit +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + assert len(result.tool_calls) == 2 + args0 = json.loads(result.tool_calls[0]["arguments"]) + args1 = json.loads(result.tool_calls[1]["arguments"]) + assert args0["city"] == "Dallas" + assert args1["city"] == "Orlando" + + def test_tool_call_with_content_before(self, glm47_parser, glm47_request): + """Content before tool call — upstream expects content preserved.""" + output = """I'll help you check the weather. get_current_weather +citySeattle +stateWA +unitcelsius +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_current_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"city": "Seattle", "state": "WA", "unit": "celsius"} + + def test_tool_call_with_chinese_content(self, glm47_parser, glm47_request): + """Chinese content before tool call + date argument.""" + output = """I will help you get the weather.get_weather +cityBeijing +date2025-08-01 +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"city": "Beijing", "date": "2025-08-01"} + + def test_thinking_tags(self, glm47_parser, glm47_request): + """Tool call after ... block.""" + output = """I want to get the weather. + +I will help you get the weather. +get_weather +cityBeijing +date2025-08-01 +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_weather" + + def test_empty_arguments(self, glm47_parser, glm47_request): + """Tool call with no arguments.""" + output = """get_current_time +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_current_time" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {} + + def test_special_characters(self, glm47_parser, glm47_request): + """Tool call with special characters in values.""" + output = """send_message +recipientAmy +messageIt is a nice day +priorityhigh +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["recipient"] == "Amy" + assert args["message"] == "It is a nice day" + assert args["priority"] == "high" + + def test_incomplete_tool_call(self, glm47_parser, glm47_request): + """Missing → should NOT extract.""" + output = """get_weather +cityBeijing +date2025-08-01""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert not result.tools_called + assert result.tool_calls == [] + + def test_numeric_deserialization(self, glm47_parser, glm47_request): + """Integer, float, and boolean arg values should deserialize to correct types.""" + output = """calculate +operationadd +a42 +b3.14 +enabledtrue +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["operation"] == "add" + assert isinstance(args["operation"], str) + assert args["a"] == 42 + assert isinstance(args["a"], int) + assert args["b"] == 3.14 + assert isinstance(args["b"], float) + assert args["enabled"] is True + assert isinstance(args["enabled"], bool) + + def test_mixed_content_between_tools(self, glm47_parser, glm47_request): + """Content between two tool calls — both should be extracted.""" + output = """I will help you get the weather info. + +get_weather +cityBeijing +date2025-08-01 + + +meanwhile, I will also check the weather in Shanghai. + +get_weather +cityShanghai +date2025-08-01 +""" + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert result.tools_called + assert len(result.tool_calls) == 2 + args0 = json.loads(result.tool_calls[0]["arguments"]) + args1 = json.loads(result.tool_calls[1]["arguments"]) + assert args0["city"] == "Beijing" + assert args1["city"] == "Shanghai" + + def test_malformed_xml_graceful(self, glm47_parser, glm47_request): + """Malformed XML (missing ) — should not crash.""" + output = """get_weather +citySeattle +incomplete_arg +value +""" + # Should not raise; may or may not extract + result = glm47_parser.extract_tool_calls(output, glm47_request) + assert isinstance(result.tools_called, bool) + assert isinstance(result.tool_calls, list) + + +class TestGlm47UpstreamStreaming: + """Streaming tests ported from upstream vLLM.""" + + def test_streaming_no_tool_calls(self, glm47_parser, glm47_request): + """Regular text in streaming → content delta.""" + result = glm47_parser.extract_tool_calls_streaming( + previous_text="Hello", + current_text="Hello world", + delta_text=" world", + request=glm47_request, + ) + assert result is not None + assert result["content"] == " world" + + def test_streaming_buffers_during_tool_call(self, glm47_parser, glm47_request): + """While inside but before , returns None.""" + result = glm47_parser.extract_tool_calls_streaming( + previous_text="", + current_text="get_weather\ncity", + delta_text="city", + request=glm47_request, + ) + assert result is None + + def test_streaming_emits_on_close(self, glm47_parser, glm47_request): + """When arrives, tool calls should be emitted.""" + full = """get_weather +cityBeijing +""" + result = glm47_parser.extract_tool_calls_streaming( + previous_text=full.replace("", ""), + current_text=full, + delta_text="", + request=glm47_request, + ) + assert result is not None + assert "tool_calls" in result + assert len(result["tool_calls"]) == 1 + tc = result["tool_calls"][0] + assert tc["function"]["name"] == "get_weather" + args = json.loads(tc["function"]["arguments"]) + assert args["city"] == "Beijing" + + def test_streaming_multiple_tools_on_close(self, glm47_parser, glm47_request): + """Two tool calls, second emits both.""" + full = """get_weather +cityBeijing + +get_weather +cityShanghai +""" + result = glm47_parser.extract_tool_calls_streaming( + previous_text=full.replace("", "", 1).rsplit("", 1)[0], + current_text=full, + delta_text="", + request=glm47_request, + ) + # Our GLM parser re-parses the full text on close, so both should appear + assert result is not None + assert "tool_calls" in result + assert len(result["tool_calls"]) == 2 + + +# ═══════════════════════════════════════════════════════════════════════ +# Mistral — ported from vLLM test_mistral_tool_parser.py +# ═══════════════════════════════════════════════════════════════════════ + + +class TestMistralUpstreamNonStreaming: + """Non-streaming tests ported from upstream vLLM.""" + + def test_no_tools(self, mistral_parser): + """Plain text → no tool calls.""" + result = mistral_parser.extract_tool_calls("This is a test", request=None) + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == "This is a test" + + # --- Old format (pre v11): [TOOL_CALLS] [{...}] --- + + @pytest.mark.parametrize( + "model_output, expected_name, expected_args", + [ + ( + '[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]', + "add", + {"a": 3.5, "b": 4}, + ), + ( + '[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]', + "get_current_weather", + {"city": "San Francisco", "state": "CA", "unit": "celsius"}, + ), + ( + '[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]', + "get_current_weather", + {"city": "San Francisco", "state": "CA", "unit": "celsius"}, + ), + ( + '[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]', + "get_age", + {"name": "John Doe"}, + ), + ], + ids=[ + "single_tool_add", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + ], + ) + def test_old_format_single(self, mistral_parser, model_output, expected_name, expected_args): + """Old Mistral format: [TOOL_CALLS] [{"name": ..., "arguments": ...}]""" + result = mistral_parser.extract_tool_calls(model_output, request=None) + assert result.tools_called + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc["name"] == expected_name + assert len(tc["id"]) == 9 # Mistral IDs are 9-char alphanumeric + args = json.loads(tc["arguments"]) + assert args == expected_args + + def test_old_format_multiple(self, mistral_parser): + """Old format with two tools in one JSON array.""" + output = '[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]' + result = mistral_parser.extract_tool_calls(output, request=None) + assert result.tools_called + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["name"] == "add" + assert result.tool_calls[1]["name"] == "get_current_weather" + + # --- New format (>= v11): [TOOL_CALLS]func_name{...} --- + + @pytest.mark.parametrize( + "model_output, expected_name, expected_args", + [ + ( + '[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}', + "add_this_and_that", + {"a": 3.5, "b": 4}, + ), + ( + '[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}', + "get_current_weather", + {"city": "San Francisco", "state": "CA", "unit": "celsius"}, + ), + ], + ids=[ + "new_format_add", + "new_format_weather", + ], + ) + def test_new_format_single(self, mistral_parser, model_output, expected_name, expected_args): + """New Mistral format: [TOOL_CALLS]func_name{...}""" + result = mistral_parser.extract_tool_calls(model_output, request=None) + assert result.tools_called + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc["name"] == expected_name + args = json.loads(tc["arguments"]) + assert args == expected_args + + def test_new_format_multiple(self, mistral_parser): + """New format with two [TOOL_CALLS] in one output.""" + output = '[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}' + result = mistral_parser.extract_tool_calls(output, request=None) + assert result.tools_called + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["name"] == "add" + assert result.tool_calls[1]["name"] == "multiply" + args0 = json.loads(result.tool_calls[0]["arguments"]) + args1 = json.loads(result.tool_calls[1]["arguments"]) + assert args0 == {"a": 3.5, "b": 4} + assert args1 == {"a": 3, "b": 6} + + def test_content_before_tool_call(self, mistral_parser): + """Content before [TOOL_CALLS] should be preserved.""" + output = 'hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}' + result = mistral_parser.extract_tool_calls(output, request=None) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "bash" + assert result.content == "hi{hi" + + def test_complex_escaped_json(self, mistral_parser): + """Complex JSON with escaped quotes and newlines.""" + output = '[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}' + result = mistral_parser.extract_tool_calls(output, request=None) + assert result.tools_called + assert result.tool_calls[0]["name"] == "bash" + args = json.loads(result.tool_calls[0]["arguments"]) + assert "print" in args["command"] + assert "re.compile" in args["command"] + + +# ─── Fixtures for new parsers ──────────────────────────────────────── + +@pytest.fixture +def seed_oss_parser(): + cls = ToolParserManager.get_tool_parser("seed_oss") + return cls(tokenizer=None) + + +@pytest.fixture +def deepseekv31_parser(): + cls = ToolParserManager.get_tool_parser("deepseek_v31") + return cls(tokenizer=None) + + +@pytest.fixture +def qwen3coder_parser(): + cls = ToolParserManager.get_tool_parser("qwen3_coder_xml") + return cls(tokenizer=None) + + +@pytest.fixture +def seed_oss_request(): + """Request with tools for Seed-OSS type conversion tests.""" + return { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string"}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "calculate", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "number"}, + "op": {"type": "string"}, + "enabled": {"type": "boolean"}, + "config": {"type": "object"}, + }, + }, + }, + }, + ] + } + + +@pytest.fixture +def qwen3coder_request(): + """Request with tools for Qwen3-Coder type conversion tests.""" + return { + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "state": {"type": "string"}, + "unit": {"type": "string"}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "calculate_area", + "parameters": { + "type": "object", + "properties": { + "shape": {"type": "string"}, + "dimensions": {"type": "object"}, + "precision": {"type": "integer"}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "test_types", + "parameters": { + "type": "object", + "properties": { + "int_param": {"type": "integer"}, + "float_param": {"type": "float"}, + "bool_param": {"type": "boolean"}, + "str_param": {"type": "string"}, + "obj_param": {"type": "object"}, + }, + }, + }, + }, + ] + } + + +# ═══════════════════════════════════════════════════════════════════════ +# Seed-OSS — ported from vLLM test_seed_oss_tool_parser.py +# ═══════════════════════════════════════════════════════════════════════ + + +class TestSeedOssUpstreamNonStreaming: + """Non-streaming tests ported from upstream vLLM.""" + + def test_no_tools(self, seed_oss_parser): + """Plain text → no tool calls.""" + result = seed_oss_parser.extract_tool_calls( + "This is a test response without any tool calls", request=None + ) + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == "This is a test response without any tool calls" + + def test_single_tool_call(self, seed_oss_parser, seed_oss_request): + """Single tool call with wrapper.""" + output = ( + "\n\n" + "Barcelona, Spain\n" + "\n" + ) + result = seed_oss_parser.extract_tool_calls(output, seed_oss_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc["name"] == "get_weather" + args = json.loads(tc["arguments"]) + assert args == {"location": "Barcelona, Spain"} + + def test_tool_call_with_two_params(self, seed_oss_parser, seed_oss_request): + """Tool call with two parameters.""" + output = ( + "\n\n" + "Barcelona, Spain\n" + "celsius\n" + "\n" + ) + result = seed_oss_parser.extract_tool_calls(output, seed_oss_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"location": "Barcelona, Spain", "unit": "celsius"} + + def test_tool_call_with_thinking(self, seed_oss_parser, seed_oss_request): + """Tool call after ... block.""" + output = ( + "I should check the weather.\n" + "\n\n" + "Barcelona, Spain\n" + "\n" + ) + result = seed_oss_parser.extract_tool_calls(output, seed_oss_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_weather" + # Content should include the thinking block + assert result.content is not None + assert "" in result.content + + def test_content_before_tool_call(self, seed_oss_parser, seed_oss_request): + """Content before tool call is preserved.""" + output = ( + "Let me check that for you.\n" + "\n\n" + "Paris, France\n" + "\n" + ) + result = seed_oss_parser.extract_tool_calls(output, seed_oss_request) + assert result.tools_called + assert result.content is not None + assert "Let me check that" in result.content + + def test_multiple_tool_calls(self, seed_oss_parser, seed_oss_request): + """Multiple tool calls in sequence.""" + output = ( + "\n\n" + "Paris\n" + "\n\n" + "\n\n" + "London\n" + "\n" + ) + result = seed_oss_parser.extract_tool_calls(output, seed_oss_request) + assert result.tools_called + assert len(result.tool_calls) == 2 + args0 = json.loads(result.tool_calls[0]["arguments"]) + args1 = json.loads(result.tool_calls[1]["arguments"]) + assert args0["location"] == "Paris" + assert args1["location"] == "London" + + def test_type_conversion_integer(self, seed_oss_parser, seed_oss_request): + """Integer parameter type conversion.""" + output = ( + "\n\n" + "42\n" + "3.14\n" + "add\n" + "true\n" + "\n" + ) + result = seed_oss_parser.extract_tool_calls(output, seed_oss_request) + assert result.tools_called + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["a"] == 42 + assert isinstance(args["a"], int) + assert args["b"] == 3.14 + assert isinstance(args["b"], float) + assert args["op"] == "add" + assert args["enabled"] is True + + def test_type_conversion_object(self, seed_oss_parser, seed_oss_request): + """Object parameter type conversion.""" + output = ( + '\n\n' + '{"key": "value"}\n' + 'test\n' + '\n' + ) + result = seed_oss_parser.extract_tool_calls(output, seed_oss_request) + assert result.tools_called + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["config"] == {"key": "value"} + + +class TestSeedOssUpstreamStreaming: + """Streaming tests ported from upstream vLLM.""" + + def test_streaming_no_tools(self, seed_oss_parser): + """Regular text → content delta.""" + result = seed_oss_parser.extract_tool_calls_streaming( + previous_text="Hello", + current_text="Hello world", + delta_text=" world", + ) + assert result is not None + assert result["content"] == " world" + + def test_streaming_buffers_during_tool(self, seed_oss_parser): + """Inside tool call but before close → returns None or tool header.""" + # First delta starts the tool call + result = seed_oss_parser.extract_tool_calls_streaming( + previous_text="", + current_text="\n", + delta_text="\n", + ) + # Should either return None (buffering) or a tool_calls header — not content + assert result is None or "tool_calls" in result + + def test_streaming_content_before_tool(self, seed_oss_parser): + """Content before tool call is streamed as content.""" + result = seed_oss_parser.extract_tool_calls_streaming( + previous_text="", + current_text="Let me check", + delta_text="Let me check", + ) + assert result is not None + assert result["content"] == "Let me check" + + def test_streaming_thinking_content(self, seed_oss_parser): + """Thinking content before seed:think end is streamed.""" + result = seed_oss_parser.extract_tool_calls_streaming( + previous_text="", + current_text="thinking...", + delta_text="thinking...", + ) + assert result is not None + assert "content" in result + + def test_streaming_full_tool_call_multistep(self, seed_oss_parser, seed_oss_request): + """Multi-step streaming: header → { → param → } across calls. + + Streaming parsers emit one piece per call; callers must invoke + extract_tool_calls_streaming once per token/delta (fine-grained). + """ + deltas = [ + "", + "\n", + "\n", + "Paris", + "\n", + "\n", + ] + text = "" + collected = [] + for d in deltas: + prev = text + text += d + r = seed_oss_parser.extract_tool_calls_streaming( + previous_text=prev, current_text=text, delta_text=d, + request=seed_oss_request, + ) + if r: + collected.append(r) + + # Should have: header (name), opening {, param fragment, closing } + names = [c["tool_calls"][0]["function"].get("name") + for c in collected if "tool_calls" in c + and "name" in c["tool_calls"][0].get("function", {})] + assert "get_weather" in names + + # Concatenate all argument fragments + arg_parts = [c["tool_calls"][0]["function"]["arguments"] + for c in collected if "tool_calls" in c + and "arguments" in c["tool_calls"][0].get("function", {})] + full_args = "".join(arg_parts) + assert full_args.startswith("{") + assert full_args.endswith("}") + parsed = json.loads(full_args) + assert parsed["location"] == "Paris" + + def test_streaming_coarse_deltas_complete(self, seed_oss_parser, seed_oss_request): + """Two coarse deltas: header + complete body → full args emitted. + + Reproduces the scenario where the function body is already complete + when the header is first detected (e.g. fast model, large chunk). + """ + deltas = [ + "\n" + "\nParis\n" + "\n", + ] + text = "" + collected = [] + for d in deltas: + prev = text + text += d + r = seed_oss_parser.extract_tool_calls_streaming( + previous_text=prev, current_text=text, delta_text=d, + request=seed_oss_request, + ) + if r: + collected.append(r) + + # Must have at least one tool_calls chunk with non-empty arguments + tc_chunks = [c for c in collected if "tool_calls" in c] + assert len(tc_chunks) >= 1 + # First chunk should have complete arguments (fast-path) + first_tc = tc_chunks[0]["tool_calls"][0] + assert first_tc["function"]["name"] == "get_weather" + args = first_tc["function"]["arguments"] + assert args # not empty + parsed = json.loads(args) + assert parsed["location"] == "Paris" + + +# ═══════════════════════════════════════════════════════════════════════ +# DeepSeek V3.1 — ported from vLLM test_deepseekv31_tool_parser.py +# ═══════════════════════════════════════════════════════════════════════ + + +class TestDeepSeekV31UpstreamNonStreaming: + """Non-streaming tests ported from upstream vLLM.""" + + def test_no_tools(self, deepseekv31_parser): + """Plain text → no tool calls.""" + result = deepseekv31_parser.extract_tool_calls("This is a test", request=None) + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == "This is a test" + + def test_single_tool_call(self, deepseekv31_parser): + """Single tool call in V3.1 format (no code fence, no type prefix).""" + output = ( + "normal text" + "<|tool▁calls▁begin|>" + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + "<|tool▁calls▁end|>" + ) + result = deepseekv31_parser.extract_tool_calls(output, request=None) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "foo" + assert result.tool_calls[0]["arguments"] == '{"x":1}' + assert result.content == "normal text" + + def test_multiple_tool_calls(self, deepseekv31_parser): + """Multiple tool calls in V3.1 format.""" + output = ( + "some prefix text" + "<|tool▁calls▁begin|>" + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + '<|tool▁call▁begin|>bar<|tool▁sep|>{"y":2}<|tool▁call▁end|>' + "<|tool▁calls▁end|>" + ) + result = deepseekv31_parser.extract_tool_calls(output, request=None) + assert result.tools_called + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["name"] == "foo" + assert result.tool_calls[0]["arguments"] == '{"x":1}' + assert result.tool_calls[1]["name"] == "bar" + assert result.tool_calls[1]["arguments"] == '{"y":2}' + assert result.content == "some prefix text" + + def test_content_preserved(self, deepseekv31_parser): + """Content before tool calls is preserved.""" + output = ( + "I'll help with that!" + "<|tool▁calls▁begin|>" + '<|tool▁call▁begin|>search<|tool▁sep|>{"q":"test"}<|tool▁call▁end|>' + "<|tool▁calls▁end|>" + ) + result = deepseekv31_parser.extract_tool_calls(output, request=None) + assert result.tools_called + assert result.content == "I'll help with that!" + + def test_no_tool_calls_start(self, deepseekv31_parser): + """Without tool_calls_begin token, treat as content.""" + output = "Just some regular text without any special tokens" + result = deepseekv31_parser.extract_tool_calls(output, request=None) + assert not result.tools_called + assert result.content == output + + def test_complex_json_args(self, deepseekv31_parser): + """Tool call with nested JSON arguments.""" + output = ( + "<|tool▁calls▁begin|>" + '<|tool▁call▁begin|>create_event<|tool▁sep|>' + '{"title":"Meeting","details":{"time":"3pm","room":"A1"}}' + "<|tool▁call▁end|>" + "<|tool▁calls▁end|>" + ) + result = deepseekv31_parser.extract_tool_calls(output, request=None) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "create_event" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["title"] == "Meeting" + assert args["details"]["room"] == "A1" + + +class TestDeepSeekV31UpstreamStreaming: + """Streaming tests ported from upstream vLLM.""" + + def test_streaming_no_tools(self, deepseekv31_parser): + """Regular text → content delta.""" + result = deepseekv31_parser.extract_tool_calls_streaming( + previous_text="Hello", + current_text="Hello world", + delta_text=" world", + ) + assert result is not None + assert result["content"] == " world" + + def test_streaming_content_before_tools(self, deepseekv31_parser): + """Content before tool calls start token.""" + result = deepseekv31_parser.extract_tool_calls_streaming( + previous_text="", + current_text="Some text", + delta_text="Some text", + ) + assert result is not None + assert result["content"] == "Some text" + + +# ═══════════════════════════════════════════════════════════════════════ +# Qwen3-Coder XML — ported from vLLM test_qwen3coder_tool_parser.py +# ═══════════════════════════════════════════════════════════════════════ + + +class TestQwen3CoderUpstreamNonStreaming: + """Non-streaming tests ported from upstream vLLM.""" + + def test_no_tools(self, qwen3coder_parser): + """Plain text → no tool calls.""" + result = qwen3coder_parser.extract_tool_calls( + "This is a test response without any tool calls", request=None + ) + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == "This is a test response without any tool calls" + + def test_single_tool_call(self, qwen3coder_parser, qwen3coder_request): + """Single tool call with wrapper.""" + output = ( + "\n\n" + "\nDallas\n\n" + "\nTX\n\n" + "\nfahrenheit\n\n" + "\n" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc["name"] == "get_current_weather" + args = json.loads(tc["arguments"]) + assert args == {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + + def test_single_tool_with_content(self, qwen3coder_parser, qwen3coder_request): + """Content before tool call is preserved.""" + output = ( + "Sure! Let me check the weather for you." + "\n\n" + "\nDallas\n\n" + "\nTX\n\n" + "\n" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.content == "Sure! Let me check the weather for you." + + def test_parallel_tools(self, qwen3coder_parser, qwen3coder_request): + """Multiple parallel tool calls.""" + output = ( + "\n\n" + "\nDallas\n\n" + "\nTX\n\n" + "\n\n" + "\n\n" + "\nOrlando\n\n" + "\nFL\n\n" + "\n" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + assert len(result.tool_calls) == 2 + args0 = json.loads(result.tool_calls[0]["arguments"]) + args1 = json.loads(result.tool_calls[1]["arguments"]) + assert args0["city"] == "Dallas" + assert args1["city"] == "Orlando" + + def test_type_conversion(self, qwen3coder_parser, qwen3coder_request): + """Parameter type conversion based on tool schema.""" + output = ( + "\n\n" + "\n42\n\n" + "\n3.14\n\n" + "\ntrue\n\n" + "\nhello world\n\n" + '\n{"key": "value"}\n\n' + "\n" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["int_param"] == 42 + assert isinstance(args["int_param"], int) + assert args["float_param"] == 3.14 + assert isinstance(args["float_param"], float) + assert args["bool_param"] is True + assert args["str_param"] == "hello world" + assert args["obj_param"] == {"key": "value"} + + def test_object_with_single_quotes(self, qwen3coder_parser, qwen3coder_request): + """Object parameter with single-quote JSON (Python literal).""" + output = ( + "\n\n" + "\n{'key': 'value'}\n\n" + "\n" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["obj_param"] == {"key": "value"} + + def test_fallback_no_tool_call_tags(self, qwen3coder_parser, qwen3coder_request): + """Bare without wrapper also works.""" + output = ( + "\n" + "\nDallas\n\n" + "\nTX\n\n" + "" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_current_weather" + + def test_missing_closing_parameter_tag(self, qwen3coder_parser, qwen3coder_request): + """Missing tag — graceful handling.""" + output = ( + "\n\n" + "\nDallas\n" + "\nTX\n\n" + "\nfahrenheit\n\n" + "\n" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + assert len(result.tool_calls) == 1 + args = json.loads(result.tool_calls[0]["arguments"]) + assert "city" in args + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + def test_multiline_object_param(self, qwen3coder_parser, qwen3coder_request): + """Object parameter spanning multiple lines.""" + output = ( + "\n\n" + "\nrectangle\n\n" + '\n{"width": 10, \n "height": 20}\n\n' + "\n2\n\n" + "\n" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["shape"] == "rectangle" + assert args["dimensions"] == {"width": 10, "height": 20} + assert args["precision"] == 2 + + def test_tool_with_content_and_typed_params( + self, qwen3coder_parser, qwen3coder_request + ): + """Content before tool call with typed parameters.""" + output = ( + "Let me calculate that area for you." + "\n\n" + "\ncircle\n\n" + '\n{"radius": 15.5}\n\n' + "\n3\n\n" + "\n" + ) + result = qwen3coder_parser.extract_tool_calls(output, qwen3coder_request) + assert result.tools_called + assert result.content == "Let me calculate that area for you." + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["shape"] == "circle" + assert args["dimensions"] == {"radius": 15.5} + assert args["precision"] == 3 + + +class TestQwen3CoderUpstreamStreaming: + """Streaming tests ported from upstream vLLM.""" + + def test_streaming_no_tools(self, qwen3coder_parser): + """Regular text → content delta.""" + result = qwen3coder_parser.extract_tool_calls_streaming( + previous_text="Hello", + current_text="Hello world", + delta_text=" world", + ) + assert result is not None + assert result["content"] == " world" + + def test_streaming_content_before_tool(self, qwen3coder_parser): + """Content before tool call is streamed.""" + result = qwen3coder_parser.extract_tool_calls_streaming( + previous_text="", + current_text="Let me check", + delta_text="Let me check", + ) + assert result is not None + assert result["content"] == "Let me check" + + def test_streaming_full_tool_call_multistep( + self, qwen3coder_parser, qwen3coder_request + ): + """Multi-step streaming: header → { → param → } across calls.""" + deltas = [ + "", + "\n", + "\n", + "Dallas", + "\n", + "\n", + ] + text = "" + collected = [] + for d in deltas: + prev = text + text += d + r = qwen3coder_parser.extract_tool_calls_streaming( + previous_text=prev, current_text=text, delta_text=d, + request=qwen3coder_request, + ) + if r: + collected.append(r) + + names = [c["tool_calls"][0]["function"].get("name") + for c in collected if "tool_calls" in c + and "name" in c["tool_calls"][0].get("function", {})] + assert "get_current_weather" in names + + arg_parts = [c["tool_calls"][0]["function"]["arguments"] + for c in collected if "tool_calls" in c + and "arguments" in c["tool_calls"][0].get("function", {})] + full_args = "".join(arg_parts) + assert full_args.startswith("{") + assert full_args.endswith("}") + parsed = json.loads(full_args) + assert parsed["city"] == "Dallas" + + def test_streaming_coarse_deltas_complete( + self, qwen3coder_parser, qwen3coder_request + ): + """Single coarse delta with complete tool call → full args emitted.""" + deltas = [ + "\n" + "\nDallas\n" + "\n", + ] + text = "" + collected = [] + for d in deltas: + prev = text + text += d + r = qwen3coder_parser.extract_tool_calls_streaming( + previous_text=prev, current_text=text, delta_text=d, + request=qwen3coder_request, + ) + if r: + collected.append(r) + + tc_chunks = [c for c in collected if "tool_calls" in c] + assert len(tc_chunks) >= 1 + first_tc = tc_chunks[0]["tool_calls"][0] + assert first_tc["function"]["name"] == "get_current_weather" + args = first_tc["function"]["arguments"] + assert args + parsed = json.loads(args) + assert parsed["city"] == "Dallas" + + +# ═══════════════════════════════════════════════════════════════════════ +# Registration tests — verify all new parsers are discoverable +# ═══════════════════════════════════════════════════════════════════════ + + +class TestNewParserRegistration: + """Verify new parsers are registered and discoverable.""" + + @pytest.mark.parametrize( + "name", + ["seed_oss", "seed", "gpt_oss", "deepseek_v31", "deepseek_r1_0528", + "qwen3_coder_xml", "qwen3_xml"], + ) + def test_parser_registered(self, name): + """Parser name should be in the registry.""" + cls = ToolParserManager.get_tool_parser(name) + assert cls is not None + + @pytest.mark.parametrize( + "name", + ["seed_oss", "deepseek_v31", "qwen3_coder_xml"], + ) + def test_parser_instantiation(self, name): + """Parser should instantiate without tokenizer.""" + cls = ToolParserManager.get_tool_parser(name) + parser = cls(tokenizer=None) + assert parser is not None + + @pytest.mark.parametrize( + "name", + ["seed_oss", "deepseek_v31", "qwen3_coder_xml"], + ) + def test_parser_supports_native_format(self, name): + """All new parsers should support native tool format.""" + cls = ToolParserManager.get_tool_parser(name) + assert cls.supports_native_format() is True diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py index 7598b88..861d69a 100644 --- a/vllm_mlx/tool_parsers/__init__.py +++ b/vllm_mlx/tool_parsers/__init__.py @@ -19,6 +19,9 @@ - functionary/meetkai: MeetKai Functionary models - glm47/glm4: GLM-4.7 and GLM-4.7-Flash models - harmony/gpt-oss: GPT-OSS models (Harmony format with channels) +- seed_oss/seed/gpt_oss: Seed-OSS / GPT-OSS models (XML format) +- deepseek_v31/deepseek_r1_0528: DeepSeek V3.1 / R1-0528 models +- qwen3_coder_xml/qwen3_xml: Qwen3-Coder models (XML format) Usage: from vllm_mlx.tool_parsers import ToolParserManager @@ -58,6 +61,9 @@ from .glm47_tool_parser import Glm47ToolParser from .harmony_tool_parser import HarmonyToolParser from .minimax_tool_parser import MiniMaxToolParser +from .seed_oss_tool_parser import SeedOssToolParser +from .deepseekv31_tool_parser import DeepSeekV31ToolParser +from .qwen3coder_tool_parser import Qwen3CoderToolParser __all__ = [ # Base classes @@ -79,4 +85,7 @@ "Glm47ToolParser", "HarmonyToolParser", "MiniMaxToolParser", + "SeedOssToolParser", + "DeepSeekV31ToolParser", + "Qwen3CoderToolParser", ] diff --git a/vllm_mlx/tool_parsers/deepseekv31_tool_parser.py b/vllm_mlx/tool_parsers/deepseekv31_tool_parser.py new file mode 100644 index 0000000..bd422d4 --- /dev/null +++ b/vllm_mlx/tool_parsers/deepseekv31_tool_parser.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +DeepSeek V3.1 tool call parser for vllm-mlx. + +Ported from vLLM upstream (vllm/tool_parsers/deepseekv31_tool_parser.py). + +Format (different from V3 — no ```json``` code fence, no "function" type tag): + <|tool▁calls▁begin|> + <|tool▁call▁begin|>NAME<|tool▁sep|>ARGS<|tool▁call▁end|> + <|tool▁calls▁end|> +""" + +import logging +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + +logger = logging.getLogger(__name__) + + +def _generate_tool_id() -> str: + return f"call_{uuid.uuid4().hex[:8]}" + + +@ToolParserManager.register_module(["deepseek_v31", "deepseek_r1_0528"]) +class DeepSeekV31ToolParser(ToolParser): + """ + Tool call parser for DeepSeek V3.1 and R1-0528 models. + + Uses the same unicode special tokens as V3 but with a simpler format: + <|tool▁call▁begin|>NAME<|tool▁sep|>ARGS<|tool▁call▁end|> + (no "function" type prefix, no ```json``` fencing) + + Used when --enable-auto-tool-choice --tool-call-parser deepseek_v31 are set. + """ + + SUPPORTS_NATIVE_TOOL_FORMAT = True + + TOOL_CALLS_START = "<|tool▁calls▁begin|>" + TOOL_CALLS_END = "<|tool▁calls▁end|>" + TOOL_CALL_START = "<|tool▁call▁begin|>" + TOOL_CALL_END = "<|tool▁call▁end|>" + TOOL_SEP = "<|tool▁sep|>" + + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [] + + self.tool_call_regex = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>" + r"(?P.*?)<|tool▁call▁end|>", + re.DOTALL, + ) + self.stream_tool_call_portion_regex = re.compile( + r"(?P.*)<|tool▁sep|>(?P.*)", + re.DOTALL, + ) + self.stream_tool_call_name_regex = re.compile( + r"(?P.*)<|tool▁sep|>" + ) + + # Token IDs for streaming (graceful fallback if absent) + self.tool_calls_start_token_id = self.vocab.get(self.TOOL_CALLS_START) + self.tool_calls_end_token_id = self.vocab.get(self.TOOL_CALLS_END) + self.tool_call_start_token_id = self.vocab.get(self.TOOL_CALL_START) + self.tool_call_end_token_id = self.vocab.get(self.TOOL_CALL_END) + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + if self.TOOL_CALLS_START not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + matches = self.tool_call_regex.findall(model_output) + tool_calls = [] + for func_name, func_args in matches: + tool_calls.append( + { + "id": _generate_tool_id(), + "name": func_name.strip(), + "arguments": func_args.strip(), + } + ) + + content = model_output[: model_output.find(self.TOOL_CALLS_START)] + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=content if content else None, + ) + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def has_pending_tool_call(self, text: str) -> bool: + return ( + self.TOOL_CALLS_START in text + or self.TOOL_CALL_START in text + or self.has_text_format_tool_call(text) + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + if not previous_text: + self.current_tool_name_sent = False + self.streamed_args_for_tool = [] + self.current_tool_id = -1 + self.prev_tool_call_arr = [] + + current_token_ids = current_token_ids or [] + previous_token_ids = previous_token_ids or [] + delta_token_ids = delta_token_ids or [] + + # Use token IDs if available, fall back to string matching + has_tool_start = ( + self.tool_calls_start_token_id is not None + and self.tool_calls_start_token_id in current_token_ids + ) or self.TOOL_CALLS_START in current_text + + if not has_tool_start: + return {"content": delta_text} + + delta_text = delta_text.replace(self.TOOL_CALLS_START, "").replace( + self.TOOL_CALLS_END, "" + ) + + try: + # Count tool call tokens (string-based fallback) + prev_tool_start_count = previous_text.count(self.TOOL_CALL_START) + prev_tool_end_count = previous_text.count(self.TOOL_CALL_END) + cur_tool_start_count = current_text.count(self.TOOL_CALL_START) + cur_tool_end_count = current_text.count(self.TOOL_CALL_END) + + tool_call_portion = None + + # Generating text (no open tool calls) + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.TOOL_CALL_END not in delta_text + ): + return {"content": delta_text} + + if self.TOOL_CALL_END in delta_text: + full_text = current_text + tool_call_portion = ( + full_text.split(self.TOOL_CALL_START)[-1] + .split(self.TOOL_CALL_END)[0] + .rstrip() + ) + delta_text = delta_text.split(self.TOOL_CALL_END)[0].rstrip() + + # Starting new tool call + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): + if len(delta_text) > 1: + tool_call_portion = current_text.split(self.TOOL_CALL_START)[-1] + else: + tool_call_portion = None + + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + + # Updating existing tool call + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): + tool_call_portion = current_text.split(self.TOOL_CALL_START)[-1] + + # Closing tool call + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if not self.prev_tool_call_arr: + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") + if diff and '"}' in delta_text: + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + self.streamed_args_for_tool[self.current_tool_id] += diff + return { + "tool_calls": [ + { + "index": self.current_tool_id, + "function": {"arguments": diff}, + } + ] + } + return None + else: + text = delta_text.replace(self.TOOL_CALL_START, "").replace( + self.TOOL_CALL_END, "" + ) + return {"content": text} if text else None + + # Parse tool call portion + current_tool_call: dict = {} + if tool_call_portion: + m = self.stream_tool_call_portion_regex.match(tool_call_portion) + if m: + current_tool_call["name"] = m.group("function_name") + current_tool_call["arguments"] = m.group("function_arguments") + else: + m2 = self.stream_tool_call_name_regex.match(tool_call_portion) + if m2: + current_tool_call["name"] = m2.group("function_name") + current_tool_call["arguments"] = "" + else: + return None + + # Send tool name + if not self.current_tool_name_sent: + if not current_tool_call: + return None + func_name = current_tool_call.get("name") + if func_name: + self.current_tool_name_sent = True + return { + "tool_calls": [ + { + "index": self.current_tool_id, + "id": _generate_tool_id(), + "type": "function", + "function": {"name": func_name, "arguments": ""}, + } + ] + } + return None + + if tool_call_portion is None: + return None + + # Ensure prev_tool_call_arr has entry + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) + cur_arguments = current_tool_call.get("arguments") + + delta = None + if not cur_arguments and not prev_arguments: + delta = None + elif cur_arguments and not prev_arguments: + delta = { + "tool_calls": [ + { + "index": self.current_tool_id, + "function": {"arguments": cur_arguments}, + } + ] + } + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments + elif cur_arguments and prev_arguments: + if ( + len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments) + ): + diff = cur_arguments[len(prev_arguments):] + delta = { + "tool_calls": [ + { + "index": self.current_tool_id, + "function": {"arguments": diff}, + } + ] + } + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments + + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None diff --git a/vllm_mlx/tool_parsers/glm47_tool_parser.py b/vllm_mlx/tool_parsers/glm47_tool_parser.py index fc8238b..48eba6b 100644 --- a/vllm_mlx/tool_parsers/glm47_tool_parser.py +++ b/vllm_mlx/tool_parsers/glm47_tool_parser.py @@ -175,8 +175,14 @@ def extract_tool_calls_streaming( } return None - # No tool call detected yet; strip think tags and emit content - clean_delta = self.strip_think_tags(delta_text) - if clean_delta: - return {"content": clean_delta} + # No tool call detected yet; emit content delta. + # Only strip think tags if they're actually present (avoid .strip() + # on normal deltas which would eat inter-word spaces). + if "" in delta_text: + clean_delta = self.strip_think_tags(delta_text) + if clean_delta: + return {"content": clean_delta} + return None + if delta_text: + return {"content": delta_text} return None diff --git a/vllm_mlx/tool_parsers/harmony_tool_parser.py b/vllm_mlx/tool_parsers/harmony_tool_parser.py index 207378b..78b9a11 100644 --- a/vllm_mlx/tool_parsers/harmony_tool_parser.py +++ b/vllm_mlx/tool_parsers/harmony_tool_parser.py @@ -4,16 +4,11 @@ Harmony uses control tokens and channels for tool calling: - <|channel|>commentary to=functions.get_weather - <|constrain|>json - <|message|>{"location": "San Francisco"} - <|call|> + <|start|>assistant to=functions.get_weather<|channel|>commentary json<|message|>{"location": "SF"}<|call|> The final response is in the 'final' channel: - <|channel|>final - <|message|>The weather is 72F. - <|return|> + <|start|>assistant<|channel|>final<|message|>The weather is 72F.<|end|> """ import json @@ -34,17 +29,23 @@ def _generate_tool_id() -> str: return f"call_{uuid.uuid4().hex[:8]}" -# Pattern: <|channel|>commentary to=functions.tool_name ... <|call|> +# Tool call pattern — supports both formats from the harmony spec: +# Model-generated: <|channel|>commentary to=functions.NAME <|constrain|>json<|message|>ARGS<|call|> +# Template-encoded (history): to=functions.NAME<|channel|>commentary json<|message|>ARGS<|call|> _COMMENTARY_BLOCK_PATTERN = re.compile( - r"<\|channel\|>commentary\s+to=functions\.(\w+)" - r"(?:\s*<\|constrain\|>\w+)?" - r"\s*<\|message\|>(.*?)<\|call\|>", + r"(?:" + # Real format: to=functions.NAME<|channel|>commentary [content_type]<|message|> + r"to=functions\.(\w+)<\|channel\|>commentary(?:\s+\w+)?<\|message\|>(.*?)<\|call\|>" + r"|" + # Legacy format: <|channel|>commentary to=functions.NAME ... <|message|> + r"<\|channel\|>commentary\s+to=functions\.(\w+)(?:\s*<\|constrain\|>\w+)?\s*<\|message\|>(.*?)<\|call\|>" + r")", re.DOTALL, ) -# Pattern: <|channel|>final ... <|return|> +# Final channel — both <|end|> and <|return|> terminators _FINAL_BLOCK_PATTERN = re.compile( - r"<\|channel\|>final\s*<\|message\|>(.*?)<\|return\|>", + r"<\|channel\|>final\s*<\|message\|>(.*?)(?:<\|end\|>|<\|return\|>)", re.DOTALL, ) @@ -62,7 +63,11 @@ class HarmonyToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser harmony are set. """ - SUPPORTS_NATIVE_TOOL_FORMAT = False + # GPT-OSS chat template natively handles tool_calls and role="tool" + # messages using harmony channel tokens (to=functions.NAME, <|call|>). + # Without this, tool history is converted to "[Calling tool: ...]" text + # which breaks the model's understanding of the tool flow. + SUPPORTS_NATIVE_TOOL_FORMAT = True def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None @@ -76,9 +81,10 @@ def extract_tool_calls( tool_calls = [] # Extract tool calls from commentary channel blocks + # Regex has 4 groups: (1,2) for real format, (3,4) for legacy format for match in _COMMENTARY_BLOCK_PATTERN.finditer(model_output): - tool_name = match.group(1) - args_str = match.group(2).strip() + tool_name = match.group(1) or match.group(3) + args_str = (match.group(2) or match.group(4) or "").strip() try: arguments = json.loads(args_str) @@ -193,7 +199,7 @@ def extract_tool_calls_streaming( def has_pending_tool_call(self, text: str) -> bool: """Check if text contains incomplete Harmony tool call markup.""" - return "commentary to=functions." in text + return "to=functions." in text def _strip_control_tokens(text: str) -> str: diff --git a/vllm_mlx/tool_parsers/qwen3coder_tool_parser.py b/vllm_mlx/tool_parsers/qwen3coder_tool_parser.py new file mode 100644 index 0000000..31b34cb --- /dev/null +++ b/vllm_mlx/tool_parsers/qwen3coder_tool_parser.py @@ -0,0 +1,531 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Qwen3-Coder XML tool call parser for vllm-mlx. + +Ported from vLLM upstream (vllm/tool_parsers/qwen3coder_tool_parser.py). + +Format: + + + VALUE + + + +Similar to Seed-OSS but without the seed: namespace prefix. +""" + +import ast +import json +import logging +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + +logger = logging.getLogger(__name__) + + +def _generate_tool_id() -> str: + return f"call_{uuid.uuid4().hex[:8]}" + + +def _get_arguments_config(func_name: str, tools: list[dict] | None) -> dict: + """Extract argument config from tools list for type conversion.""" + if tools is None: + return {} + for tool in tools: + if not isinstance(tool, dict): + continue + func = tool.get("function", {}) + if func.get("name") == func_name: + params = func.get("parameters", {}) + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + return {} + return {} + + +def _convert_param_value( + param_value: str, param_name: str, param_config: dict, func_name: str +) -> Any: + """Convert parameter value based on its type in the schema.""" + if param_value.lower() == "null": + return None + + if param_name not in param_config: + return param_value + + cfg = param_config[param_name] + if isinstance(cfg, dict) and "type" in cfg: + param_type = str(cfg["type"]).strip().lower() + else: + param_type = "string" + + if param_type in ("string", "str", "text", "varchar", "char", "enum"): + return param_value + elif param_type.startswith(("int", "uint", "long", "short", "unsigned")): + try: + return int(param_value) + except (ValueError, TypeError): + return param_value + elif param_type.startswith(("num", "float", "double")): + try: + return float(param_value) + except (ValueError, TypeError): + return param_value + elif param_type in ("boolean", "bool", "binary"): + return param_value.lower() == "true" + else: + if param_type in ("object", "array", "arr") or param_type.startswith( + ("dict", "list") + ): + try: + return json.loads(param_value) + except (json.JSONDecodeError, TypeError, ValueError): + pass + try: + return ast.literal_eval(param_value) + except (ValueError, SyntaxError): + return param_value + + +@ToolParserManager.register_module(["qwen3_coder_xml", "qwen3_xml"]) +class Qwen3CoderToolParser(ToolParser): + """ + Tool call parser for Qwen3-Coder models using XML format. + + Supports the XML-based tool call format with / + tags and type conversion from tool schema. + + Used when --enable-auto-tool-choice --tool-call-parser qwen3_coder_xml are set. + """ + + SUPPORTS_NATIVE_TOOL_FORMAT = True + + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + + self.tool_call_start_token = "" + self.tool_call_end_token = "" + self.tool_call_prefix = "(.*?)", re.DOTALL + ) + self.tool_call_regex = re.compile( + r"(.*?)|(.*?)$", re.DOTALL + ) + self.tool_call_function_regex = re.compile( + r"||(?=)|$)", + re.DOTALL, + ) + + # Token IDs for streaming (graceful fallback if tokenizer absent) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + self._reset_streaming_state() + + def _reset_streaming_state(self): + self.current_tool_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self._current_tool_id = None + self.current_function_name = None + self.param_count = 0 + self.in_param = False + self.in_function = False + self.accumulated_text = "" + self.json_started = False + self.json_closed = False + self.accumulated_params = {} + self._streaming_request = None + self.prev_tool_call_arr = [] + + def _parse_xml_function_call( + self, function_call_str: str, tools: list[dict] | None + ) -> dict | None: + """Parse a single function call from XML and return a tool call dict.""" + try: + end_index = function_call_str.index(">") + except ValueError: + return None + function_name = function_call_str[:end_index] + param_config = _get_arguments_config(function_name, tools) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match_text in self.tool_call_parameter_regex.findall(parameters): + try: + idx = match_text.index(">") + except ValueError: + continue + p_name = match_text[:idx] + p_value = str(match_text[idx + 1:]) + if p_value.startswith("\n"): + p_value = p_value[1:] + if p_value.endswith("\n"): + p_value = p_value[:-1] + param_dict[p_name] = _convert_param_value( + p_value, p_name, param_config, function_name + ) + return { + "id": _generate_tool_id(), + "name": function_name, + "arguments": json.dumps(param_dict, ensure_ascii=False), + } + + def _get_function_calls(self, model_output: str) -> list[str]: + matched_ranges = self.tool_call_regex.findall(model_output) + raw_tool_calls = [ + m[0] if m[0] else m[1] for m in matched_ranges + ] + if not raw_tool_calls: + raw_tool_calls = [model_output] + + raw_function_calls = [] + for tc in raw_tool_calls: + raw_function_calls.extend(self.tool_call_function_regex.findall(tc)) + return [m[0] if m[0] else m[1] for m in raw_function_calls] + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + if self.tool_call_prefix not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + function_calls = self._get_function_calls(model_output) + if not function_calls: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + tools = None + if request and isinstance(request, dict): + tools = request.get("tools") + + tool_calls = [] + for fc_str in function_calls: + tc = self._parse_xml_function_call(fc_str, tools) + if tc: + tool_calls.append(tc) + + # Extract content before tool calls + content_index = model_output.find(self.tool_call_start_token) + idx = model_output.find(self.tool_call_prefix) + content_index = content_index if content_index >= 0 else idx + content = model_output[:content_index] + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=content if content else None, + ) + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + if not previous_text: + self._reset_streaming_state() + self._streaming_request = request + + if not delta_text: + return None + + delta_token_ids = delta_token_ids or [] + self.accumulated_text = current_text + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + tool_ends = current_text.count(self.tool_call_end_token) + if tool_ends > self.current_tool_index: + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + self.accumulated_params = {} + if self.current_tool_index >= current_text.count( + self.tool_call_start_token + ): + self.is_tool_call_started = False + return None + + # Handle content before tool calls + if not self.is_tool_call_started: + if ( + self.tool_call_start_token_id is not None + and self.tool_call_start_token_id in delta_token_ids + ) or self.tool_call_start_token in delta_text: + self.is_tool_call_started = True + if self.tool_call_start_token in delta_text: + content_before = delta_text[ + : delta_text.index(self.tool_call_start_token) + ] + if content_before: + return {"content": content_before} + # Fall through to header parsing below instead of returning + # None — the function header may already be in current_text. + else: + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): + return None + return {"content": delta_text} + + # Find current tool call portion + tool_starts_count = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + return None + + tool_start_positions: list[int] = [] + idx = 0 + while True: + idx = current_text.find(self.tool_call_start_token, idx) + if idx == -1: + break + tool_start_positions.append(idx) + idx += len(self.tool_call_start_token) + + if self.current_tool_index >= len(tool_start_positions): + return None + + tool_start_idx = tool_start_positions[self.current_tool_index] + tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) + if tool_end_idx == -1: + tool_text = current_text[tool_start_idx:] + else: + tool_text = current_text[ + tool_start_idx: tool_end_idx + len(self.tool_call_end_token) + ] + + # Parse function header + if not self.header_sent: + if self.tool_call_prefix in tool_text: + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix + ) + func_end = tool_text.find(">", func_start) + if func_end != -1: + self.current_function_name = tool_text[func_start:func_end] + self._current_tool_id = _generate_tool_id() + self.header_sent = True + self.in_function = True + + # If the function body is already complete, emit the full + # tool call in one chunk to prevent header-only output + # when coarse deltas or max_tokens truncation leave no + # further parser calls. + if self.function_end_token in tool_text: + tools = None + if request and isinstance(request, dict): + tools = request.get("tools") + fc = tool_text[func_start:tool_text.find( + self.function_end_token, func_start + )] + parsed = self._parse_xml_function_call(fc, tools) + args = parsed["arguments"] if parsed else "{}" + self.json_started = True + self.json_closed = True + self.in_function = False + self.accumulated_params = {} + self.prev_tool_call_arr.append( + {"name": self.current_function_name, "arguments": args} + ) + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "id": self._current_tool_id, + "type": "function", + "function": { + "name": self.current_function_name, + "arguments": args, + }, + } + ] + } + + self.prev_tool_call_arr.append( + {"name": self.current_function_name, "arguments": "{}"} + ) + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "id": self._current_tool_id, + "type": "function", + "function": { + "name": self.current_function_name, + "arguments": "", + }, + } + ] + } + return None + + # Handle function body + if self.in_function: + if not self.json_started: + self.json_started = True + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "function": {"arguments": "{"}, + } + ] + } + + # Find all parameter start positions + param_starts = [] + si = 0 + while True: + si = tool_text.find(self.parameter_prefix, si) + if si == -1: + break + param_starts.append(si) + si += len(self.parameter_prefix) + + # Process complete parameters + json_fragments = [] + while not self.in_param and self.param_count < len(param_starts): + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" not in remaining: + break + + name_end = remaining.find(">") + current_param_name = remaining[:name_end] + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # Try next parameter or function end as delimiter + next_param = value_text.find(self.parameter_prefix) + func_end = value_text.find(self.function_end_token) + if next_param != -1 and (func_end == -1 or next_param < func_end): + param_end_idx = next_param + elif func_end != -1: + param_end_idx = func_end + else: + tool_end_in_val = value_text.find(self.tool_call_end_token) + if tool_end_in_val != -1: + param_end_idx = tool_end_in_val + else: + break + + if param_end_idx == -1: + break + + pv = value_text[:param_end_idx] + if pv.endswith("\n"): + pv = pv[:-1] + + self.accumulated_params[current_param_name] = pv + + # Type conversion + tools = None + if self._streaming_request: + tools = self._streaming_request.get("tools") if isinstance( + self._streaming_request, dict + ) else None + param_config = _get_arguments_config( + self.current_function_name or "", tools + ) + converted = _convert_param_value( + pv, current_param_name, param_config, self.current_function_name or "" + ) + serialized = json.dumps(converted, ensure_ascii=False) + + if self.param_count == 0: + frag = f'"{current_param_name}": {serialized}' + else: + frag = f', "{current_param_name}": {serialized}' + self.param_count += 1 + json_fragments.append(frag) + + if json_fragments: + combined = "".join(json_fragments) + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "function": {"arguments": combined}, + } + ] + } + + # Check for function end + if not self.json_closed and self.function_end_token in tool_text: + self.json_closed = True + + # Update prev_tool_call_arr with final arguments + tools = None + if self._streaming_request: + tools = self._streaming_request.get("tools") if isinstance( + self._streaming_request, dict + ) else None + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix + ) + func_content_end = tool_text.find(self.function_end_token, func_start) + if func_content_end != -1: + fc = tool_text[func_start:func_content_end] + try: + parsed = self._parse_xml_function_call(fc, tools) + if parsed and self.current_tool_index < len( + self.prev_tool_call_arr + ): + self.prev_tool_call_arr[self.current_tool_index][ + "arguments" + ] = parsed["arguments"] + except Exception: + pass + + self.in_function = False + self.accumulated_params = {} + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "function": {"arguments": "}"}, + } + ] + } + + return None diff --git a/vllm_mlx/tool_parsers/seed_oss_tool_parser.py b/vllm_mlx/tool_parsers/seed_oss_tool_parser.py new file mode 100644 index 0000000..ebc8c81 --- /dev/null +++ b/vllm_mlx/tool_parsers/seed_oss_tool_parser.py @@ -0,0 +1,537 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Seed-OSS tool call parser for vllm-mlx. + +Ported from vLLM upstream (vllm/tool_parsers/seed_oss_tool_parser.py). + +Format: + + + VALUE + + + +Thinking: + ... +""" + +import ast +import json +import logging +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + +logger = logging.getLogger(__name__) + + +def _generate_tool_id() -> str: + return f"call_{uuid.uuid4().hex[:8]}" + + +def _get_arguments_config(func_name: str, tools: list[dict] | None) -> dict: + """Extract argument config from tools list for type conversion.""" + if tools is None: + return {} + for tool in tools: + if not isinstance(tool, dict): + continue + func = tool.get("function", {}) + if func.get("name") == func_name: + params = func.get("parameters", {}) + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + return {} + return {} + + +def _convert_param_value( + param_value: str, param_name: str, param_config: dict, func_name: str +) -> Any: + """Convert parameter value based on its type in the schema.""" + if param_value.lower() == "null": + return None + + if param_name not in param_config: + return param_value + + cfg = param_config[param_name] + if isinstance(cfg, dict) and "type" in cfg: + param_type = str(cfg["type"]).strip().lower() + else: + param_type = "string" + + if param_type in ("string", "str", "text", "varchar", "char", "enum"): + return param_value + elif param_type.startswith(("int", "uint", "long", "short", "unsigned")): + try: + return int(param_value) + except (ValueError, TypeError): + return param_value + elif param_type.startswith(("num", "float", "double")): + try: + return float(param_value) + except (ValueError, TypeError): + return param_value + elif param_type in ("boolean", "bool", "binary"): + return param_value.lower() == "true" + else: + if param_type == "object" or param_type.startswith("dict"): + try: + return json.loads(param_value) + except (ValueError, TypeError, json.JSONDecodeError): + pass + try: + return ast.literal_eval(param_value) + except (ValueError, SyntaxError): + return param_value + + +@ToolParserManager.register_module(["seed_oss", "seed", "gpt_oss"]) +class SeedOssToolParser(ToolParser): + """ + Tool call parser for Seed-OSS / GPT-OSS models. + + Supports the XML-based tool call format with wrapper + and thinking blocks. + + Used when --enable-auto-tool-choice --tool-call-parser seed_oss are set. + """ + + SUPPORTS_NATIVE_TOOL_FORMAT = True + + TOOL_CALL_START = "" + TOOL_CALL_END = "" + + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + + self.tool_call_start_token = self.TOOL_CALL_START + self.tool_call_end_token = self.TOOL_CALL_END + self.tool_call_prefix = "|| dict | None: + """Parse a single function call from XML and return a tool call dict.""" + try: + end_index = function_call_str.index(">") + except ValueError: + return None + function_name = function_call_str[:end_index] + param_config = _get_arguments_config(function_name, tools) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match in self.tool_call_parameter_regex.findall(parameters): + match_text = match[0] if match[0] else match[1] + try: + idx = match_text.index(">") + except ValueError: + continue + p_name = match_text[:idx] + p_value = str(match_text[idx + 1:]) + if p_value.startswith("\n"): + p_value = p_value[1:] + if p_value.endswith("\n"): + p_value = p_value[:-1] + param_dict[p_name] = _convert_param_value( + p_value, p_name, param_config, function_name + ) + return { + "id": _generate_tool_id(), + "name": function_name, + "arguments": json.dumps(param_dict, ensure_ascii=False), + } + + def _get_function_calls(self, model_output: str) -> list[str]: + matched_ranges = self.tool_call_regex.findall(model_output) + raw_tool_calls = [ + m[0] if m[0] else m[1] for m in matched_ranges + ] + if not raw_tool_calls: + raw_tool_calls = [model_output] + + raw_function_calls = [] + for tc in raw_tool_calls: + raw_function_calls.extend(self.tool_call_function_regex.findall(tc)) + return [m[0] if m[0] else m[1] for m in raw_function_calls] + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + if self.tool_call_prefix not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + # Handle ... + if ( + self.think_start_token in model_output + and self.think_end_token in model_output + ): + think_end_index = model_output.find(self.think_end_token) + len( + self.think_end_token + ) + result_content = model_output[think_end_index:] + thinking_content = model_output[:think_end_index] + else: + thinking_content = "" + result_content = model_output + + try: + function_calls = self._get_function_calls(result_content) + if not function_calls: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + tools = None + if request and isinstance(request, dict): + tools = request.get("tools") + + tool_calls = [] + for fc_str in function_calls: + tc = self._parse_xml_function_call(fc_str, tools) + if tc: + tool_calls.append(tc) + + # Extract content before tool calls + tc_start = result_content.find(self.tool_call_start_token) + if tc_start < 0: + tc_start = result_content.find(self.tool_call_prefix) + content = thinking_content + result_content[:tc_start] + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=content if content else None, + ) + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def has_pending_tool_call(self, text: str) -> bool: + return ( + self.TOOL_CALL_START in text + or self.tool_call_prefix in text + or self.has_text_format_tool_call(text) + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + if not previous_text: + self._reset_streaming_state() + + if not delta_text: + return None + + self.accumulated_text = current_text + delta_token_ids = delta_token_ids or [] + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + tool_ends = current_text.count(self.tool_call_end_token) + if tool_ends > self.current_tool_index: + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + if self.current_tool_index >= current_text.count( + self.tool_call_start_token + ): + self.is_tool_call_started = False + return None + + # Check if thinking ended (or never started) + if not self.is_thinking_end: + # If there's no in the text at all, skip thinking gate + if self.think_start_token not in current_text: + self.is_thinking_end = True + elif ( + self.think_end_token_id is not None + and self.think_end_token_id in delta_token_ids + ) or self.think_end_token in delta_text: + self.is_thinking_end = True + + if not self.is_thinking_end: + return {"content": delta_text} + + # Handle content before tool calls + if not self.is_tool_call_started: + if ( + self.tool_call_start_token_id is not None + and self.tool_call_start_token_id in delta_token_ids + ) or self.tool_call_start_token in delta_text: + self.is_tool_call_started = True + if self.tool_call_start_token in delta_text: + content_before = delta_text[ + : delta_text.index(self.tool_call_start_token) + ] + if content_before: + return {"content": content_before} + # Fall through to header parsing below instead of returning + # None — the function header may already be in current_text. + else: + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): + return None + return {"content": delta_text} + + # Find current tool call portion + tool_starts_count = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + return None + + # Locate tool text + think_end_idx = 0 + if self.think_end_token in current_text: + think_end_idx = current_text.find(self.think_end_token) + len( + self.think_end_token + ) + tool_starts: list[int] = [] + idx = think_end_idx + while True: + idx = current_text.find(self.tool_call_start_token, idx) + if idx == -1: + break + tool_starts.append(idx) + idx += len(self.tool_call_start_token) + + if self.current_tool_index >= len(tool_starts): + return None + + tool_start_idx = tool_starts[self.current_tool_index] + tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) + if tool_end_idx == -1: + tool_text = current_text[tool_start_idx:] + else: + tool_text = current_text[ + tool_start_idx: tool_end_idx + len(self.tool_call_end_token) + ] + + # Parse function header + if not self.header_sent: + if self.tool_call_prefix in tool_text: + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix + ) + func_end = tool_text.find(">", func_start) + if func_end != -1: + self.current_function_name = tool_text[func_start:func_end] + self.current_tool_id = _generate_tool_id() + self.header_sent = True + self.in_function = True + + # If the function body is already complete, emit the full + # tool call in one chunk. This prevents header-only output + # when coarse deltas (or max_tokens truncation) leave no + # further parser calls to emit the arguments. + if self.function_end_token in tool_text: + tools = None + if request and isinstance(request, dict): + tools = request.get("tools") + fc = tool_text[func_start:tool_text.find( + self.function_end_token, func_start + )] + parsed = self._parse_xml_function_call(fc, tools) + args = parsed["arguments"] if parsed else "{}" + self.json_started = True + self.json_closed = True + self.in_function = False + self.prev_tool_call_arr.append( + {"name": self.current_function_name, "arguments": args} + ) + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "id": self.current_tool_id, + "type": "function", + "function": { + "name": self.current_function_name, + "arguments": args, + }, + } + ] + } + + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "id": self.current_tool_id, + "type": "function", + "function": { + "name": self.current_function_name, + "arguments": "", + }, + } + ] + } + return None + + # Handle function body + if self.in_function: + if not self.json_started: + self.json_started = True + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "function": {"arguments": "{"}, + } + ] + } + + # Check for function end + if not self.json_closed and self.function_end_token in tool_text: + self.json_closed = True + self.in_function = False + + # Extract complete params for prev_tool_call_arr + tools = None + if request and isinstance(request, dict): + tools = request.get("tools") + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix + ) + func_content_end = tool_text.find(self.function_end_token, func_start) + if func_content_end != -1: + fc = tool_text[func_start:func_content_end] + parsed = self._parse_xml_function_call(fc, tools) + if parsed: + self.prev_tool_call_arr.append( + {"name": parsed["name"], "arguments": parsed["arguments"]} + ) + + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "function": {"arguments": "}"}, + } + ] + } + + # Look for complete parameters + complete_params = tool_text.count(self.parameter_end_token) + if not self.in_param and self.param_count < complete_params: + param_starts = [] + si = 0 + while True: + si = tool_text.find(self.parameter_prefix, si) + if si == -1: + break + param_starts.append(si) + si += len(self.parameter_prefix) + + if len(param_starts) > self.param_count: + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + if ">" in remaining: + name_end = remaining.find(">") + param_name = remaining[:name_end] + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx != -1: + pv = value_text[:param_end_idx] + if pv.endswith("\n"): + pv = pv[:-1] + # Type conversion using tool schema + tools = None + if request and isinstance(request, dict): + tools = request.get("tools") + param_config = _get_arguments_config( + self.current_function_name or "", tools + ) + converted = _convert_param_value( + pv, param_name, param_config, + self.current_function_name or "" + ) + serialized = json.dumps(converted, ensure_ascii=False) + if self.param_count == 0: + frag = f'"{param_name}": {serialized}' + else: + frag = f', "{param_name}": {serialized}' + self.param_count += 1 + return { + "tool_calls": [ + { + "index": self.current_tool_index, + "function": {"arguments": frag}, + } + ] + } + + return None diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index 0a2991a..4d38693 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -58,8 +58,8 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): logger.warning(f"Standard tokenizer loading failed, using fallback: {e}") return _load_with_tokenizer_fallback(model_name) # Fallback for multimodal models loaded as text-only (skip vision weights) - elif "parameters not in model" in str(e): - logger.warning(f"Model has extra parameters (likely vision weights), retrying with strict=False: {e}") + elif "parameters not in model" in str(e) or ("Missing" in str(e) and "parameters" in str(e)): + logger.warning(f"Model has extra/missing parameters (likely VLM weights), retrying with strict=False: {e}") return _load_non_strict(model_name, tokenizer_config) else: raise