Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/test_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,74 @@ def test_list_with_no_text_parts(self):

def test_empty_list(self):
assert _content_to_text([]) == ""


class TestGptOssSpecialTokens:
"""Tests for GPT-OSS channel token handling in utils."""

def test_pattern_matches_channel_token(self):
assert SPECIAL_TOKENS_PATTERN.search("<|channel|>") is not None

def test_pattern_matches_message_token(self):
assert SPECIAL_TOKENS_PATTERN.search("<|message|>") is not None

def test_pattern_matches_start_token(self):
assert SPECIAL_TOKENS_PATTERN.search("<|start|>") is not None

def test_pattern_matches_return_token(self):
assert SPECIAL_TOKENS_PATTERN.search("<|return|>") is not None

def test_pattern_matches_call_token(self):
assert SPECIAL_TOKENS_PATTERN.search("<|call|>") is not None

def test_clean_output_extracts_final_channel(self):
text = (
"<|channel|>analysis<|message|>Thinking about it"
"<|start|>assistant<|channel|>final<|message|>The answer is 42<|return|>"
)
result = clean_output_text(text)
assert result == "The answer is 42"
assert "<|" not in result

def test_clean_output_final_only(self):
text = "<|channel|>final<|message|>Just the answer<|return|>"
result = clean_output_text(text)
assert result == "Just the answer"

def test_clean_output_strips_return_token(self):
text = "<|channel|>final<|message|>Hello world<|return|>"
result = clean_output_text(text)
assert "<|return|>" not in result
assert result == "Hello world"

def test_clean_output_no_channel_tokens_passthrough(self):
text = "Normal text without any channel tokens."
result = clean_output_text(text)
assert result == text

def test_pattern_matches_constrain_token(self):
assert SPECIAL_TOKENS_PATTERN.search("<|constrain|>") is not None

def test_clean_output_constrain_format(self):
"""Should extract final content from extended constrain format."""
text = (
"<|channel|>analysis<|message|>Thinking"
"<|end|><|channel|>final <|constrain|>JSON<|message|>"
'{"hello":"world"}<|return|>'
)
result = clean_output_text(text)
assert result == '{"hello":"world"}'
assert "<|constrain|>" not in result
assert "<|channel|>" not in result

def test_clean_output_constrain_final_only(self):
"""Should handle constrain format with only final channel."""
text = '<|channel|>final <|constrain|>JSON<|message|>{"key":"value"}<|return|>'
result = clean_output_text(text)
assert result == '{"key":"value"}'

def test_clean_output_no_final_strips_constrain(self):
"""When no final channel found, constrain tokens should be stripped."""
text = "<|channel|>analysis<|message|>Just thinking <|constrain|>something"
result = clean_output_text(text)
assert "<|constrain|>" not in result
241 changes: 241 additions & 0 deletions tests/test_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,244 @@ def test_qwen3_whitespace_between_tags(self, parser):
if expected_reasoning is None:
assert reasoning is None or reasoning.strip() == ""
assert expected_content in (content or "")


class TestGptOssParser:
"""Tests for the GPT-OSS reasoning parser (channel-based format)."""

@pytest.fixture
def parser(self):
"""Create a fresh GPT-OSS parser for each test."""
return get_parser("gpt_oss")()

# Non-streaming tests

def test_extract_both_channels(self, parser):
"""Should extract reasoning from analysis and content from final."""
output = (
"<|channel|>analysis<|message|>Let me think step by step"
"<|start|>assistant<|channel|>final<|message|>The answer is 42<|return|>"
)
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "Let me think step by step"
assert content == "The answer is 42"

def test_extract_only_final(self, parser):
"""Should handle output with only final channel."""
output = "<|channel|>final<|message|>Just the answer<|return|>"
reasoning, content = parser.extract_reasoning(output)
assert reasoning is None
assert content == "Just the answer"

def test_extract_only_analysis(self, parser):
"""Should handle output with only analysis channel."""
output = "<|channel|>analysis<|message|>Just thinking out loud"
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "Just thinking out loud"
assert content is None

def test_no_channel_tokens_fallback(self, parser):
"""No channel tokens should return pure content."""
output = "Just a regular response."
reasoning, content = parser.extract_reasoning(output)
assert reasoning is None
assert content == output

def test_empty_analysis_channel(self, parser):
"""Empty analysis channel should return None reasoning."""
output = (
"<|channel|>analysis<|message|>"
"<|start|>assistant<|channel|>final<|message|>Content here<|return|>"
)
reasoning, content = parser.extract_reasoning(output)
assert reasoning is None
assert content == "Content here"

def test_multiline_analysis(self, parser):
"""Should preserve multiline reasoning content."""
output = (
"<|channel|>analysis<|message|>Step 1: Analyze\nStep 2: Solve\nStep 3: Verify"
"<|start|>assistant<|channel|>final<|message|>Result: 42<|return|>"
)
reasoning, content = parser.extract_reasoning(output)
assert "Step 1" in reasoning
assert "Step 2" in reasoning
assert "Step 3" in reasoning
assert content == "Result: 42"

def test_no_return_token(self, parser):
"""Should handle missing <|return|> at end."""
output = (
"<|channel|>analysis<|message|>Thinking"
"<|start|>assistant<|channel|>final<|message|>Answer"
)
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "Thinking"
assert content == "Answer"

# Streaming tests

def test_streaming_full_flow(self, parser):
"""Test streaming through analysis -> transition -> final phases."""
parser.reset_state()

# Simulate token-by-token streaming
tokens = [
"<|channel|>",
"analysis",
"<|message|>",
"Let me ",
"think",
"<|start|>",
"assistant",
"<|channel|>",
"final",
"<|message|>",
"The answer",
" is 42",
"<|return|>",
]

accumulated = ""
reasoning_parts = []
content_parts = []

for token in tokens:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result:
if result.reasoning:
reasoning_parts.append(result.reasoning)
if result.content:
content_parts.append(result.content)

full_reasoning = "".join(reasoning_parts)
full_content = "".join(content_parts)

assert "Let me think" in full_reasoning
assert "The answer is 42" in full_content

def test_streaming_only_final(self, parser):
"""Test streaming with only final channel."""
parser.reset_state()

tokens = [
"<|channel|>",
"final",
"<|message|>",
"Direct ",
"answer",
"<|return|>",
]

accumulated = ""
content_parts = []

for token in tokens:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result and result.content:
content_parts.append(result.content)

assert "Direct answer" in "".join(content_parts)

def test_streaming_suppresses_structural_tokens(self, parser):
"""Structural tokens should not leak into reasoning or content."""
parser.reset_state()

tokens = [
"<|channel|>analysis<|message|>",
"thinking",
"<|start|>",
"assistant",
"<|channel|>final<|message|>",
"answer",
"<|return|>",
]

accumulated = ""
all_output = []

for token in tokens:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result:
if result.reasoning:
all_output.append(result.reasoning)
if result.content:
all_output.append(result.content)

combined = "".join(all_output)
assert "<|" not in combined

def test_registry_includes_gpt_oss(self):
"""gpt_oss should be in the parser registry."""
assert "gpt_oss" in list_parsers()

def test_extract_constrain_format(self, parser):
"""Should handle extended format with <|constrain|> token."""
output = (
"<|channel|>analysis<|message|>We need to output JSON"
"<|end|><|channel|>final <|constrain|>JSON<|message|>"
'{"hello":"world"}<|return|>'
)
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "We need to output JSON"
assert content == '{"hello":"world"}'

def test_extract_constrain_no_analysis(self, parser):
"""Should handle constrain format with only final channel."""
output = (
'<|channel|>final <|constrain|>JSON<|message|>{"key":"value"}<|return|>'
)
reasoning, content = parser.extract_reasoning(output)
assert reasoning is None
assert content == '{"key":"value"}'

def test_streaming_constrain_format(self, parser):
"""Streaming should handle <|constrain|> in channel marker."""
parser.reset_state()

tokens = [
"<|channel|>analysis<|message|>",
"Thinking...",
"<|end|>",
"<|channel|>final <|constrain|>JSON<|message|>",
'{"result":',
'"ok"}',
"<|return|>",
]

accumulated = ""
reasoning_parts = []
content_parts = []

for token in tokens:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result:
if result.reasoning:
reasoning_parts.append(result.reasoning)
if result.content:
content_parts.append(result.content)

full_reasoning = "".join(reasoning_parts)
full_content = "".join(content_parts)

assert "Thinking" in full_reasoning
assert '{"result":"ok"}' in full_content
assert "<|constrain|>" not in full_content

def test_constrain_tokens_stripped(self, parser):
"""<|constrain|> should not leak into output."""
output = (
"<|channel|>final <|constrain|>JSON<|message|>"
'{"hello":"world"}<|return|>'
)
reasoning, content = parser.extract_reasoning(output)
assert "<|constrain|>" not in (content or "")
assert "<|channel|>" not in (content or "")
54 changes: 54 additions & 0 deletions vllm_mlx/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,65 @@
SPECIAL_TOKENS_PATTERN = re.compile(
r"<\|im_end\|>|<\|im_start\|>|<\|endoftext\|>|"
r"<\|end\|>|<\|eot_id\|>|<\|start_header_id\|>|<\|end_header_id\|>|"
r"<\|channel\|>|<\|message\|>|<\|start\|>|<\|return\|>|<\|call\|>|<\|constrain\|>|"
r"</s>|<s>|<pad>|\[PAD\]|\[SEP\]|\[CLS\]"
)


# Regex for matching final channel marker with optional constrain token:
# <|channel|>final<|message|>
# <|channel|>final <|constrain|>JSON<|message|>
_FINAL_CHANNEL_RE = re.compile(
r"<\|channel\|>final[^<]*(?:<\|constrain\|>[^<]*)?<\|message\|>"
)


def _clean_gpt_oss_output(text: str) -> str:
"""
Extract final channel content from GPT-OSS channel-based output.

When reasoning parser is not enabled, this provides a fallback that
extracts the 'final' channel content so the API response is usable.

Handles both standard and extended format with constrain token:
<|channel|>final<|message|>...
<|channel|>final <|constrain|>JSON<|message|>...

Args:
text: Raw model output containing channel tokens.

Returns:
Extracted final content, or text with channel tokens stripped.
"""
match = _FINAL_CHANNEL_RE.search(text)
if match:
content = text[match.end() :]
# Strip trailing structural tokens (including <|constrain|>)
content = re.sub(
r"<\|start\|>|<\|end\|>|<\|channel\|>|<\|return\|>|<\|call\|>|<\|message\|>|<\|constrain\|>",
"",
content,
)
return content.strip()

# No final channel — strip all channel/structural tokens (including constrain)
cleaned = re.sub(
r"<\|channel\|>[^<]*(?:<\|constrain\|>[^<]*)?<\|message\|>|<\|start\|>[^<]*|<\|return\|>|<\|call\|>|<\|constrain\|>[^<]*",
"",
text,
)
return cleaned.strip()


def clean_output_text(text: str) -> str:
"""
Clean model output by removing special tokens.

Keeps <think>...</think> blocks intact for reasoning models.
Adds opening <think> tag if missing (happens when thinking is enabled
in the prompt template but the tag is part of the prompt, not output).
Handles GPT-OSS channel-based format as fallback when reasoning parser
is not enabled.

Args:
text: Raw model output
Expand All @@ -36,6 +84,12 @@ def clean_output_text(text: str) -> str:
"""
if not text:
return text

# GPT-OSS channel format — extract final content before general stripping
if "<|channel|>" in text and "<|message|>" in text:
text = _clean_gpt_oss_output(text)
return text

text = SPECIAL_TOKENS_PATTERN.sub("", text)
text = text.strip()

Expand Down
Loading
Loading