diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 7202f625f..b06b48971 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -12,6 +12,10 @@ class TestSimpleEngineConcurrency: """Test SimpleEngine lock behavior with concurrent requests.""" + @pytest.fixture + def anyio_backend(self): + return "asyncio" + @pytest.fixture def mock_model(self): """Create a mock model that tracks concurrent calls.""" @@ -117,6 +121,55 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model): "The lock is not working correctly." ) + async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model): + """Tool-enabled non-stream chat should use the streaming path.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_stream_chat(*args, **kwargs): + yield MagicMock( + text="partial", + tokens=[1], + prompt_tokens=11, + completion_tokens=1, + finish_reason=None, + finished=False, + ) + yield MagicMock( + text='<|im_end|>{"name":"bash","arguments":{"command":"pwd"}}', + tokens=[7, 8, 9], + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + finished=True, + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = mock_llm_model + engine._loaded = True + engine.stream_chat = fake_stream_chat # type: ignore[method-assign] + + output = await engine.chat( + messages=[{"role": "user", "content": "run pwd"}], + max_tokens=16, + tools=[ + { + "type": "function", + "function": { + "name": "bash", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + assert output.text == '{"name":"bash","arguments":{"command":"pwd"}}' + assert output.tokens == [7, 8, 9] + assert output.prompt_tokens == 11 + assert output.completion_tokens == 4 + assert output.finish_reason == "stop" + mock_llm_model.chat.assert_not_called() + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 39cfa849d..b93c20c0a 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -453,6 +453,32 @@ async def chat( if not self._loaded: await self.start() + # mlx-lm non-streaming chat with tools can stall indefinitely on some + # local models, while the streaming path completes normally. Reuse the + # streaming implementation and aggregate its final state so both chat + # APIs share the same tool-capable execution path. + if tools and not self._is_mllm: + final_output = GenerationOutput(text="") + async for output in self.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + **kwargs, + ): + final_output = output + text = clean_output_text(final_output.text) + return GenerationOutput( + text=text, + tokens=list(final_output.tokens), + prompt_tokens=final_output.prompt_tokens, + completion_tokens=final_output.completion_tokens, + finish_reason=final_output.finish_reason, + ) + # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None