Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/test_simple_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
class TestSimpleEngineConcurrency:
"""Test SimpleEngine lock behavior with concurrent requests."""

@pytest.fixture
def anyio_backend(self):
return "asyncio"

@pytest.fixture
def mock_model(self):
"""Create a mock model that tracks concurrent calls."""
Expand Down Expand Up @@ -117,6 +121,55 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"The lock is not working correctly."
)

async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model):
"""Tool-enabled non-stream chat should use the streaming path."""
from vllm_mlx.engine.simple import SimpleEngine

async def fake_stream_chat(*args, **kwargs):
yield MagicMock(
text="partial",
tokens=[1],
prompt_tokens=11,
completion_tokens=1,
finish_reason=None,
finished=False,
)
yield MagicMock(
text='<|im_end|><tool_call>{"name":"bash","arguments":{"command":"pwd"}}</tool_call>',
tokens=[7, 8, 9],
prompt_tokens=11,
completion_tokens=4,
finish_reason="stop",
finished=True,
)

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._model = mock_llm_model
engine._loaded = True
engine.stream_chat = fake_stream_chat # type: ignore[method-assign]

output = await engine.chat(
messages=[{"role": "user", "content": "run pwd"}],
max_tokens=16,
tools=[
{
"type": "function",
"function": {
"name": "bash",
"parameters": {"type": "object", "properties": {}},
},
}
],
)

assert output.text == '{"name":"bash","arguments":{"command":"pwd"}}'
assert output.tokens == [7, 8, 9]
assert output.prompt_tokens == 11
assert output.completion_tokens == 4
assert output.finish_reason == "stop"
mock_llm_model.chat.assert_not_called()

@pytest.mark.anyio
async def test_lock_serializes_stream_generate(self, mock_model):
"""Test that stream_generate uses the same lock as other methods."""
Expand Down
26 changes: 26 additions & 0 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,32 @@ async def chat(
if not self._loaded:
await self.start()

# mlx-lm non-streaming chat with tools can stall indefinitely on some
# local models, while the streaming path completes normally. Reuse the
# streaming implementation and aggregate its final state so both chat
# APIs share the same tool-capable execution path.
if tools and not self._is_mllm:
final_output = GenerationOutput(text="")
async for output in self.stream_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
tools=tools,
images=images,
videos=videos,
**kwargs,
):
final_output = output
text = clean_output_text(final_output.text)
return GenerationOutput(
text=text,
tokens=list(final_output.tokens),
prompt_tokens=final_output.prompt_tokens,
completion_tokens=final_output.completion_tokens,
finish_reason=final_output.finish_reason,
)

# Convert tools for template if provided
template_tools = convert_tools_for_template(tools) if tools else None

Expand Down
Loading