diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index cce42bfc3..a5399993b 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -10,6 +10,10 @@ class TestSimpleEngineConcurrency: """Test SimpleEngine lock behavior with concurrent requests.""" + @pytest.fixture + def anyio_backend(self): + return "asyncio" + @pytest.fixture def mock_model(self): """Create a mock model that tracks concurrent calls.""" @@ -65,7 +69,7 @@ def chat_side_effect(**kwargs): model.chat = MagicMock(side_effect=chat_side_effect) return model - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_generate(self, mock_model): """Test that the lock prevents concurrent generate calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -89,7 +93,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_chat(self, mock_llm_model): """Test that the lock prevents concurrent chat calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -115,7 +119,61 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio + async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model): + """Tool-enabled non-stream chat should use the streaming path.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_stream_chat(*args, **kwargs): + yield MagicMock( + text="partial", + tokens=[], + prompt_tokens=11, + completion_tokens=1, + finish_reason=None, + finished=False, + ) + yield MagicMock( + text="<|im_end|>{\"name\":\"bash\",\"arguments\":{\"command\":\"pwd\"}}", + tokens=[], + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + finished=True, + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = mock_llm_model + engine._loaded = True + engine._model.tokenizer.encode = MagicMock(return_value=[7, 8, 9]) + engine.stream_chat = fake_stream_chat # type: ignore[method-assign] + + output = await engine.chat( + messages=[{"role": "user", "content": "run pwd"}], + max_tokens=16, + tools=[ + { + "type": "function", + "function": { + "name": "bash", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + assert output.text.startswith("") + assert output.tokens == [7, 8, 9] + assert output.prompt_tokens == 11 + assert output.completion_tokens == 4 + assert output.finish_reason == "stop" + mock_llm_model.chat.assert_not_called() + engine._model.tokenizer.encode.assert_called_once_with( + output.text, add_special_tokens=False + ) + + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" from vllm_mlx.engine.simple import SimpleEngine @@ -178,7 +236,7 @@ async def try_stream(): result = await stream_task assert len(result) == 3, f"Expected 3 chunks, got {len(result)}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_engine_initialization_creates_lock(self): """Test that SimpleEngine creates a lock on initialization.""" from vllm_mlx.engine.simple import SimpleEngine @@ -189,7 +247,7 @@ async def test_engine_initialization_creates_lock(self): assert hasattr(engine, "_generation_lock") assert isinstance(engine._generation_lock, asyncio.Lock) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_requests_complete_in_order(self, mock_model): """Test that concurrent requests complete (may be in any order due to lock).""" from vllm_mlx.engine.simple import SimpleEngine diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 4df2f0e54..5d7077141 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -453,6 +453,36 @@ async def chat( if not self._loaded: await self.start() + # mlx-lm non-streaming chat with tools can stall indefinitely on some + # local models, while the streaming path completes normally. Reuse the + # streaming implementation and aggregate its final state so both chat + # APIs share the same tool-capable execution path. + if tools and not self._is_mllm: + final_output = GenerationOutput(text="") + async for output in self.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + **kwargs, + ): + final_output = output + text = clean_output_text(final_output.text) + try: + tokens = self._model.tokenizer.encode(text, add_special_tokens=False) + except TypeError: + tokens = self._model.tokenizer.encode(text) + return GenerationOutput( + text=text, + tokens=tokens, + prompt_tokens=final_output.prompt_tokens, + completion_tokens=final_output.completion_tokens, + finish_reason=final_output.finish_reason, + ) + # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None