diff --git a/tests/test_simple_engine_cancel_serialization.py b/tests/test_simple_engine_cancel_serialization.py new file mode 100644 index 000000000..5f50367be --- /dev/null +++ b/tests/test_simple_engine_cancel_serialization.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression test for cancellation-safe SimpleEngine serialization.""" + +from __future__ import annotations + +import asyncio +import threading +import unittest +from unittest.mock import MagicMock, patch + + +class SimpleEngineCancelSerializationTests(unittest.IsolatedAsyncioTestCase): + async def test_cancellation_does_not_release_lock_before_worker_finishes(self): + """A cancelled request must not let a second MLX worker overlap.""" + from vllm_mlx.engine.simple import SimpleEngine + + model = MagicMock() + model.tokenizer = MagicMock() + model.tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + model._concurrent_count = 0 + model._max_concurrent = 0 + + first_started = threading.Event() + release_workers = threading.Event() + call_count = 0 + call_lock = threading.Lock() + + def generate_side_effect(**kwargs): + nonlocal call_count + with call_lock: + call_count += 1 + current_call = call_count + model._concurrent_count += 1 + model._max_concurrent = max( + model._max_concurrent, model._concurrent_count + ) + if current_call == 1: + first_started.set() + + release_workers.wait(timeout=1.0) + + with call_lock: + model._concurrent_count -= 1 + + result = MagicMock() + result.text = f"response-{current_call}" + result.tokens = [1, 2, 3] + result.finish_reason = "stop" + return result + + model.generate = MagicMock(side_effect=generate_side_effect) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + + task1 = asyncio.create_task(engine.generate(prompt="first", max_tokens=8)) + await asyncio.to_thread(first_started.wait, 1.0) + + task1.cancel() + task2 = asyncio.create_task(engine.generate(prompt="second", max_tokens=8)) + + await asyncio.sleep(0.05) + release_workers.set() + + with self.assertRaises(asyncio.CancelledError): + await task1 + result2 = await task2 + + self.assertEqual(result2.text, "response-2") + self.assertEqual( + model._max_concurrent, + 1, + "cancellation released the generation lock before the first worker finished", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index e96317ef0..4df2f0e54 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -226,6 +226,24 @@ async def stop(self) -> None: self._system_kv_token_count = 0 logger.info("SimpleEngine stopped") + async def _run_blocking_serialized(self, func, /, *args, **kwargs): + """Run a blocking MLX operation under the generation lock. + + Cancellation must not release the async lock before the worker thread + finishes, or a follow-up request can enter MLX/Metal concurrently and + corrupt the command-buffer state. + """ + async with self._generation_lock: + task = asyncio.create_task(asyncio.to_thread(func, *args, **kwargs)) + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + try: + await task + except Exception: + pass + raise + async def generate( self, prompt: str, @@ -252,30 +270,28 @@ async def generate( if not self._loaded: await self.start() - async with self._generation_lock: - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.generate, - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - **kwargs, - ) - - # Clean output text - text = clean_output_text(output.text) + output = await self._run_blocking_serialized( + self._model.generate, + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ) - return GenerationOutput( - text=text, - tokens=getattr(output, "tokens", []), - prompt_tokens=getattr(output, "prompt_tokens", 0), - completion_tokens=getattr( - output, "completion_tokens", len(getattr(output, "tokens", [])) - ), - finish_reason=output.finish_reason, - ) + # Clean output text + text = clean_output_text(output.text) + + return GenerationOutput( + text=text, + tokens=getattr(output, "tokens", []), + prompt_tokens=getattr(output, "prompt_tokens", 0), + completion_tokens=getattr( + output, "completion_tokens", len(getattr(output, "tokens", [])) + ), + finish_reason=output.finish_reason, + ) async def stream_generate( self, @@ -440,44 +456,39 @@ async def chat( # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None - async with self._generation_lock: - if self._is_mllm: - # For MLLM, use the chat method which handles images/videos - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, - ) - text = clean_output_text(output.text) - return GenerationOutput( - text=text, - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - finish_reason=output.finish_reason, - ) - else: - # For LLM, use the chat method - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - tools=template_tools, - **kwargs, - ) - text = clean_output_text(output.text) - return GenerationOutput( - text=text, - tokens=output.tokens, - completion_tokens=len(output.tokens), - finish_reason=output.finish_reason, - ) + if self._is_mllm: + output = await self._run_blocking_serialized( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, + ) + text = clean_output_text(output.text) + return GenerationOutput( + text=text, + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + finish_reason=output.finish_reason, + ) + else: + output = await self._run_blocking_serialized( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=template_tools, + **kwargs, + ) + text = clean_output_text(output.text) + return GenerationOutput( + text=text, + tokens=output.tokens, + completion_tokens=len(output.tokens), + finish_reason=output.finish_reason, + ) async def stream_chat( self, @@ -537,42 +548,41 @@ async def stream_chat( # For MLLM, use stream_chat which yields tokens incrementally. # Must hold _generation_lock to prevent concurrent Metal access # (e.g. OpenCode sends title + main request simultaneously). - async with self._generation_lock: - accumulated_text = "" - token_count = 0 - - # Run stream_chat in thread pool since it's synchronous - def run_stream(): - return list( - self._model.stream_chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, - ) + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + return list( + self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, ) + ) - chunks = await asyncio.to_thread(run_stream) + chunks = await self._run_blocking_serialized(run_stream) - for chunk in chunks: - token_count += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text - finished = chunk.finish_reason is not None + finished = chunk.finish_reason is not None - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=getattr(chunk, "prompt_tokens", 0), - completion_tokens=token_count, - finished=finished, - finish_reason=chunk.finish_reason if finished else None, - ) + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, "prompt_tokens", 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) - if finished: - break + if finished: + break return # For LLM, apply chat template and stream @@ -758,7 +768,7 @@ def _run_normal(): ) return results - all_resps = await asyncio.to_thread(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" @@ -1186,7 +1196,7 @@ def _run_specprefill(model, bc): finally: cleanup_rope(model) - all_resps = await asyncio.to_thread(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = ""