diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bc3f19b3a..3f8799c85 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,6 +82,7 @@ jobs: tests/test_request.py \ tests/test_anthropic_models.py \ tests/test_anthropic_adapter.py \ + tests/test_harmony_parsers.py \ -v --tb=short \ -k "not Integration and not InjectJson and not TestMLXMultimodalLMCache" \ --cov=vllm_mlx \ diff --git a/tests/test_harmony_parsers.py b/tests/test_harmony_parsers.py new file mode 100644 index 000000000..9ca509f61 --- /dev/null +++ b/tests/test_harmony_parsers.py @@ -0,0 +1,734 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for Harmony format parsers (GPT-OSS models). + +Tests cover: +- HarmonyToolParser: tool call extraction from commentary channel +- HarmonyReasoningParser: reasoning extraction from analysis channel +- convert_tools_to_typescript: OpenAI JSON Schema to TypeScript conversion + +Usage: + pytest tests/test_harmony_parsers.py -v +""" + +import json + +import pytest + +from vllm_mlx.api.harmony_tools import convert_tools_to_typescript +from vllm_mlx.reasoning import get_parser +from vllm_mlx.reasoning.harmony_parser import HarmonyReasoningParser +from vllm_mlx.tool_parsers import ToolParserManager +from vllm_mlx.tool_parsers.harmony_tool_parser import HarmonyToolParser + +# ============================================================================ +# Tool Parser Tests +# ============================================================================ + + +class TestHarmonyToolParser: + """Tests for HarmonyToolParser.""" + + @pytest.fixture() + def parser(self): + return HarmonyToolParser() + + def test_registration(self): + """Parser is registered under harmony and gpt-oss names.""" + assert ToolParserManager.get_tool_parser("harmony") is HarmonyToolParser + assert ToolParserManager.get_tool_parser("gpt-oss") is HarmonyToolParser + + def test_single_tool_call(self, parser): + """Parse a single tool call from commentary channel.""" + text = ( + "<|start|>\n" + "<|channel|>commentary to=functions.get_weather\n" + "<|constrain|>json\n" + '<|message|>{"location": "San Francisco", "unit": "celsius"}\n' + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["location"] == "San Francisco" + assert args["unit"] == "celsius" + + def test_tool_call_with_analysis_and_final(self, parser): + """Parse tool call when analysis and final channels are present.""" + text = ( + "<|start|>\n" + "<|channel|>analysis\n" + "<|message|>The user wants weather. I should call get_weather.\n" + "<|end|>\n" + "<|channel|>commentary to=functions.get_weather\n" + "<|constrain|>json\n" + '<|message|>{"location": "SF"}\n' + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_weather" + + def test_final_response_only(self, parser): + """Parse response with no tool calls (final channel only).""" + text = ( + "<|start|>\n" + "<|channel|>final\n" + "<|message|>The weather in San Francisco is 72F and sunny!\n" + "<|return|>" + ) + result = parser.extract_tool_calls(text) + + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == "The weather in San Francisco is 72F and sunny!" + + def test_multiple_tool_calls(self, parser): + """Parse multiple tool calls from separate commentary blocks.""" + text = ( + "<|start|>\n" + "<|channel|>commentary to=functions.get_weather\n" + "<|constrain|>json\n" + '<|message|>{"location": "SF"}\n' + "<|call|>\n" + "<|channel|>commentary to=functions.get_time\n" + "<|constrain|>json\n" + '<|message|>{"timezone": "PST"}\n' + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["name"] == "get_weather" + assert result.tool_calls[1]["name"] == "get_time" + + def test_tool_call_without_constrain(self, parser): + """Parse tool call without <|constrain|>json tag.""" + text = ( + "<|channel|>commentary to=functions.simple_func\n" + '<|message|>{"arg": "value"}\n' + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "simple_func" + + def test_malformed_json_arguments(self, parser): + """Handle malformed JSON gracefully by keeping raw string.""" + text = ( + "<|channel|>commentary to=functions.broken_func\n" + "<|constrain|>json\n" + "<|message|>{invalid json here}\n" + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "broken_func" + assert result.tool_calls[0]["arguments"] == "{invalid json here}" + + def test_tool_call_with_final_content(self, parser): + """Tool calls coexist with final channel content.""" + text = ( + "<|channel|>commentary to=functions.search\n" + "<|constrain|>json\n" + '<|message|>{"query": "python"}\n' + "<|call|>\n" + "<|channel|>final\n" + "<|message|>Here are the results.\n" + "<|return|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.content == "Here are the results." + + def test_empty_input(self, parser): + """Handle empty input.""" + result = parser.extract_tool_calls("") + assert not result.tools_called + assert result.tool_calls == [] + + def test_plain_text_input(self, parser): + """Handle plain text with no Harmony tokens.""" + result = parser.extract_tool_calls("Just a regular response.") + assert not result.tools_called + assert result.content == "Just a regular response." + + def test_unique_tool_ids(self, parser): + """Each tool call gets a unique ID.""" + text = ( + "<|channel|>commentary to=functions.func_a\n" + "<|constrain|>json\n" + "<|message|>{}\n" + "<|call|>\n" + "<|channel|>commentary to=functions.func_b\n" + "<|constrain|>json\n" + "<|message|>{}\n" + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + ids = [tc["id"] for tc in result.tool_calls] + assert len(set(ids)) == 2 + assert all(id_.startswith("call_") for id_ in ids) + + def test_nested_json_arguments(self, parser): + """Parse tool call with nested JSON arguments.""" + args = {"filter": {"type": "range", "min": 0, "max": 100}, "sort": "asc"} + text = ( + "<|channel|>commentary to=functions.query\n" + "<|constrain|>json\n" + f"<|message|>{json.dumps(args)}\n" + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + parsed_args = json.loads(result.tool_calls[0]["arguments"]) + assert parsed_args["filter"]["type"] == "range" + + def test_streaming_no_tool_markers(self, parser): + """Streaming: plain text passes through as content.""" + result = parser.extract_tool_calls_streaming("", "Hello", "Hello") + assert result == {"content": "Hello"} + + def test_streaming_tool_call_complete(self, parser): + """Streaming: emit tool calls when <|call|> appears.""" + current = ( + "<|channel|>commentary to=functions.func\n" + "<|constrain|>json\n" + '<|message|>{"a": 1}\n' + "<|call|>" + ) + result = parser.extract_tool_calls_streaming("", current, "<|call|>") + + assert result is not None + assert "tool_calls" in result + assert result["tool_calls"][0]["function"]["name"] == "func" + + def test_streaming_building_tool_call(self, parser): + """Streaming: suppress output while building tool call.""" + current = ( + "<|channel|>commentary to=functions.func\n" + "<|constrain|>json\n" + '<|message|>{"a":' + ) + result = parser.extract_tool_calls_streaming("", current, '{"a":') + assert result is None + + +# ============================================================================ +# Reasoning Parser Tests +# ============================================================================ + + +class TestHarmonyReasoningParser: + """Tests for HarmonyReasoningParser.""" + + @pytest.fixture() + def parser(self): + return HarmonyReasoningParser() + + def test_registration(self): + """Parser is registered under the harmony name.""" + parser_cls = get_parser("harmony") + assert parser_cls is HarmonyReasoningParser + + def test_extract_analysis_and_final(self, parser): + """Extract reasoning from analysis and content from final.""" + output = ( + "<|channel|>analysis\n" + "<|message|>Let me think step by step.\n" + "<|end|>\n" + "<|channel|>final\n" + "<|message|>The answer is 42.\n" + "<|return|>" + ) + reasoning, content = parser.extract_reasoning(output) + + assert reasoning == "Let me think step by step." + assert content == "The answer is 42." + + def test_multiple_analysis_blocks(self, parser): + """Concatenate multiple analysis blocks.""" + output = ( + "<|channel|>analysis\n" + "<|message|>First thought.\n" + "<|end|>\n" + "<|channel|>analysis\n" + "<|message|>Second thought.\n" + "<|end|>\n" + "<|channel|>final\n" + "<|message|>Result.\n" + "<|return|>" + ) + reasoning, content = parser.extract_reasoning(output) + + assert "First thought." in reasoning + assert "Second thought." in reasoning + assert content == "Result." + + def test_no_analysis_channel(self, parser): + """Output with no analysis returns None reasoning.""" + output = "<|channel|>final\n" "<|message|>Direct answer.\n" "<|return|>" + reasoning, content = parser.extract_reasoning(output) + + assert reasoning is None + assert content == "Direct answer." + + def test_analysis_only_no_final(self, parser): + """Output with only analysis returns None content.""" + output = "<|channel|>analysis\n" "<|message|>Just thinking...\n" "<|end|>" + reasoning, content = parser.extract_reasoning(output) + + assert reasoning == "Just thinking..." + assert content is None + + def test_empty_input(self, parser): + """Handle empty input.""" + reasoning, content = parser.extract_reasoning("") + assert reasoning is None + assert content is None + + def test_analysis_with_commentary_and_final(self, parser): + """Ignore commentary channel, extract analysis and final.""" + output = ( + "<|channel|>analysis\n" + "<|message|>Need to call a tool.\n" + "<|end|>\n" + "<|channel|>commentary to=functions.search\n" + "<|constrain|>json\n" + '<|message|>{"q": "test"}\n' + "<|call|>\n" + "<|channel|>final\n" + "<|message|>Found results.\n" + "<|return|>" + ) + reasoning, content = parser.extract_reasoning(output) + + assert reasoning == "Need to call a tool." + assert content == "Found results." + + def test_streaming_analysis_to_final(self, parser): + """Streaming: emit reasoning for analysis, content for final.""" + parser.reset_state() + + # Channel switch to analysis + r1 = parser.extract_reasoning_streaming( + "", "<|channel|>analysis\n", "<|channel|>analysis\n" + ) + assert r1 is None # channel switch, no content yet + + # Message start + r2 = parser.extract_reasoning_streaming( + "<|channel|>analysis\n", + "<|channel|>analysis\n<|message|>", + "<|message|>", + ) + assert r2 is None # message start token + + # Reasoning content + r3 = parser.extract_reasoning_streaming( + "<|channel|>analysis\n<|message|>", + "<|channel|>analysis\n<|message|>Thinking", + "Thinking", + ) + assert r3 is not None + assert r3.reasoning == "Thinking" + assert r3.content is None + + # End of analysis + r4 = parser.extract_reasoning_streaming( + "<|channel|>analysis\n<|message|>Thinking", + "<|channel|>analysis\n<|message|>Thinking<|end|>", + "<|end|>", + ) + assert r4 is None # end token + + # Switch to final + r5 = parser.extract_reasoning_streaming( + "<|channel|>analysis\n<|message|>Thinking<|end|>", + "<|channel|>analysis\n<|message|>Thinking<|end|>\n<|channel|>final\n", + "\n<|channel|>final\n", + ) + assert r5 is None # channel switch + + # Final message content + prev = "<|channel|>analysis\n<|message|>Thinking<|end|>\n<|channel|>final\n<|message|>" + parser.extract_reasoning_streaming( + "<|channel|>analysis\n<|message|>Thinking<|end|>\n<|channel|>final\n", + prev, + "<|message|>", + ) + r6 = parser.extract_reasoning_streaming( + prev, + prev + "Answer", + "Answer", + ) + assert r6 is not None + assert r6.content == "Answer" + assert r6.reasoning is None + + def test_streaming_reset(self, parser): + """Reset clears internal state.""" + parser._current_channel = "analysis" + parser._in_message = True + parser.reset_state() + assert parser._current_channel is None + assert parser._in_message is False + + def test_streaming_commentary_suppressed(self, parser): + """Streaming: commentary channel output is suppressed.""" + parser.reset_state() + + parser.extract_reasoning_streaming( + "", + "<|channel|>commentary to=functions.f\n", + "<|channel|>commentary to=functions.f\n", + ) + parser.extract_reasoning_streaming( + "<|channel|>commentary to=functions.f\n", + "<|channel|>commentary to=functions.f\n<|message|>", + "<|message|>", + ) + r = parser.extract_reasoning_streaming( + "<|channel|>commentary to=functions.f\n<|message|>", + '<|channel|>commentary to=functions.f\n<|message|>{"a":1}', + '{"a":1}', + ) + assert r is None + + +# ============================================================================ +# TypeScript Tool Converter Tests +# ============================================================================ + + +class TestHarmonyToolDefinitionConverter: + """Tests for convert_tools_to_typescript.""" + + def test_simple_tool(self): + """Convert a simple tool with required parameters.""" + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + }, + }, + } + ] + + result = convert_tools_to_typescript(tools) + + assert "namespace functions" in result + assert "get_weather" in result + assert "location: string," in result + assert "// Get weather for a location" in result + + def test_optional_parameters(self): + """Optional parameters get ? suffix.""" + tools = [ + { + "type": "function", + "function": { + "name": "func", + "parameters": { + "type": "object", + "properties": { + "required_param": {"type": "string"}, + "optional_param": {"type": "number"}, + }, + "required": ["required_param"], + }, + }, + } + ] + + result = convert_tools_to_typescript(tools) + + assert "required_param: string," in result + assert "optional_param?: number," in result + + def test_enum_type(self): + """Enums become TypeScript union types.""" + tools = [ + { + "type": "function", + "function": { + "name": "set_unit", + "parameters": { + "type": "object", + "properties": { + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + }, + }, + } + ] + + result = convert_tools_to_typescript(tools) + + assert '"celsius" | "fahrenheit"' in result + + def test_multiple_tools(self): + """Multiple tools in one namespace.""" + tools = [ + { + "type": "function", + "function": { + "name": "func_a", + "description": "First function", + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "func_b", + "description": "Second function", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ] + + result = convert_tools_to_typescript(tools) + + assert "func_a" in result + assert "func_b" in result + assert "// First function" in result + assert "// Second function" in result + + def test_no_tools(self): + """None input returns None.""" + assert convert_tools_to_typescript(None) is None + assert convert_tools_to_typescript([]) is None + + def test_no_parameters(self): + """Tool with no parameters uses empty signature.""" + tools = [ + { + "type": "function", + "function": { + "name": "ping", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + result = convert_tools_to_typescript(tools) + + assert "type ping = () => any;" in result + + def test_array_type(self): + """Array types convert to Array.""" + tools = [ + { + "type": "function", + "function": { + "name": "process", + "parameters": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "string"}, + }, + }, + }, + }, + } + ] + + result = convert_tools_to_typescript(tools) + + assert "Array" in result + + def test_boolean_and_integer_types(self): + """Boolean and integer map correctly.""" + tools = [ + { + "type": "function", + "function": { + "name": "config", + "parameters": { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + "count": {"type": "integer"}, + }, + }, + }, + } + ] + + result = convert_tools_to_typescript(tools) + + assert "enabled?: boolean," in result + assert "count?: number," in result + + def test_no_description(self): + """Tool without description has no comment line.""" + tools = [ + { + "type": "function", + "function": { + "name": "no_desc", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + result = convert_tools_to_typescript(tools) + + assert "//" not in result + assert "no_desc" in result + + def test_skips_non_function_types(self): + """Non-function tools are skipped.""" + tools = [ + {"type": "retrieval"}, + { + "type": "function", + "function": { + "name": "real_func", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ] + + result = convert_tools_to_typescript(tools) + + assert "real_func" in result + assert "retrieval" not in result + + +# ============================================================================ +# Edge Case Tests +# ============================================================================ + + +class TestHarmonyEdgeCases: + """Edge case tests for Harmony parsers.""" + + def test_tool_parser_incomplete_call(self): + """Incomplete tool call (missing <|call|>) is not parsed.""" + parser = HarmonyToolParser() + text = "<|channel|>commentary to=functions.func\n" '<|message|>{"arg": "value"}' + result = parser.extract_tool_calls(text) + assert not result.tools_called + + def test_tool_parser_unicode_content(self): + """Handle unicode in tool arguments.""" + parser = HarmonyToolParser() + text = ( + "<|channel|>commentary to=functions.translate\n" + "<|constrain|>json\n" + '<|message|>{"text": "日本語テスト"}\n' + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["text"] == "日本語テスト" + + def test_reasoning_parser_unicode_content(self): + """Handle unicode in reasoning and content.""" + parser = HarmonyReasoningParser() + output = ( + "<|channel|>analysis\n" + "<|message|>让我想想...\n" + "<|end|>\n" + "<|channel|>final\n" + "<|message|>答案是42。\n" + "<|return|>" + ) + reasoning, content = parser.extract_reasoning(output) + + assert reasoning == "让我想想..." + assert content == "答案是42。" + + def test_mixed_channels_full_flow(self): + """Full flow: analysis -> commentary -> analysis -> final.""" + text = ( + "<|start|>\n" + "<|channel|>analysis\n" + "<|message|>Think 1.\n" + "<|end|>\n" + "<|channel|>commentary to=functions.search\n" + "<|constrain|>json\n" + '<|message|>{"q": "test"}\n' + "<|call|>\n" + "<|channel|>analysis\n" + "<|message|>Think 2.\n" + "<|end|>\n" + "<|channel|>final\n" + "<|message|>Done.\n" + "<|return|>" + ) + + # Tool parser finds tool calls + tool_parser = HarmonyToolParser() + tool_result = tool_parser.extract_tool_calls(text) + assert tool_result.tools_called + assert len(tool_result.tool_calls) == 1 + assert tool_result.tool_calls[0]["name"] == "search" + assert tool_result.content == "Done." + + # Reasoning parser finds both analysis blocks + reasoning_parser = HarmonyReasoningParser() + reasoning, content = reasoning_parser.extract_reasoning(text) + assert "Think 1." in reasoning + assert "Think 2." in reasoning + assert content == "Done." + + def test_tool_parser_empty_arguments(self): + """Tool call with empty JSON arguments.""" + parser = HarmonyToolParser() + text = ( + "<|channel|>commentary to=functions.ping\n" + "<|constrain|>json\n" + "<|message|>{}\n" + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert json.loads(result.tool_calls[0]["arguments"]) == {} + + def test_tool_parser_whitespace_handling(self): + """Handle extra whitespace in Harmony format.""" + parser = HarmonyToolParser() + text = ( + "<|channel|>commentary to=functions.func\n" + "<|constrain|>json\n" + '<|message|> {"key": "value"} \n' + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["key"] == "value" diff --git a/vllm_mlx/api/harmony_tools.py b/vllm_mlx/api/harmony_tools.py new file mode 100644 index 000000000..f9f98ba97 --- /dev/null +++ b/vllm_mlx/api/harmony_tools.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +TypeScript-style tool definition converter for Harmony/GPT-OSS models. + +Harmony models expect tool definitions in TypeScript namespace format: + + namespace functions { + // Get weather for a location + type get_weather = (_: { + location: string, + unit?: "celsius" | "fahrenheit" + }) => any; + } + +This module converts OpenAI JSON Schema tool definitions to that format. +""" + +from typing import Any + +# JSON Schema type to TypeScript type mapping +_TYPE_MAP = { + "string": "string", + "number": "number", + "integer": "number", + "boolean": "boolean", + "null": "null", + "object": "object", +} + + +def _convert_type(prop: dict[str, Any]) -> str: + """ + Convert a JSON Schema property to a TypeScript type string. + + Args: + prop: JSON Schema property definition. + + Returns: + TypeScript type string. + """ + # Enum: union of literal values + if "enum" in prop: + literals = [f'"{v}"' for v in prop["enum"]] + return " | ".join(literals) + + schema_type = prop.get("type", "any") + + # Array with items + if schema_type == "array": + items = prop.get("items", {}) + item_type = _convert_type(items) if items else "any" + return f"Array<{item_type}>" + + return _TYPE_MAP.get(schema_type, "any") + + +def convert_tools_to_typescript(tools: list[dict[str, Any]] | None) -> str | None: + """ + Convert OpenAI JSON Schema tool definitions to TypeScript namespace format. + + Args: + tools: List of tool definitions in OpenAI format, e.g.: + [{"type": "function", "function": {"name": "...", ...}}] + + Returns: + TypeScript namespace string, or None if no tools. + """ + if not tools: + return None + + functions = [] + for tool in tools: + if tool.get("type") != "function": + continue + func = tool.get("function", {}) + name = func.get("name", "") + if not name: + continue + + description = func.get("description", "") + parameters = func.get("parameters", {}) + properties = parameters.get("properties", {}) + required = set(parameters.get("required", [])) + + # Build parameter list + params = [] + for prop_name, prop_schema in properties.items(): + ts_type = _convert_type(prop_schema) + optional = "?" if prop_name not in required else "" + params.append(f" {prop_name}{optional}: {ts_type},") + + # Build function type + lines = [] + if description: + lines.append(f" // {description}") + + if params: + params_block = "\n".join(params) + lines.append(f" type {name} = (_: {{\n{params_block}\n }}) => any;") + else: + lines.append(f" type {name} = () => any;") + + functions.append("\n".join(lines)) + + if not functions: + return None + + body = "\n\n".join(functions) + return f"namespace functions {{\n{body}\n}}" diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py index eef9c62fb..55daa9e8d 100644 --- a/vllm_mlx/reasoning/__init__.py +++ b/vllm_mlx/reasoning/__init__.py @@ -76,10 +76,12 @@ def list_parsers() -> list[str]: def _register_builtin_parsers(): """Register built-in parsers.""" from .deepseek_r1_parser import DeepSeekR1ReasoningParser + from .harmony_parser import HarmonyReasoningParser from .qwen3_parser import Qwen3ReasoningParser register_parser("qwen3", Qwen3ReasoningParser) register_parser("deepseek_r1", DeepSeekR1ReasoningParser) + register_parser("harmony", HarmonyReasoningParser) # Register built-in parsers on module load diff --git a/vllm_mlx/reasoning/harmony_parser.py b/vllm_mlx/reasoning/harmony_parser.py new file mode 100644 index 000000000..73b94f0ea --- /dev/null +++ b/vllm_mlx/reasoning/harmony_parser.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for GPT-OSS models using Harmony format. + +Harmony uses channels for reasoning vs final content: + + <|channel|>analysis + <|message|>Let me think about this... + <|end|> + <|channel|>final + <|message|>The answer is 42. + <|return|> + +The analysis channel contains reasoning, and the final channel +contains the user-facing response. +""" + +import re + +from .base import DeltaMessage, ReasoningParser + +# Analysis channel blocks: <|channel|>analysis<|message|>...<|end|> +_ANALYSIS_PATTERN = re.compile( + r"<\|channel\|>analysis\s*<\|message\|>(.*?)<\|end\|>", + re.DOTALL, +) + +# Final channel content: <|channel|>final<|message|>...<|return|> +_FINAL_PATTERN = re.compile( + r"<\|channel\|>final\s*<\|message\|>(.*?)<\|return\|>", + re.DOTALL, +) + + +class HarmonyReasoningParser(ReasoningParser): + """ + Reasoning parser for GPT-OSS models using Harmony format. + + Extracts reasoning from the 'analysis' channel and content from + the 'final' channel. Commentary channels (tool calls) are ignored + since they are handled by the tool parser. + + Example: + Input: "<|channel|>analysis<|message|>Thinking...<|end|> + <|channel|>final<|message|>Result.<|return|>" + Output: reasoning="Thinking...", content="Result." + """ + + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + self._current_channel: str | None = None + self._in_message: bool = False + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from complete Harmony output. + + Collects all analysis channel blocks as reasoning and the + final channel block as content. + + Args: + model_output: Complete model output text. + + Returns: + (reasoning, content) tuple. Either may be None. + """ + # Collect all analysis blocks + analysis_blocks = _ANALYSIS_PATTERN.findall(model_output) + reasoning = "\n".join(block.strip() for block in analysis_blocks) or None + + # Extract final channel content + final_match = _FINAL_PATTERN.search(model_output) + content = final_match.group(1).strip() if final_match else None + + return reasoning, content + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming Harmony output. + + Tracks the current channel and emits reasoning deltas for + analysis channel content and content deltas for final channel. + + Args: + previous_text: Accumulated text before this delta. + current_text: Accumulated text including this delta. + delta_text: The new text in this streaming chunk. + + Returns: + DeltaMessage with reasoning and/or content, or None. + """ + # Detect channel switches in the delta + if "<|channel|>" in delta_text: + if "analysis" in delta_text: + self._current_channel = "analysis" + self._in_message = False + return None + elif "final" in delta_text: + self._current_channel = "final" + self._in_message = False + return None + elif "commentary" in delta_text: + self._current_channel = "commentary" + self._in_message = False + return None + + # Detect channel from full context if not yet determined + if self._current_channel is None and "<|channel|>" in current_text: + last_channel = current_text.rfind("<|channel|>") + after = current_text[last_channel + len("<|channel|>") :] + if after.startswith("analysis"): + self._current_channel = "analysis" + elif after.startswith("final"): + self._current_channel = "final" + elif after.startswith("commentary"): + self._current_channel = "commentary" + + # Handle message start + if "<|message|>" in delta_text: + self._in_message = True + # Don't emit the token itself + return None + + # Handle channel/message end tokens + if any( + token in delta_text + for token in ("<|end|>", "<|return|>", "<|call|>", "<|start|>") + ): + self._in_message = False + return None + + # Skip control tokens + if delta_text.strip().startswith("<|") and delta_text.strip().endswith("|>"): + return None + + # Emit content based on current channel + if self._in_message and self._current_channel == "analysis": + return DeltaMessage(reasoning=delta_text) + + if self._in_message and self._current_channel == "final": + return DeltaMessage(content=delta_text) + + # In commentary or unknown channel, suppress + return None + + def reset_state(self): + """Reset streaming state for a new request.""" + self._current_channel = None + self._in_message = False diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py index 5a8574481..16f744080 100644 --- a/vllm_mlx/tool_parsers/__init__.py +++ b/vllm_mlx/tool_parsers/__init__.py @@ -18,6 +18,7 @@ - xlam: Salesforce xLAM models - functionary/meetkai: MeetKai Functionary models - glm47/glm4: GLM-4.7 and GLM-4.7-Flash models +- harmony/gpt-oss: GPT-OSS models (Harmony format with channels) Usage: from vllm_mlx.tool_parsers import ToolParserManager @@ -55,6 +56,7 @@ from .qwen_tool_parser import QwenToolParser from .xlam_tool_parser import xLAMToolParser from .glm47_tool_parser import Glm47ToolParser +from .harmony_tool_parser import HarmonyToolParser __all__ = [ # Base classes @@ -74,4 +76,5 @@ "xLAMToolParser", "FunctionaryToolParser", "Glm47ToolParser", + "HarmonyToolParser", ] diff --git a/vllm_mlx/tool_parsers/harmony_tool_parser.py b/vllm_mlx/tool_parsers/harmony_tool_parser.py new file mode 100644 index 000000000..34f8555d9 --- /dev/null +++ b/vllm_mlx/tool_parsers/harmony_tool_parser.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Harmony tool call parser for GPT-OSS models. + +Harmony uses control tokens and channels for tool calling: + + <|channel|>commentary to=functions.get_weather + <|constrain|>json + <|message|>{"location": "San Francisco"} + <|call|> + +The final response is in the 'final' channel: + + <|channel|>final + <|message|>The weather is 72F. + <|return|> +""" + +import json +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + + +def _generate_tool_id() -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:8]}" + + +# Pattern: <|channel|>commentary to=functions.tool_name ... <|call|> +_COMMENTARY_BLOCK_PATTERN = re.compile( + r"<\|channel\|>commentary\s+to=functions\.(\w+)" + r"(?:\s*<\|constrain\|>\w+)?" + r"\s*<\|message\|>(.*?)<\|call\|>", + re.DOTALL, +) + +# Pattern: <|channel|>final ... <|return|> +_FINAL_BLOCK_PATTERN = re.compile( + r"<\|channel\|>final\s*<\|message\|>(.*?)<\|return\|>", + re.DOTALL, +) + + +@ToolParserManager.register_module(["harmony", "gpt-oss"]) +class HarmonyToolParser(ToolParser): + """ + Tool call parser for GPT-OSS models using Harmony format. + + Harmony uses control tokens and 3 channels: + - analysis: internal reasoning (handled by reasoning parser) + - commentary: tool calls addressed with to=functions.{name} + - final: user-facing response + + Used when --enable-auto-tool-choice --tool-call-parser harmony are set. + """ + + SUPPORTS_NATIVE_TOOL_FORMAT = False + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + """ + Extract tool calls from a complete Harmony model response. + + Parses commentary channel blocks for tool calls and the final + channel for the user-facing content. + """ + tool_calls = [] + + # Extract tool calls from commentary channel blocks + for match in _COMMENTARY_BLOCK_PATTERN.finditer(model_output): + tool_name = match.group(1) + args_str = match.group(2).strip() + + try: + arguments = json.loads(args_str) + tool_calls.append( + { + "id": _generate_tool_id(), + "name": tool_name, + "arguments": ( + json.dumps(arguments, ensure_ascii=False) + if isinstance(arguments, dict) + else str(arguments) + ), + } + ) + except json.JSONDecodeError: + # Keep the raw arguments string + tool_calls.append( + { + "id": _generate_tool_id(), + "name": tool_name, + "arguments": args_str, + } + ) + + # Extract final channel content + final_match = _FINAL_BLOCK_PATTERN.search(model_output) + content = final_match.group(1).strip() if final_match else None + + if tool_calls: + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content, + ) + + # No tool calls: return all text as content + # If there's a final channel, use that; otherwise return the raw output + # stripped of control tokens + if content is None: + content = _strip_control_tokens(model_output) + + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + """ + Extract tool calls from streaming Harmony model output. + + Waits for <|call|> to complete a tool call, and emits final + channel content as regular content deltas. + """ + # If we see a tool call completion marker in the delta + if "<|call|>" in delta_text: + result = self.extract_tool_calls(current_text) + if result.tools_called: + return { + "tool_calls": [ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + } + + # If we're in the final channel, emit content + if "<|channel|>final" in current_text and "<|call|>" not in current_text: + # Only emit content after <|message|> in the final channel + if "<|message|>" in current_text: + final_start = current_text.rfind("<|channel|>final") + msg_start = current_text.find("<|message|>", final_start) + if msg_start >= 0: + msg_content = current_text[msg_start + len("<|message|>") :] + # Strip trailing control tokens + msg_content = msg_content.replace("<|return|>", "").strip() + if msg_content and not _is_control_token(delta_text): + return {"content": delta_text} + + # If no tool markers at all, pass through as content + if "<|channel|>" not in current_text: + return {"content": delta_text} + + # Building tool call or in analysis channel, suppress output + return None + + +def _strip_control_tokens(text: str) -> str: + """Remove Harmony control tokens from text.""" + tokens = [ + "<|start|>", + "<|end|>", + "<|message|>", + "<|channel|>", + "<|constrain|>", + "<|return|>", + "<|call|>", + ] + result = text + for token in tokens: + result = result.replace(token, "") + # Clean up channel names and constrain values + result = re.sub(r"(?:analysis|commentary|final)\s*", "", result) + result = re.sub(r"to=functions\.\w+\s*", "", result) + result = re.sub(r"json\s*", "", result) + return result.strip() + + +def _is_control_token(text: str) -> bool: + """Check if text is a Harmony control token.""" + return text.strip() in { + "<|start|>", + "<|end|>", + "<|message|>", + "<|channel|>", + "<|constrain|>", + "<|return|>", + "<|call|>", + }