Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions tests/test_engine_tool_output_preservation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests that tool-enabled chat preserves raw parser-visible output."""

from unittest.mock import AsyncMock, MagicMock, patch

import pytest


@pytest.fixture
def anyio_backend():
return "asyncio"


class TestSimpleEngineToolOutputPreservation:
@pytest.mark.anyio
async def test_chat_with_tools_preserves_raw_harmony_output(self):
from vllm_mlx.engine.simple import SimpleEngine

async def fake_stream_chat(*args, **kwargs):
yield MagicMock(
text=(
"<|channel|>commentary to=functions.get_weather"
'<|message|>{"city":"Paris"}<|call|>'
),
tokens=[],
prompt_tokens=11,
completion_tokens=4,
finish_reason="stop",
finished=True,
)

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._model = MagicMock()
engine._loaded = True
engine.stream_chat = fake_stream_chat # type: ignore[method-assign]

output = await engine.chat(
messages=[{"role": "user", "content": "Weather in Paris?"}],
tools=[
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {"type": "object", "properties": {}},
},
}
],
)

assert "<|channel|>commentary" in output.text
assert "<|call|>" in output.text


class TestBatchedEngineToolOutputPreservation:
@pytest.mark.anyio
async def test_chat_with_tools_preserves_raw_output(self):
from vllm_mlx.engine.batched import BatchedEngine

raw_output = (
"<|channel|>commentary to=functions.get_weather"
'<|message|>{"city":"Paris"}<|call|>'
)

with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False):
engine = BatchedEngine("test-model")
engine._loaded = True
engine._tokenizer = MagicMock()
engine._apply_chat_template = MagicMock(return_value="prompt")
engine._engine = MagicMock()
engine._engine.generate = AsyncMock(
return_value=MagicMock(
output_text=raw_output,
prompt_tokens=9,
completion_tokens=3,
finish_reason="stop",
)
)

output = await engine.chat(
messages=[{"role": "user", "content": "Weather in Paris?"}],
tools=[
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {"type": "object", "properties": {}},
},
}
],
)

assert output.text == raw_output
engine._engine.generate.assert_called_once()
80 changes: 75 additions & 5 deletions tests/test_harmony_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,23 @@ def test_nested_json_arguments(self, parser):
parsed_args = json.loads(result.tool_calls[0]["arguments"])
assert parsed_args["filter"]["type"] == "range"

def test_consecutive_duplicate_tool_calls_are_deduped(self, parser):
"""Repeated identical commentary blocks should collapse to one call."""
text = (
"<|channel|>commentary to=functions.get_weather\n"
"<|constrain|>json\n"
'<|message|>{"city":"Paris"}\n'
"<|call|>\n"
"<|channel|>commentary to=functions.get_weather\n"
"<|constrain|>json\n"
'<|message|>{"city":"Paris"}\n'
"<|call|>"
)
result = parser.extract_tool_calls(text)

assert result.tools_called
assert len(result.tool_calls) == 1

def test_streaming_no_tool_markers(self, parser):
"""Streaming: plain text passes through as content."""
result = parser.extract_tool_calls_streaming("", "Hello", "Hello")
Expand Down Expand Up @@ -227,6 +244,25 @@ def test_streaming_building_tool_call(self, parser):
result = parser.extract_tool_calls_streaming("", current, '{"a":')
assert result is None

def test_streaming_duplicate_tool_call_is_not_reemitted(self, parser):
"""Streaming should only emit newly completed Harmony tool calls."""
previous = (
"<|channel|>commentary to=functions.func\n"
"<|constrain|>json\n"
'<|message|>{"a": 1}\n'
"<|call|>"
)
current = (
previous
+ "\n<|channel|>commentary to=functions.func\n"
+ "<|constrain|>json\n"
+ '<|message|>{"a": 1}\n'
+ "<|call|>"
)

result = parser.extract_tool_calls_streaming(previous, current, "<|call|>")
assert result is None


# ============================================================================
# Reasoning Parser Tests
Expand Down Expand Up @@ -388,26 +424,60 @@ def test_streaming_reset(self, parser):
assert parser._current_channel is None
assert parser._in_message is False

def test_streaming_commentary_suppressed(self, parser):
"""Streaming: commentary channel output is suppressed."""
def test_streaming_commentary_routed_as_content(self, parser):
"""Streaming: commentary channel is forwarded for downstream tool parsing."""
parser.reset_state()

parser.extract_reasoning_streaming(
r1 = parser.extract_reasoning_streaming(
"",
"<|channel|>commentary to=functions.f\n",
"<|channel|>commentary to=functions.f\n",
)
parser.extract_reasoning_streaming(
assert r1 is not None
assert r1.content == "<|channel|>commentary to=functions.f\n"

r2 = parser.extract_reasoning_streaming(
"<|channel|>commentary to=functions.f\n",
"<|channel|>commentary to=functions.f\n<|message|>",
"<|message|>",
)
assert r2 is not None
assert r2.content == "<|message|>"

r = parser.extract_reasoning_streaming(
"<|channel|>commentary to=functions.f\n<|message|>",
'<|channel|>commentary to=functions.f\n<|message|>{"a":1}',
'{"a":1}',
)
assert r is None
assert r is not None
assert r.content == '{"a":1}'

def test_streaming_split_commentary_header(self, parser):
"""Split commentary headers should still be routed to downstream tool parsing."""
parser.reset_state()

accumulated = ""
content_parts = []
for token in [
"<|channel|>",
"comment",
"ary to",
"=functions.get_weather",
" <|constrain|>",
"json",
"<|message|>",
'{"city":"Paris"}',
"<|call|>",
]:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result and result.content:
content_parts.append(result.content)

combined = "".join(content_parts)
assert "<|channel|>commentary to=functions.get_weather" in combined
assert '<|message|>{"city":"Paris"}' in combined


# ============================================================================
Expand Down
64 changes: 59 additions & 5 deletions tests/test_simple_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
class TestSimpleEngineConcurrency:
"""Test SimpleEngine lock behavior with concurrent requests."""

@pytest.fixture
def anyio_backend(self):
return "asyncio"

@pytest.fixture
def mock_model(self):
"""Create a mock model that tracks concurrent calls."""
Expand Down Expand Up @@ -65,7 +69,7 @@ def chat_side_effect(**kwargs):
model.chat = MagicMock(side_effect=chat_side_effect)
return model

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_lock_prevents_concurrent_generate(self, mock_model):
"""Test that the lock prevents concurrent generate calls."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -89,7 +93,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model):
"The lock is not working correctly."
)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"""Test that the lock prevents concurrent chat calls."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -115,7 +119,57 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"The lock is not working correctly."
)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model):
"""Tool-enabled non-stream chat should use the streaming path."""
from vllm_mlx.engine.simple import SimpleEngine

async def fake_stream_chat(*args, **kwargs):
yield MagicMock(
text="partial",
tokens=[],
prompt_tokens=11,
completion_tokens=1,
finish_reason=None,
finished=False,
)
yield MagicMock(
text="<|im_end|><tool_call>{\"name\":\"bash\",\"arguments\":{\"command\":\"pwd\"}}</tool_call>",
tokens=[],
prompt_tokens=11,
completion_tokens=4,
finish_reason="stop",
finished=True,
)

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._model = mock_llm_model
engine._loaded = True
engine.stream_chat = fake_stream_chat # type: ignore[method-assign]

output = await engine.chat(
messages=[{"role": "user", "content": "run pwd"}],
max_tokens=16,
tools=[
{
"type": "function",
"function": {
"name": "bash",
"parameters": {"type": "object", "properties": {}},
},
}
],
)

assert output.text.startswith("<|im_end|><tool_call>")
assert output.tokens == []
assert output.prompt_tokens == 11
assert output.completion_tokens == 4
assert output.finish_reason == "stop"
mock_llm_model.chat.assert_not_called()

@pytest.mark.anyio
async def test_lock_serializes_stream_generate(self, mock_model):
"""Test that stream_generate uses the same lock as other methods."""
from vllm_mlx.engine.simple import SimpleEngine
Expand Down Expand Up @@ -178,7 +232,7 @@ async def try_stream():
result = await stream_task
assert len(result) == 3, f"Expected 3 chunks, got {len(result)}"

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_engine_initialization_creates_lock(self):
"""Test that SimpleEngine creates a lock on initialization."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -189,7 +243,7 @@ async def test_engine_initialization_creates_lock(self):
assert hasattr(engine, "_generation_lock")
assert isinstance(engine._generation_lock, asyncio.Lock)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_requests_complete_in_order(self, mock_model):
"""Test that concurrent requests complete (may be in any order due to lock)."""
from vllm_mlx.engine.simple import SimpleEngine
Expand Down
Loading