Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletions tests/test_simple_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
class TestSimpleEngineConcurrency:
"""Test SimpleEngine lock behavior with concurrent requests."""

@pytest.fixture
def anyio_backend(self):
return "asyncio"

Comment on lines +13 to +16
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action required

1. Pytest anyio plugin missing 🐞 Bug ⛯ Reliability

Tests were switched to @pytest.mark.anyio and an anyio_backend fixture was added, but the repo
doesn’t declare pytest-anyio (or register the anyio marker), so these tests can fail or warn in
CI/dev environments that don’t already have that plugin installed.
Agent Prompt
### Issue description
`tests/test_simple_engine.py` uses `@pytest.mark.anyio` and defines `anyio_backend`, which requires the `pytest-anyio` plugin (and typically an `anyio` marker registration). The repo currently only declares `pytest-asyncio` in dev deps, and `pytest.ini` does not register the `anyio` marker.

### Issue Context
- Current dev deps: `pytest`, `pytest-asyncio`.
- Tests now use anyio marker and backend fixture.

### Fix Focus Areas
- Add `pytest-anyio` to dev optional deps and register marker:
  - pyproject.toml[65-72]
  - pytest.ini[10-14]

OR
- Revert tests to pytest-asyncio:
  - tests/test_simple_engine.py[13-16]
  - tests/test_simple_engine.py[72-246]

ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools

@pytest.fixture
def mock_model(self):
"""Create a mock model that tracks concurrent calls."""
Expand Down Expand Up @@ -65,7 +69,7 @@ def chat_side_effect(**kwargs):
model.chat = MagicMock(side_effect=chat_side_effect)
return model

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_lock_prevents_concurrent_generate(self, mock_model):
"""Test that the lock prevents concurrent generate calls."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -89,7 +93,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model):
"The lock is not working correctly."
)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"""Test that the lock prevents concurrent chat calls."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -115,7 +119,61 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"The lock is not working correctly."
)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model):
"""Tool-enabled non-stream chat should use the streaming path."""
from vllm_mlx.engine.simple import SimpleEngine

async def fake_stream_chat(*args, **kwargs):
yield MagicMock(
text="partial",
tokens=[],
prompt_tokens=11,
completion_tokens=1,
finish_reason=None,
finished=False,
)
yield MagicMock(
text="<|im_end|><tool_call>{\"name\":\"bash\",\"arguments\":{\"command\":\"pwd\"}}</tool_call>",
tokens=[],
prompt_tokens=11,
completion_tokens=4,
finish_reason="stop",
finished=True,
)

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._model = mock_llm_model
engine._loaded = True
engine._model.tokenizer.encode = MagicMock(return_value=[7, 8, 9])
engine.stream_chat = fake_stream_chat # type: ignore[method-assign]

output = await engine.chat(
messages=[{"role": "user", "content": "run pwd"}],
max_tokens=16,
tools=[
{
"type": "function",
"function": {
"name": "bash",
"parameters": {"type": "object", "properties": {}},
},
}
],
)

assert output.text.startswith("<tool_call>")
assert output.tokens == [7, 8, 9]
assert output.prompt_tokens == 11
assert output.completion_tokens == 4
assert output.finish_reason == "stop"
mock_llm_model.chat.assert_not_called()
engine._model.tokenizer.encode.assert_called_once_with(
output.text, add_special_tokens=False
)

@pytest.mark.anyio
async def test_lock_serializes_stream_generate(self, mock_model):
"""Test that stream_generate uses the same lock as other methods."""
from vllm_mlx.engine.simple import SimpleEngine
Expand Down Expand Up @@ -178,7 +236,7 @@ async def try_stream():
result = await stream_task
assert len(result) == 3, f"Expected 3 chunks, got {len(result)}"

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_engine_initialization_creates_lock(self):
"""Test that SimpleEngine creates a lock on initialization."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -189,7 +247,7 @@ async def test_engine_initialization_creates_lock(self):
assert hasattr(engine, "_generation_lock")
assert isinstance(engine._generation_lock, asyncio.Lock)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_requests_complete_in_order(self, mock_model):
"""Test that concurrent requests complete (may be in any order due to lock)."""
from vllm_mlx.engine.simple import SimpleEngine
Expand Down
30 changes: 30 additions & 0 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,36 @@ async def chat(
if not self._loaded:
await self.start()

# mlx-lm non-streaming chat with tools can stall indefinitely on some
# local models, while the streaming path completes normally. Reuse the
# streaming implementation and aggregate its final state so both chat
# APIs share the same tool-capable execution path.
if tools and not self._is_mllm:
final_output = GenerationOutput(text="")
async for output in self.stream_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
tools=tools,
images=images,
videos=videos,
**kwargs,
):
final_output = output
text = clean_output_text(final_output.text)
try:
tokens = self._model.tokenizer.encode(text, add_special_tokens=False)
except TypeError:
tokens = self._model.tokenizer.encode(text)
return GenerationOutput(
text=text,
tokens=tokens,
prompt_tokens=final_output.prompt_tokens,
completion_tokens=final_output.completion_tokens,
finish_reason=final_output.finish_reason,
)

# Convert tools for template if provided
template_tools = convert_tools_for_template(tools) if tools else None

Expand Down