Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ def pytest_collection_modifyitems(config, items):
def server_url(request):
"""Get server URL from command line."""
return request.config.getoption("--server-url")


@pytest.fixture(params=["asyncio"])
def anyio_backend(request):
"""Run anyio-marked tests on asyncio only (trio is not installed)."""
return request.param
12 changes: 7 additions & 5 deletions tests/test_simple_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pytest

pytestmark = pytest.mark.anyio


class TestSimpleEngineConcurrency:
"""Test SimpleEngine lock behavior with concurrent requests."""
Expand Down Expand Up @@ -65,7 +67,7 @@ def chat_side_effect(**kwargs):
model.chat = MagicMock(side_effect=chat_side_effect)
return model

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_lock_prevents_concurrent_generate(self, mock_model):
"""Test that the lock prevents concurrent generate calls."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -89,7 +91,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model):
"The lock is not working correctly."
)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"""Test that the lock prevents concurrent chat calls."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -115,7 +117,7 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"The lock is not working correctly."
)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_lock_serializes_stream_generate(self, mock_model):
"""Test that stream_generate uses the same lock as other methods."""
from vllm_mlx.engine.simple import SimpleEngine
Expand Down Expand Up @@ -178,7 +180,7 @@ async def try_stream():
result = await stream_task
assert len(result) == 3, f"Expected 3 chunks, got {len(result)}"

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_engine_initialization_creates_lock(self):
"""Test that SimpleEngine creates a lock on initialization."""
from vllm_mlx.engine.simple import SimpleEngine
Expand All @@ -189,7 +191,7 @@ async def test_engine_initialization_creates_lock(self):
assert hasattr(engine, "_generation_lock")
assert isinstance(engine._generation_lock, asyncio.Lock)

@pytest.mark.asyncio
@pytest.mark.anyio
async def test_requests_complete_in_order(self, mock_model):
"""Test that concurrent requests complete (may be in any order due to lock)."""
from vllm_mlx.engine.simple import SimpleEngine
Expand Down
143 changes: 143 additions & 0 deletions tests/test_simple_engine_cancel_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
"""Regression test for cancellation-safe SimpleEngine serialization."""

from __future__ import annotations

import asyncio
import threading
import unittest
from unittest.mock import MagicMock, patch


class SimpleEngineCancelSerializationTests(unittest.IsolatedAsyncioTestCase):
async def test_cancellation_does_not_release_lock_before_worker_finishes(self):
"""A cancelled request must not let a second MLX worker overlap."""
from vllm_mlx.engine.simple import SimpleEngine

model = MagicMock()
model.tokenizer = MagicMock()
model.tokenizer.encode = MagicMock(return_value=[1, 2, 3])
model._concurrent_count = 0
model._max_concurrent = 0

first_started = threading.Event()
release_workers = threading.Event()
call_count = 0
call_lock = threading.Lock()

def generate_side_effect(**kwargs):
nonlocal call_count
with call_lock:
call_count += 1
current_call = call_count
model._concurrent_count += 1
model._max_concurrent = max(
model._max_concurrent, model._concurrent_count
)
if current_call == 1:
first_started.set()

release_workers.wait(timeout=1.0)

with call_lock:
model._concurrent_count -= 1

result = MagicMock()
result.text = f"response-{current_call}"
result.tokens = [1, 2, 3]
result.finish_reason = "stop"
return result

model.generate = MagicMock(side_effect=generate_side_effect)

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._model = model
engine._loaded = True

task1 = asyncio.create_task(engine.generate(prompt="first", max_tokens=8))
await asyncio.to_thread(first_started.wait, 1.0)

task1.cancel()
task2 = asyncio.create_task(engine.generate(prompt="second", max_tokens=8))

await asyncio.sleep(0.05)
release_workers.set()

with self.assertRaises(asyncio.CancelledError):
await task1
result2 = await task2

self.assertEqual(result2.text, "response-2")
self.assertEqual(
model._max_concurrent,
1,
"cancellation released the generation lock before the first worker finished",
)

async def test_specprefill_path_does_not_prelock_serialized_runner(self):
"""Specprefill streaming must let _run_blocking_serialized own the lock."""
from vllm_mlx.engine.simple import SimpleEngine

async def fake_serialized(func, *args, **kwargs):
self.assertFalse(engine._generation_lock.locked())
return []

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._loaded = True
engine._model = MagicMock()
engine._model.model = MagicMock()
engine._model.tokenizer = MagicMock()
engine._draft_model = MagicMock()
engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign]

outputs = []
async for chunk in engine._stream_generate_specprefill(
prompt="hello",
tokens=[1, 2, 3, 4],
max_tokens=4,
temperature=0.7,
top_p=0.9,
):
outputs.append(chunk)

self.assertEqual(len(outputs), 1)
self.assertTrue(outputs[0].finished)
self.assertEqual(outputs[0].completion_tokens, 0)

async def test_text_mtp_path_does_not_prelock_serialized_runner(self):
"""Text-only MTP streaming must let _run_blocking_serialized own the lock."""
from vllm_mlx.engine.simple import SimpleEngine

async def fake_serialized(func, *args, **kwargs):
self.assertFalse(engine._generation_lock.locked())
return []

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True):
engine = SimpleEngine("test-model")
engine._loaded = True
engine._text_model = MagicMock()
engine._text_model.make_mtp_cache = MagicMock(return_value=[])
engine._text_tokenizer = MagicMock()
engine._text_tokenizer.apply_chat_template = MagicMock(return_value="hello")
engine._text_tokenizer.bos_token = None
engine._draft_model = None
engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign]

outputs = []
async for chunk in engine._stream_generate_text(
messages=[{"role": "user", "content": "hello"}],
max_tokens=4,
temperature=0.7,
top_p=0.9,
):
outputs.append(chunk)

self.assertEqual(len(outputs), 1)
self.assertTrue(outputs[0].finished)
self.assertEqual(outputs[0].completion_tokens, 0)


if __name__ == "__main__":
unittest.main()
Loading
Loading