Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/test_simple_engine_cancel_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
"""Regression test for cancellation-safe SimpleEngine serialization."""

from __future__ import annotations

import asyncio
import threading
import unittest
from unittest.mock import MagicMock, patch


class SimpleEngineCancelSerializationTests(unittest.IsolatedAsyncioTestCase):
async def test_cancellation_does_not_release_lock_before_worker_finishes(self):
"""A cancelled request must not let a second MLX worker overlap."""
from vllm_mlx.engine.simple import SimpleEngine

model = MagicMock()
model.tokenizer = MagicMock()
model.tokenizer.encode = MagicMock(return_value=[1, 2, 3])
model._concurrent_count = 0
model._max_concurrent = 0

first_started = threading.Event()
release_workers = threading.Event()
call_count = 0
call_lock = threading.Lock()

def generate_side_effect(**kwargs):
nonlocal call_count
with call_lock:
call_count += 1
current_call = call_count
model._concurrent_count += 1
model._max_concurrent = max(
model._max_concurrent, model._concurrent_count
)
if current_call == 1:
first_started.set()

release_workers.wait(timeout=1.0)

with call_lock:
model._concurrent_count -= 1

result = MagicMock()
result.text = f"response-{current_call}"
result.tokens = [1, 2, 3]
result.finish_reason = "stop"
return result

model.generate = MagicMock(side_effect=generate_side_effect)

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._model = model
engine._loaded = True

task1 = asyncio.create_task(engine.generate(prompt="first", max_tokens=8))
await asyncio.to_thread(first_started.wait, 1.0)

task1.cancel()
task2 = asyncio.create_task(engine.generate(prompt="second", max_tokens=8))

await asyncio.sleep(0.05)
release_workers.set()

with self.assertRaises(asyncio.CancelledError):
await task1
result2 = await task2

self.assertEqual(result2.text, "response-2")
self.assertEqual(
model._max_concurrent,
1,
"cancellation released the generation lock before the first worker finished",
)


if __name__ == "__main__":
unittest.main()
196 changes: 103 additions & 93 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,24 @@ async def stop(self) -> None:
self._system_kv_token_count = 0
logger.info("SimpleEngine stopped")

async def _run_blocking_serialized(self, func, /, *args, **kwargs):
"""Run a blocking MLX operation under the generation lock.

Cancellation must not release the async lock before the worker thread
finishes, or a follow-up request can enter MLX/Metal concurrently and
corrupt the command-buffer state.
"""
async with self._generation_lock:
task = asyncio.create_task(asyncio.to_thread(func, *args, **kwargs))
try:
return await asyncio.shield(task)
except asyncio.CancelledError:
try:
await task
except Exception:
pass
raise

async def generate(
self,
prompt: str,
Expand All @@ -252,30 +270,28 @@ async def generate(
if not self._loaded:
await self.start()

async with self._generation_lock:
# Run in thread pool to allow asyncio timeout to work
output = await asyncio.to_thread(
self._model.generate,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
**kwargs,
)

# Clean output text
text = clean_output_text(output.text)
output = await self._run_blocking_serialized(
self._model.generate,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
**kwargs,
)

return GenerationOutput(
text=text,
tokens=getattr(output, "tokens", []),
prompt_tokens=getattr(output, "prompt_tokens", 0),
completion_tokens=getattr(
output, "completion_tokens", len(getattr(output, "tokens", []))
),
finish_reason=output.finish_reason,
)
# Clean output text
text = clean_output_text(output.text)

return GenerationOutput(
text=text,
tokens=getattr(output, "tokens", []),
prompt_tokens=getattr(output, "prompt_tokens", 0),
completion_tokens=getattr(
output, "completion_tokens", len(getattr(output, "tokens", []))
),
finish_reason=output.finish_reason,
)

async def stream_generate(
self,
Expand Down Expand Up @@ -440,44 +456,39 @@ async def chat(
# Convert tools for template if provided
template_tools = convert_tools_for_template(tools) if tools else None

async with self._generation_lock:
if self._is_mllm:
# For MLLM, use the chat method which handles images/videos
# Run in thread pool to allow asyncio timeout to work
output = await asyncio.to_thread(
self._model.chat,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
**kwargs,
)
text = clean_output_text(output.text)
return GenerationOutput(
text=text,
prompt_tokens=output.prompt_tokens,
completion_tokens=output.completion_tokens,
finish_reason=output.finish_reason,
)
else:
# For LLM, use the chat method
# Run in thread pool to allow asyncio timeout to work
output = await asyncio.to_thread(
self._model.chat,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
tools=template_tools,
**kwargs,
)
text = clean_output_text(output.text)
return GenerationOutput(
text=text,
tokens=output.tokens,
completion_tokens=len(output.tokens),
finish_reason=output.finish_reason,
)
if self._is_mllm:
output = await self._run_blocking_serialized(
self._model.chat,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
**kwargs,
)
text = clean_output_text(output.text)
return GenerationOutput(
text=text,
prompt_tokens=output.prompt_tokens,
completion_tokens=output.completion_tokens,
finish_reason=output.finish_reason,
)
else:
output = await self._run_blocking_serialized(
self._model.chat,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
tools=template_tools,
**kwargs,
)
text = clean_output_text(output.text)
return GenerationOutput(
text=text,
tokens=output.tokens,
completion_tokens=len(output.tokens),
finish_reason=output.finish_reason,
)

async def stream_chat(
self,
Expand Down Expand Up @@ -537,42 +548,41 @@ async def stream_chat(
# For MLLM, use stream_chat which yields tokens incrementally.
# Must hold _generation_lock to prevent concurrent Metal access
# (e.g. OpenCode sends title + main request simultaneously).
async with self._generation_lock:
accumulated_text = ""
token_count = 0

# Run stream_chat in thread pool since it's synchronous
def run_stream():
return list(
self._model.stream_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
**kwargs,
)
accumulated_text = ""
token_count = 0

# Run stream_chat in thread pool since it's synchronous
def run_stream():
return list(
self._model.stream_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
**kwargs,
)
)

chunks = await asyncio.to_thread(run_stream)
chunks = await self._run_blocking_serialized(run_stream)

for chunk in chunks:
token_count += 1
new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
accumulated_text += new_text
for chunk in chunks:
token_count += 1
new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
accumulated_text += new_text

finished = chunk.finish_reason is not None
finished = chunk.finish_reason is not None

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
prompt_tokens=getattr(chunk, "prompt_tokens", 0),
completion_tokens=token_count,
finished=finished,
finish_reason=chunk.finish_reason if finished else None,
)
yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
prompt_tokens=getattr(chunk, "prompt_tokens", 0),
completion_tokens=token_count,
finished=finished,
finish_reason=chunk.finish_reason if finished else None,
)

if finished:
break
if finished:
break
return

# For LLM, apply chat template and stream
Expand Down Expand Up @@ -758,7 +768,7 @@ def _run_normal():
)
return results

all_resps = await asyncio.to_thread(_run_all)
all_resps = await self._run_blocking_serialized(_run_all)

# Yield results as GenerationOutput
accumulated_text = ""
Expand Down Expand Up @@ -1186,7 +1196,7 @@ def _run_specprefill(model, bc):
finally:
cleanup_rope(model)

all_resps = await asyncio.to_thread(_run_all)
all_resps = await self._run_blocking_serialized(_run_all)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action required

1. Mtp stream deadlock 🐞 Bug ⛯ Reliability

_stream_generate_text() holds _generation_lock and now awaits _run_blocking_serialized(),
which tries to acquire the same lock again, deadlocking MLLM+MTP text-only streaming. Text-only MLLM
requests routed to this path will hang indefinitely.
Agent Prompt
### Issue description
`_stream_generate_text()` currently acquires `self._generation_lock` and then calls `_run_blocking_serialized()`, which also acquires `self._generation_lock`. This deadlocks.

### Issue Context
This path is used for MLLM+MTP routing (text-only requests). The PR moved lock acquisition into `_run_blocking_serialized()`, so call sites that already hold the lock must be adjusted.

### Fix Focus Areas
- vllm_mlx/engine/simple.py[229-246]
- vllm_mlx/engine/simple.py[1012-1014]
- vllm_mlx/engine/simple.py[1194-1201]

### Implementation direction
Refactor `_stream_generate_text()` so `_run_all` (and thus MLX/Metal calls) run under `_run_blocking_serialized()` without an additional surrounding `async with self._generation_lock:` block; or add a dedicated cancellation-safe helper that does **not** reacquire the lock when the caller already holds it.

ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools


# Yield results as GenerationOutput
accumulated_text = ""
Expand Down