diff --git a/tests/test_api_utils.py b/tests/test_api_utils.py index 7faa67ec7..71e1958c9 100644 --- a/tests/test_api_utils.py +++ b/tests/test_api_utils.py @@ -521,3 +521,74 @@ def test_list_with_no_text_parts(self): def test_empty_list(self): assert _content_to_text([]) == "" + + +class TestGptOssSpecialTokens: + """Tests for GPT-OSS channel token handling in utils.""" + + def test_pattern_matches_channel_token(self): + assert SPECIAL_TOKENS_PATTERN.search("<|channel|>") is not None + + def test_pattern_matches_message_token(self): + assert SPECIAL_TOKENS_PATTERN.search("<|message|>") is not None + + def test_pattern_matches_start_token(self): + assert SPECIAL_TOKENS_PATTERN.search("<|start|>") is not None + + def test_pattern_matches_return_token(self): + assert SPECIAL_TOKENS_PATTERN.search("<|return|>") is not None + + def test_pattern_matches_call_token(self): + assert SPECIAL_TOKENS_PATTERN.search("<|call|>") is not None + + def test_clean_output_extracts_final_channel(self): + text = ( + "<|channel|>analysis<|message|>Thinking about it" + "<|start|>assistant<|channel|>final<|message|>The answer is 42<|return|>" + ) + result = clean_output_text(text) + assert result == "The answer is 42" + assert "<|" not in result + + def test_clean_output_final_only(self): + text = "<|channel|>final<|message|>Just the answer<|return|>" + result = clean_output_text(text) + assert result == "Just the answer" + + def test_clean_output_strips_return_token(self): + text = "<|channel|>final<|message|>Hello world<|return|>" + result = clean_output_text(text) + assert "<|return|>" not in result + assert result == "Hello world" + + def test_clean_output_no_channel_tokens_passthrough(self): + text = "Normal text without any channel tokens." + result = clean_output_text(text) + assert result == text + + def test_pattern_matches_constrain_token(self): + assert SPECIAL_TOKENS_PATTERN.search("<|constrain|>") is not None + + def test_clean_output_constrain_format(self): + """Should extract final content from extended constrain format.""" + text = ( + "<|channel|>analysis<|message|>Thinking" + "<|end|><|channel|>final <|constrain|>JSON<|message|>" + '{"hello":"world"}<|return|>' + ) + result = clean_output_text(text) + assert result == '{"hello":"world"}' + assert "<|constrain|>" not in result + assert "<|channel|>" not in result + + def test_clean_output_constrain_final_only(self): + """Should handle constrain format with only final channel.""" + text = '<|channel|>final <|constrain|>JSON<|message|>{"key":"value"}<|return|>' + result = clean_output_text(text) + assert result == '{"key":"value"}' + + def test_clean_output_no_final_strips_constrain(self): + """When no final channel found, constrain tokens should be stripped.""" + text = "<|channel|>analysis<|message|>Just thinking <|constrain|>something" + result = clean_output_text(text) + assert "<|constrain|>" not in result diff --git a/tests/test_reasoning_parser.py b/tests/test_reasoning_parser.py index ad5399858..e2d0184e7 100644 --- a/tests/test_reasoning_parser.py +++ b/tests/test_reasoning_parser.py @@ -679,3 +679,244 @@ def test_qwen3_whitespace_between_tags(self, parser): if expected_reasoning is None: assert reasoning is None or reasoning.strip() == "" assert expected_content in (content or "") + + +class TestGptOssParser: + """Tests for the GPT-OSS reasoning parser (channel-based format).""" + + @pytest.fixture + def parser(self): + """Create a fresh GPT-OSS parser for each test.""" + return get_parser("gpt_oss")() + + # Non-streaming tests + + def test_extract_both_channels(self, parser): + """Should extract reasoning from analysis and content from final.""" + output = ( + "<|channel|>analysis<|message|>Let me think step by step" + "<|start|>assistant<|channel|>final<|message|>The answer is 42<|return|>" + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Let me think step by step" + assert content == "The answer is 42" + + def test_extract_only_final(self, parser): + """Should handle output with only final channel.""" + output = "<|channel|>final<|message|>Just the answer<|return|>" + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == "Just the answer" + + def test_extract_only_analysis(self, parser): + """Should handle output with only analysis channel.""" + output = "<|channel|>analysis<|message|>Just thinking out loud" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Just thinking out loud" + assert content is None + + def test_no_channel_tokens_fallback(self, parser): + """No channel tokens should return pure content.""" + output = "Just a regular response." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == output + + def test_empty_analysis_channel(self, parser): + """Empty analysis channel should return None reasoning.""" + output = ( + "<|channel|>analysis<|message|>" + "<|start|>assistant<|channel|>final<|message|>Content here<|return|>" + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == "Content here" + + def test_multiline_analysis(self, parser): + """Should preserve multiline reasoning content.""" + output = ( + "<|channel|>analysis<|message|>Step 1: Analyze\nStep 2: Solve\nStep 3: Verify" + "<|start|>assistant<|channel|>final<|message|>Result: 42<|return|>" + ) + reasoning, content = parser.extract_reasoning(output) + assert "Step 1" in reasoning + assert "Step 2" in reasoning + assert "Step 3" in reasoning + assert content == "Result: 42" + + def test_no_return_token(self, parser): + """Should handle missing <|return|> at end.""" + output = ( + "<|channel|>analysis<|message|>Thinking" + "<|start|>assistant<|channel|>final<|message|>Answer" + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Thinking" + assert content == "Answer" + + # Streaming tests + + def test_streaming_full_flow(self, parser): + """Test streaming through analysis -> transition -> final phases.""" + parser.reset_state() + + # Simulate token-by-token streaming + tokens = [ + "<|channel|>", + "analysis", + "<|message|>", + "Let me ", + "think", + "<|start|>", + "assistant", + "<|channel|>", + "final", + "<|message|>", + "The answer", + " is 42", + "<|return|>", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + assert "Let me think" in full_reasoning + assert "The answer is 42" in full_content + + def test_streaming_only_final(self, parser): + """Test streaming with only final channel.""" + parser.reset_state() + + tokens = [ + "<|channel|>", + "final", + "<|message|>", + "Direct ", + "answer", + "<|return|>", + ] + + accumulated = "" + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result and result.content: + content_parts.append(result.content) + + assert "Direct answer" in "".join(content_parts) + + def test_streaming_suppresses_structural_tokens(self, parser): + """Structural tokens should not leak into reasoning or content.""" + parser.reset_state() + + tokens = [ + "<|channel|>analysis<|message|>", + "thinking", + "<|start|>", + "assistant", + "<|channel|>final<|message|>", + "answer", + "<|return|>", + ] + + accumulated = "" + all_output = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + all_output.append(result.reasoning) + if result.content: + all_output.append(result.content) + + combined = "".join(all_output) + assert "<|" not in combined + + def test_registry_includes_gpt_oss(self): + """gpt_oss should be in the parser registry.""" + assert "gpt_oss" in list_parsers() + + def test_extract_constrain_format(self, parser): + """Should handle extended format with <|constrain|> token.""" + output = ( + "<|channel|>analysis<|message|>We need to output JSON" + "<|end|><|channel|>final <|constrain|>JSON<|message|>" + '{"hello":"world"}<|return|>' + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "We need to output JSON" + assert content == '{"hello":"world"}' + + def test_extract_constrain_no_analysis(self, parser): + """Should handle constrain format with only final channel.""" + output = ( + '<|channel|>final <|constrain|>JSON<|message|>{"key":"value"}<|return|>' + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == '{"key":"value"}' + + def test_streaming_constrain_format(self, parser): + """Streaming should handle <|constrain|> in channel marker.""" + parser.reset_state() + + tokens = [ + "<|channel|>analysis<|message|>", + "Thinking...", + "<|end|>", + "<|channel|>final <|constrain|>JSON<|message|>", + '{"result":', + '"ok"}', + "<|return|>", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + assert "Thinking" in full_reasoning + assert '{"result":"ok"}' in full_content + assert "<|constrain|>" not in full_content + + def test_constrain_tokens_stripped(self, parser): + """<|constrain|> should not leak into output.""" + output = ( + "<|channel|>final <|constrain|>JSON<|message|>" + '{"hello":"world"}<|return|>' + ) + reasoning, content = parser.extract_reasoning(output) + assert "<|constrain|>" not in (content or "") + assert "<|channel|>" not in (content or "") diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 40af0e64b..e916ae199 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -16,10 +16,56 @@ SPECIAL_TOKENS_PATTERN = re.compile( r"<\|im_end\|>|<\|im_start\|>|<\|endoftext\|>|" r"<\|end\|>|<\|eot_id\|>|<\|start_header_id\|>|<\|end_header_id\|>|" + r"<\|channel\|>|<\|message\|>|<\|start\|>|<\|return\|>|<\|call\|>|<\|constrain\|>|" r"|||\[PAD\]|\[SEP\]|\[CLS\]" ) +# Regex for matching final channel marker with optional constrain token: +# <|channel|>final<|message|> +# <|channel|>final <|constrain|>JSON<|message|> +_FINAL_CHANNEL_RE = re.compile( + r"<\|channel\|>final[^<]*(?:<\|constrain\|>[^<]*)?<\|message\|>" +) + + +def _clean_gpt_oss_output(text: str) -> str: + """ + Extract final channel content from GPT-OSS channel-based output. + + When reasoning parser is not enabled, this provides a fallback that + extracts the 'final' channel content so the API response is usable. + + Handles both standard and extended format with constrain token: + <|channel|>final<|message|>... + <|channel|>final <|constrain|>JSON<|message|>... + + Args: + text: Raw model output containing channel tokens. + + Returns: + Extracted final content, or text with channel tokens stripped. + """ + match = _FINAL_CHANNEL_RE.search(text) + if match: + content = text[match.end() :] + # Strip trailing structural tokens (including <|constrain|>) + content = re.sub( + r"<\|start\|>|<\|end\|>|<\|channel\|>|<\|return\|>|<\|call\|>|<\|message\|>|<\|constrain\|>", + "", + content, + ) + return content.strip() + + # No final channel — strip all channel/structural tokens (including constrain) + cleaned = re.sub( + r"<\|channel\|>[^<]*(?:<\|constrain\|>[^<]*)?<\|message\|>|<\|start\|>[^<]*|<\|return\|>|<\|call\|>|<\|constrain\|>[^<]*", + "", + text, + ) + return cleaned.strip() + + def clean_output_text(text: str) -> str: """ Clean model output by removing special tokens. @@ -27,6 +73,8 @@ def clean_output_text(text: str) -> str: Keeps ... blocks intact for reasoning models. Adds opening tag if missing (happens when thinking is enabled in the prompt template but the tag is part of the prompt, not output). + Handles GPT-OSS channel-based format as fallback when reasoning parser + is not enabled. Args: text: Raw model output @@ -36,6 +84,12 @@ def clean_output_text(text: str) -> str: """ if not text: return text + + # GPT-OSS channel format — extract final content before general stripping + if "<|channel|>" in text and "<|message|>" in text: + text = _clean_gpt_oss_output(text) + return text + text = SPECIAL_TOKENS_PATTERN.sub("", text) text = text.strip() diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py index 55daa9e8d..f138796ff 100644 --- a/vllm_mlx/reasoning/__init__.py +++ b/vllm_mlx/reasoning/__init__.py @@ -76,11 +76,13 @@ def list_parsers() -> list[str]: def _register_builtin_parsers(): """Register built-in parsers.""" from .deepseek_r1_parser import DeepSeekR1ReasoningParser + from .gpt_oss_parser import GptOssReasoningParser from .harmony_parser import HarmonyReasoningParser from .qwen3_parser import Qwen3ReasoningParser register_parser("qwen3", Qwen3ReasoningParser) register_parser("deepseek_r1", DeepSeekR1ReasoningParser) + register_parser("gpt_oss", GptOssReasoningParser) register_parser("harmony", HarmonyReasoningParser) diff --git a/vllm_mlx/reasoning/gpt_oss_parser.py b/vllm_mlx/reasoning/gpt_oss_parser.py new file mode 100644 index 000000000..8541faf25 --- /dev/null +++ b/vllm_mlx/reasoning/gpt_oss_parser.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for GPT-OSS models using channel-based format. + +GPT-OSS models use a channel-based token format instead of ... tags: + <|channel|>analysis<|message|>[reasoning]<|start|>assistant<|channel|>final<|message|>[content]<|return|> + +Some models also emit an extended format with a constrain token: + <|channel|>final <|constrain|>JSON<|message|>[content]<|return|> + +This parser extracts reasoning from the 'analysis' channel and content from +the 'final' channel, stripping all structural tokens from API responses. +""" + +import re + +from .base import DeltaMessage, ReasoningParser + +# Structural tokens that should be stripped from output +_STRUCTURAL_TOKENS = re.compile( + r"<\|start\|>|<\|end\|>|<\|channel\|>|<\|return\|>|<\|call\|>|<\|constrain\|>" +) + +# Flexible channel marker regex — matches both standard and extended formats: +# <|channel|>analysis<|message|> +# <|channel|>final<|message|> +# <|channel|>final <|constrain|>JSON<|message|> +_CHANNEL_RE = re.compile( + r"<\|channel\|>(analysis|final)(?:[^<]*(?:<\|constrain\|>[^<]*)?)?<\|message\|>" +) + + +def _extract_channel(text: str, channel_name: str) -> str | None: + """ + Extract content from a named channel. + + Finds <|channel|>{name}...<|message|> (with optional constrain token) + and extracts text up to the next structural token or end of string. + + Args: + text: Full model output text. + channel_name: Channel name to extract (e.g., "analysis", "final"). + + Returns: + Extracted channel content, or None if channel not found. + """ + for m in _CHANNEL_RE.finditer(text): + if m.group(1) == channel_name: + start = m.end() + # Find next structural token after message content + end_match = _STRUCTURAL_TOKENS.search(text, start) + content = text[start : end_match.start()] if end_match else text[start:] + content = content.strip() + return content if content else None + return None + + +class GptOssReasoningParser(ReasoningParser): + """ + Reasoning parser for GPT-OSS models. + + GPT-OSS uses channel-based tokens: + <|channel|>analysis<|message|>[reasoning] + <|start|>assistant<|channel|>final<|message|>[content]<|return|> + + The 'analysis' channel maps to reasoning, 'final' to content. + + Also handles extended format with constrain token: + <|channel|>final <|constrain|>JSON<|message|>[content]<|return|> + """ + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning and content from complete model output. + + Args: + model_output: Complete text output from the model. + + Returns: + (reasoning, content) tuple. Either may be None. + """ + if not model_output or "<|channel|>" not in model_output: + return None, model_output if model_output else None + + reasoning = _extract_channel(model_output, "analysis") + content = _extract_channel(model_output, "final") + + # Strip trailing <|return|> + if content: + content = content.replace("<|return|>", "").strip() + content = _STRUCTURAL_TOKENS.sub("", content).strip() + content = content if content else None + + # Strip any remaining structural tokens from reasoning + if reasoning: + reasoning = _STRUCTURAL_TOKENS.sub("", reasoning).strip() + reasoning = reasoning if reasoning else None + + # If no channels found, return as plain content + if reasoning is None and content is None: + return None, model_output + + return reasoning, content + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming delta. + + Uses stateless phase detection from current_text on each call. + + Args: + previous_text: Accumulated text before this delta. + current_text: Accumulated text including this delta. + delta_text: Just the new text in this streaming chunk. + + Returns: + DeltaMessage with reasoning and/or content, or None to skip. + """ + prev_phase = self._detect_phase(previous_text) + curr_phase = self._detect_phase(current_text) + + # Phase changed — extract content after the new marker + if curr_phase != prev_phase and curr_phase in ("analysis", "final"): + after_marker = self._extract_content_after_marker_in_delta( + current_text, curr_phase + ) + if after_marker: + after_marker = self._strip_return(after_marker) + if curr_phase == "analysis": + return DeltaMessage(reasoning=after_marker) + else: + return DeltaMessage(content=after_marker) + return None + + # In a steady phase — emit delta directly + if curr_phase == "analysis": + cleaned = self._strip_return(delta_text) + # Skip structural tokens in the delta + if _STRUCTURAL_TOKENS.search(cleaned): + cleaned = _STRUCTURAL_TOKENS.sub("", cleaned) + if cleaned: + return DeltaMessage(reasoning=cleaned) + return None + elif curr_phase == "final": + cleaned = self._strip_return(delta_text) + if _STRUCTURAL_TOKENS.search(cleaned): + cleaned = _STRUCTURAL_TOKENS.sub("", cleaned) + if cleaned: + return DeltaMessage(content=cleaned) + return None + + # init or transition phase — skip structural tokens + return None + + @staticmethod + def _detect_phase(text: str) -> str: + """ + Detect current streaming phase from accumulated text. + + Returns: + "final" — final channel marker complete + "analysis" — analysis marker complete, no structural token after + "transition" — analysis present but structural token follows + "init" — no channel marker yet + """ + # Find all channel markers in text + matches = list(_CHANNEL_RE.finditer(text)) + if not matches: + return "init" + + last = matches[-1] + if last.group(1) == "final": + return "final" + + # analysis channel found — check if there's a structural token after + after = text[last.end() :] + if _STRUCTURAL_TOKENS.search(after): + return "transition" + return "analysis" + + @staticmethod + def _extract_content_after_marker_in_delta( + current_text: str, phase: str + ) -> str | None: + """ + When phase changes, extract only the content after the phase marker + that falls within the current accumulated text's tail. + + Args: + current_text: Full accumulated text. + phase: Current phase ("analysis" or "final"). + + Returns: + Content after the marker, or None. + """ + channel_name = "analysis" if phase == "analysis" else "final" + matches = list(_CHANNEL_RE.finditer(current_text)) + for m in reversed(matches): + if m.group(1) == channel_name: + return current_text[m.end() :] + return None + + @staticmethod + def _strip_return(text: str) -> str: + """Strip <|return|> from text.""" + return text.replace("<|return|>", "") diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 7b131067c..138c166ab 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1387,18 +1387,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re # Parse tool calls from output using configured parser cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, request) - # Process response_format if specified - if response_format and not tool_calls: - cleaned_text, parsed_json, is_valid, error = parse_json_output( - cleaned_text or output.text, response_format - ) - if parsed_json is not None: - # Return JSON as string - cleaned_text = json.dumps(parsed_json) - if not is_valid: - logger.warning(f"JSON validation failed: {error}") - - # Extract reasoning content if parser is enabled + # Extract reasoning content FIRST (strips channel tokens before JSON extraction) reasoning_text = None if _reasoning_parser and not tool_calls: text_to_parse = cleaned_text or output.text @@ -1406,6 +1395,16 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re text_to_parse ) + # Process response_format if specified (after reasoning parser cleaned the text) + if response_format and not tool_calls: + json_input = cleaned_text or output.text + _, parsed_json, is_valid, error = parse_json_output(json_input, response_format) + if parsed_json is not None: + # Return JSON as string + cleaned_text = json.dumps(parsed_json) + if not is_valid: + logger.warning(f"JSON validation failed: {error}") + # Determine finish reason finish_reason = "tool_calls" if tool_calls else output.finish_reason