From 8a782bb4c69c9aac6ed3e8750ae547999cf1d2de Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Tue, 24 Mar 2026 15:34:07 +0100 Subject: [PATCH 1/5] fix: keep simple engine serialized across cancellation (#8) --- ...test_simple_engine_cancel_serialization.py | 80 +++++++ vllm_mlx/engine/simple.py | 208 +++++++++--------- 2 files changed, 184 insertions(+), 104 deletions(-) create mode 100644 tests/test_simple_engine_cancel_serialization.py diff --git a/tests/test_simple_engine_cancel_serialization.py b/tests/test_simple_engine_cancel_serialization.py new file mode 100644 index 000000000..5f50367be --- /dev/null +++ b/tests/test_simple_engine_cancel_serialization.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression test for cancellation-safe SimpleEngine serialization.""" + +from __future__ import annotations + +import asyncio +import threading +import unittest +from unittest.mock import MagicMock, patch + + +class SimpleEngineCancelSerializationTests(unittest.IsolatedAsyncioTestCase): + async def test_cancellation_does_not_release_lock_before_worker_finishes(self): + """A cancelled request must not let a second MLX worker overlap.""" + from vllm_mlx.engine.simple import SimpleEngine + + model = MagicMock() + model.tokenizer = MagicMock() + model.tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + model._concurrent_count = 0 + model._max_concurrent = 0 + + first_started = threading.Event() + release_workers = threading.Event() + call_count = 0 + call_lock = threading.Lock() + + def generate_side_effect(**kwargs): + nonlocal call_count + with call_lock: + call_count += 1 + current_call = call_count + model._concurrent_count += 1 + model._max_concurrent = max( + model._max_concurrent, model._concurrent_count + ) + if current_call == 1: + first_started.set() + + release_workers.wait(timeout=1.0) + + with call_lock: + model._concurrent_count -= 1 + + result = MagicMock() + result.text = f"response-{current_call}" + result.tokens = [1, 2, 3] + result.finish_reason = "stop" + return result + + model.generate = MagicMock(side_effect=generate_side_effect) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + + task1 = asyncio.create_task(engine.generate(prompt="first", max_tokens=8)) + await asyncio.to_thread(first_started.wait, 1.0) + + task1.cancel() + task2 = asyncio.create_task(engine.generate(prompt="second", max_tokens=8)) + + await asyncio.sleep(0.05) + release_workers.set() + + with self.assertRaises(asyncio.CancelledError): + await task1 + result2 = await task2 + + self.assertEqual(result2.text, "response-2") + self.assertEqual( + model._max_concurrent, + 1, + "cancellation released the generation lock before the first worker finished", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index da3ccfc18..5d1994f75 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -226,6 +226,24 @@ async def stop(self) -> None: self._system_kv_token_count = 0 logger.info("SimpleEngine stopped") + async def _run_blocking_serialized(self, func, /, *args, **kwargs): + """Run a blocking MLX operation under the generation lock. + + Cancellation must not release the async lock before the worker thread + finishes, or a follow-up request can enter MLX/Metal concurrently and + corrupt the command-buffer state. + """ + async with self._generation_lock: + task = asyncio.create_task(asyncio.to_thread(func, *args, **kwargs)) + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + try: + await task + except Exception: + pass + raise + async def generate( self, prompt: str, @@ -252,30 +270,28 @@ async def generate( if not self._loaded: await self.start() - async with self._generation_lock: - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.generate, - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - **kwargs, - ) - - # Clean output text - text = clean_output_text(output.text) + output = await self._run_blocking_serialized( + self._model.generate, + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ) - return GenerationOutput( - text=text, - tokens=getattr(output, "tokens", []), - prompt_tokens=getattr(output, "prompt_tokens", 0), - completion_tokens=getattr( - output, "completion_tokens", len(getattr(output, "tokens", [])) - ), - finish_reason=output.finish_reason, - ) + # Clean output text + text = clean_output_text(output.text) + + return GenerationOutput( + text=text, + tokens=getattr(output, "tokens", []), + prompt_tokens=getattr(output, "prompt_tokens", 0), + completion_tokens=getattr( + output, "completion_tokens", len(getattr(output, "tokens", [])) + ), + finish_reason=output.finish_reason, + ) async def stream_generate( self, @@ -440,55 +456,40 @@ async def chat( # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None - async with self._generation_lock: - if self._is_mllm: - # For MLLM, use the chat method which handles images/videos - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, - ) - text = clean_output_text(output.text) - return GenerationOutput( - text=text, - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - finish_reason=output.finish_reason, - ) - else: - # For LLM, use the chat method - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - tools=template_tools, - **kwargs, - ) - text = clean_output_text(output.text) - # Count prompt tokens from the full templated prompt - tokenizer = self._model.tokenizer - template_kwargs = { - "tokenize": True, - "add_generation_prompt": True, - } - if template_tools: - template_kwargs["tools"] = template_tools - prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs) - prompt_token_count = len(prompt_ids) - return GenerationOutput( - text=text, - tokens=output.tokens, - prompt_tokens=prompt_token_count, - completion_tokens=len(output.tokens), - finish_reason=output.finish_reason, - ) + if self._is_mllm: + output = await self._run_blocking_serialized( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, + ) + text = clean_output_text(output.text) + return GenerationOutput( + text=text, + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + finish_reason=output.finish_reason, + ) + else: + output = await self._run_blocking_serialized( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=template_tools, + **kwargs, + ) + text = clean_output_text(output.text) + return GenerationOutput( + text=text, + tokens=output.tokens, + prompt_tokens=prompt_token_count, + completion_tokens=len(output.tokens), + finish_reason=output.finish_reason, + ) async def stream_chat( self, @@ -548,42 +549,41 @@ async def stream_chat( # For MLLM, use stream_chat which yields tokens incrementally. # Must hold _generation_lock to prevent concurrent Metal access # (e.g. OpenCode sends title + main request simultaneously). - async with self._generation_lock: - accumulated_text = "" - token_count = 0 - - # Run stream_chat in thread pool since it's synchronous - def run_stream(): - return list( - self._model.stream_chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, - ) + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + return list( + self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, ) + ) - chunks = await asyncio.to_thread(run_stream) + chunks = await self._run_blocking_serialized(run_stream) - for chunk in chunks: - token_count += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text - finished = chunk.finish_reason is not None + finished = chunk.finish_reason is not None - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=getattr(chunk, "prompt_tokens", 0), - completion_tokens=token_count, - finished=finished, - finish_reason=chunk.finish_reason if finished else None, - ) + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, "prompt_tokens", 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) - if finished: - break + if finished: + break return # For LLM, apply chat template and stream @@ -769,7 +769,7 @@ def _run_normal(): ) return results - all_resps = await asyncio.to_thread(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" @@ -1197,7 +1197,7 @@ def _run_specprefill(model, bc): finally: cleanup_rope(model) - all_resps = await asyncio.to_thread(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" From e22052bdf5ace728a758f74026923cb46a614ea7 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Tue, 24 Mar 2026 19:11:17 +0100 Subject: [PATCH 2/5] fix: avoid nested simple engine generation locks --- ...test_simple_engine_cancel_serialization.py | 63 ++++++ vllm_mlx/engine/simple.py | 188 +++++++++--------- 2 files changed, 155 insertions(+), 96 deletions(-) diff --git a/tests/test_simple_engine_cancel_serialization.py b/tests/test_simple_engine_cancel_serialization.py index 5f50367be..28c25868e 100644 --- a/tests/test_simple_engine_cancel_serialization.py +++ b/tests/test_simple_engine_cancel_serialization.py @@ -75,6 +75,69 @@ def generate_side_effect(**kwargs): "cancellation released the generation lock before the first worker finished", ) + async def test_specprefill_path_does_not_prelock_serialized_runner(self): + """Specprefill streaming must let _run_blocking_serialized own the lock.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_serialized(func, *args, **kwargs): + self.assertFalse(engine._generation_lock.locked()) + return [] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._model = MagicMock() + engine._model.model = MagicMock() + engine._model.tokenizer = MagicMock() + engine._draft_model = MagicMock() + engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign] + + outputs = [] + async for chunk in engine._stream_generate_specprefill( + prompt="hello", + tokens=[1, 2, 3, 4], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0].finished) + self.assertEqual(outputs[0].completion_tokens, 0) + + async def test_text_mtp_path_does_not_prelock_serialized_runner(self): + """Text-only MTP streaming must let _run_blocking_serialized own the lock.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_serialized(func, *args, **kwargs): + self.assertFalse(engine._generation_lock.locked()) + return [] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.make_mtp_cache = MagicMock(return_value=[]) + engine._text_tokenizer = MagicMock() + engine._text_tokenizer.apply_chat_template = MagicMock(return_value="hello") + engine._text_tokenizer.bos_token = None + engine._draft_model = None + engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign] + + outputs = [] + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0].finished) + self.assertEqual(outputs[0].completion_tokens, 0) + if __name__ == "__main__": unittest.main() diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 5d1994f75..b2473fc60 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -647,16 +647,14 @@ async def _stream_generate_specprefill( tokenizer = self._model.tokenizer n_tokens = len(tokens) - async with self._generation_lock: - - def _run_all(): - try: - return _run_specprefill() - except Exception as e: - logger.error( - "SpecPrefill failed, falling back to normal path: %s", e - ) - return _run_normal() + def _run_all(): + try: + return _run_specprefill() + except Exception as e: + logger.error( + "SpecPrefill failed, falling back to normal path: %s", e + ) + return _run_normal() def _run_specprefill(): """Score tokens, sparse prefill, generate autoregressively.""" @@ -769,7 +767,7 @@ def _run_normal(): ) return results - all_resps = await self._run_blocking_serialized(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" @@ -1010,96 +1008,94 @@ async def _stream_generate_text( ) use_specprefill = False - # Run under generation lock, all Metal ops in single thread - async with self._generation_lock: + # Run all Metal ops in a single serialized thread. + def _run_all(): + nonlocal backbone_cache, prompt_to_send - def _run_all(): - nonlocal backbone_cache, prompt_to_send + model = self._text_model - model = self._text_model + # Cache MISS with valid prefix: prefill system tokens and snapshot + if ( + not cache_hit + and system_token_count > 0 + and system_tokens is not None + and suffix_tokens is not None + ): + mc = make_prompt_cache(model) + sys_arr = mx.array(system_tokens) + + # Prefill system tokens in chunks (matching generate_step) + step = self._prefill_step_size + while sys_arr.size > step: + model(sys_arr[:step][None], cache=mc) + mx.eval([c.state for c in mc]) + sys_arr = sys_arr[step:] + mx.clear_cache() + if sys_arr.size > 0: + model(sys_arr[None], cache=mc) + mx.eval([c.state for c in mc]) + + # Snapshot backbone cache (immutable mx.arrays, safe to reuse) + snapshot = [c.state for c in mc] + mx.eval([s for pair in snapshot for s in pair]) + + self._system_kv_snapshot = snapshot + self._system_kv_hash = system_hash + self._system_kv_token_count = system_token_count + + backbone_cache = mc + prompt_to_send = mx.array(suffix_tokens) + logger.info( + "System KV cache: stored %d-token snapshot (%.1f MB), " + "prefilling %d remaining", + system_token_count, + sum(c.nbytes for c in mc) / 1e6, + len(suffix_tokens), + ) - # Cache MISS with valid prefix: prefill system tokens and snapshot - if ( - not cache_hit - and system_token_count > 0 - and system_tokens is not None - and suffix_tokens is not None - ): - mc = make_prompt_cache(model) - sys_arr = mx.array(system_tokens) - - # Prefill system tokens in chunks (matching generate_step) - step = self._prefill_step_size - while sys_arr.size > step: - model(sys_arr[:step][None], cache=mc) - mx.eval([c.state for c in mc]) - sys_arr = sys_arr[step:] - mx.clear_cache() - if sys_arr.size > 0: - model(sys_arr[None], cache=mc) - mx.eval([c.state for c in mc]) - - # Snapshot backbone cache (immutable mx.arrays, safe to reuse) - snapshot = [c.state for c in mc] - mx.eval([s for pair in snapshot for s in pair]) - - self._system_kv_snapshot = snapshot - self._system_kv_hash = system_hash - self._system_kv_token_count = system_token_count - - backbone_cache = mc - prompt_to_send = mx.array(suffix_tokens) - logger.info( - "System KV cache: stored %d-token snapshot (%.1f MB), " - "prefilling %d remaining", - system_token_count, - sum(c.nbytes for c in mc) / 1e6, - len(suffix_tokens), + # --- SpecPrefill path (with fallback to normal on failure) --- + if use_specprefill: + try: + return _run_specprefill(model, backbone_cache) + except Exception as e: + logger.error( + "SpecPrefill failed, falling back to normal MTP path: %s", + e, ) + # Discard potentially corrupted cache + backbone_cache = None + prompt_to_send = full_prompt + + # --- Normal path (MTP via mlx_lm stream_generate) --- + prompt_cache = None + if backbone_cache is not None: + # Add MTP cache on top of backbone + if hasattr(model, "make_mtp_cache"): + mtp_cache = model.make_mtp_cache() + prompt_cache = backbone_cache + mtp_cache + else: + prompt_cache = backbone_cache - # --- SpecPrefill path (with fallback to normal on failure) --- - if use_specprefill: - try: - return _run_specprefill(model, backbone_cache) - except Exception as e: - logger.error( - "SpecPrefill failed, falling back to normal MTP path: %s", - e, - ) - # Discard potentially corrupted cache - backbone_cache = None - prompt_to_send = full_prompt - - # --- Normal path (MTP via mlx_lm stream_generate) --- - prompt_cache = None - if backbone_cache is not None: - # Add MTP cache on top of backbone - if hasattr(model, "make_mtp_cache"): - mtp_cache = model.make_mtp_cache() - prompt_cache = backbone_cache + mtp_cache - else: - prompt_cache = backbone_cache - - results = [] - gen_kwargs = dict( - max_tokens=max_tokens, - sampler=sampler, - mtp=True, - prefill_step_size=self._prefill_step_size, - ) - if prompt_cache is not None: - gen_kwargs["prompt_cache"] = prompt_cache - - for resp in mlx_stream_generate( - model, - self._text_tokenizer, - prompt=prompt_to_send, - **gen_kwargs, - ): - results.append(resp) - return results + results = [] + gen_kwargs = dict( + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + prefill_step_size=self._prefill_step_size, + ) + if prompt_cache is not None: + gen_kwargs["prompt_cache"] = prompt_cache + + for resp in mlx_stream_generate( + model, + self._text_tokenizer, + prompt=prompt_to_send, + **gen_kwargs, + ): + results.append(resp) + return results - def _run_specprefill(model, bc): + def _run_specprefill(model, bc): """Score tokens, sparse prefill, generate without MTP.""" from types import SimpleNamespace @@ -1197,7 +1193,7 @@ def _run_specprefill(model, bc): finally: cleanup_rope(model) - all_resps = await self._run_blocking_serialized(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" From 81a0c26d8b8b0b162e8f8609fb75060d1424f934 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Thu, 26 Mar 2026 01:15:26 +0100 Subject: [PATCH 3/5] fix: catch BaseException in cancellation handler, fix async test markers _run_blocking_serialized catches CancelledError (a BaseException subclass) from the outer scope, but the inner try/except used Exception which would let a second CancelledError during await task escape unhandled. Changed to BaseException to suppress any exception from the draining await. Also fix test_simple_engine.py to use pytest.mark.anyio instead of pytest.mark.asyncio (pytest-asyncio is not configured), and add the anyio_backend fixture to conftest.py restricting to asyncio only since trio is not installed. --- tests/conftest.py | 6 ++++++ tests/test_simple_engine.py | 12 +++++++----- vllm_mlx/engine/simple.py | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d0c7f026b..f699c08bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,3 +50,9 @@ def pytest_collection_modifyitems(config, items): def server_url(request): """Get server URL from command line.""" return request.config.getoption("--server-url") + + +@pytest.fixture(params=["asyncio"]) +def anyio_backend(request): + """Run anyio-marked tests on asyncio only (trio is not installed).""" + return request.param diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index cce42bfc3..7202f625f 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -6,6 +6,8 @@ import pytest +pytestmark = pytest.mark.anyio + class TestSimpleEngineConcurrency: """Test SimpleEngine lock behavior with concurrent requests.""" @@ -65,7 +67,7 @@ def chat_side_effect(**kwargs): model.chat = MagicMock(side_effect=chat_side_effect) return model - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_generate(self, mock_model): """Test that the lock prevents concurrent generate calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -89,7 +91,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_chat(self, mock_llm_model): """Test that the lock prevents concurrent chat calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -115,7 +117,7 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" from vllm_mlx.engine.simple import SimpleEngine @@ -178,7 +180,7 @@ async def try_stream(): result = await stream_task assert len(result) == 3, f"Expected 3 chunks, got {len(result)}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_engine_initialization_creates_lock(self): """Test that SimpleEngine creates a lock on initialization.""" from vllm_mlx.engine.simple import SimpleEngine @@ -189,7 +191,7 @@ async def test_engine_initialization_creates_lock(self): assert hasattr(engine, "_generation_lock") assert isinstance(engine._generation_lock, asyncio.Lock) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_requests_complete_in_order(self, mock_model): """Test that concurrent requests complete (may be in any order due to lock).""" from vllm_mlx.engine.simple import SimpleEngine diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index b2473fc60..04dd32edf 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -240,7 +240,7 @@ async def _run_blocking_serialized(self, func, /, *args, **kwargs): except asyncio.CancelledError: try: await task - except Exception: + except BaseException: pass raise From aedb342ace82767e013e12d468d49eb28409ff05 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Thu, 9 Apr 2026 08:27:56 +0200 Subject: [PATCH 4/5] fix: preserve prompt token accounting after upstream refresh --- vllm_mlx/engine/simple.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 04dd32edf..f0e75510b 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -483,6 +483,17 @@ async def chat( **kwargs, ) text = clean_output_text(output.text) + # Preserve upstream prompt accounting while routing the blocking + # chat call through the cancellation-safe serialized runner. + tokenizer = self._model.tokenizer + template_kwargs = { + "tokenize": True, + "add_generation_prompt": True, + } + if template_tools: + template_kwargs["tools"] = template_tools + prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs) + prompt_token_count = len(prompt_ids) return GenerationOutput( text=text, tokens=output.tokens, From f2526ffd88aef7fbb17dbfb9f3adb3bddd2e4f45 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Thu, 9 Apr 2026 09:31:08 +0200 Subject: [PATCH 5/5] fix: restore specprefill fallback helper scope --- vllm_mlx/engine/simple.py | 360 +++++++++++++++++++------------------- 1 file changed, 179 insertions(+), 181 deletions(-) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index f0e75510b..8c7212ca4 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -662,121 +662,119 @@ def _run_all(): try: return _run_specprefill() except Exception as e: - logger.error( - "SpecPrefill failed, falling back to normal path: %s", e - ) + logger.error("SpecPrefill failed, falling back to normal path: %s", e) return _run_normal() - def _run_specprefill(): - """Score tokens, sparse prefill, generate autoregressively.""" - import time - from types import SimpleNamespace + def _run_specprefill(): + """Score tokens, sparse prefill, generate autoregressively.""" + import time + from types import SimpleNamespace - from ..specprefill import ( - cleanup_rope, - score_tokens, - select_chunks, - sparse_prefill, - ) + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) - cache = make_prompt_cache(model) + cache = make_prompt_cache(model) - try: - # Phase 1: Score with draft model - t0 = time.monotonic() - importance = score_tokens( - self._draft_model, - tokens, - prefill_step_size=self._prefill_step_size, - ) - t_score = time.monotonic() - t0 - - # Phase 2: Select important chunks - effective_keep = specprefill_keep_pct or self._specprefill_keep_pct - selected = select_chunks(importance, keep_pct=effective_keep) - n_selected = selected.shape[0] - - # Phase 3: Sparse prefill on target model - t0 = time.monotonic() - logits = sparse_prefill( - model, - tokens, - selected, - cache, - step_size=self._prefill_step_size, - ) - t_prefill = time.monotonic() - t0 + try: + # Phase 1: Score with draft model + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + tokens, + prefill_step_size=self._prefill_step_size, + ) + t_score = time.monotonic() - t0 + + # Phase 2: Select important chunks + effective_keep = specprefill_keep_pct or self._specprefill_keep_pct + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] + + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + model, + tokens, + selected, + cache, + step_size=self._prefill_step_size, + ) + t_prefill = time.monotonic() - t0 - logger.info( - "SpecPrefill: scored %d tokens in %.1fs, " - "sparse prefill %d/%d (keep=%.0f%%) in %.1fs", - n_tokens, - t_score, - n_selected, - n_tokens, - n_selected / n_tokens * 100, - t_prefill, - ) + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs", + n_tokens, + t_score, + n_selected, + n_tokens, + n_selected / n_tokens * 100, + t_prefill, + ) - # Phase 4: Generate (simple autoregressive, no MTP) - sampler = make_sampler(temp=temperature, top_p=top_p) - eos_id = tokenizer.eos_token_id - y = sampler(logits[:, -1, :]) - mx.eval(y) + # Phase 4: Generate (simple autoregressive, no MTP) + sampler = make_sampler(temp=temperature, top_p=top_p) + eos_id = tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) - results = [] - generated_ids = [] - prev_decoded = "" + results = [] + generated_ids = [] + prev_decoded = "" - for _ in range(max_tokens): - tok_id = y.item() - generated_ids.append(tok_id) + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) - decoded = tokenizer.decode(generated_ids) - new_text = decoded[len(prev_decoded) :] - prev_decoded = decoded + decoded = tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded - is_eos = tok_id == eos_id - results.append( - SimpleNamespace( - text=new_text, - finish_reason="stop" if is_eos else None, - ) + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, ) + ) - if is_eos: - break + if is_eos: + break - logits = model(y.reshape(1, -1), cache=cache) - y = sampler(logits[:, -1, :]) - mx.eval(y) + logits = model(y.reshape(1, -1), cache=cache) + y = sampler(logits[:, -1, :]) + mx.eval(y) - return results + return results - finally: - cleanup_rope(model) + finally: + cleanup_rope(model) - def _run_normal(): - """Fallback: normal generation without specprefill.""" - from types import SimpleNamespace + def _run_normal(): + """Fallback: normal generation without specprefill.""" + from types import SimpleNamespace - results = [] - for chunk in self._model.stream_generate( - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - **kwargs, - ): - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - results.append( - SimpleNamespace( - text=new_text, - finish_reason=getattr(chunk, "finish_reason", None), - ) + results = [] + for chunk in self._model.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ): + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + results.append( + SimpleNamespace( + text=new_text, + finish_reason=getattr(chunk, "finish_reason", None), ) - return results + ) + return results all_resps = await self._run_blocking_serialized(_run_all) @@ -1107,102 +1105,102 @@ def _run_all(): return results def _run_specprefill(model, bc): - """Score tokens, sparse prefill, generate without MTP.""" - from types import SimpleNamespace - - from ..specprefill import ( - cleanup_rope, - score_tokens, - select_chunks, - sparse_prefill, + """Score tokens, sparse prefill, generate without MTP.""" + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) + + # Create backbone cache if not already from system KV + if bc is None: + bc = make_prompt_cache(model) + + try: + # Phase 1: Score with draft model + import time + + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + specprefill_tokens, + prefill_step_size=self._prefill_step_size, ) + t_score = time.monotonic() - t0 + + # Phase 2: Select important chunks + effective_keep = specprefill_keep_pct or self._specprefill_keep_pct + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] + n_total = len(specprefill_tokens) + + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + model, + specprefill_tokens, + selected, + bc, + step_size=self._prefill_step_size, + position_offset=specprefill_offset, + ) + t_prefill = time.monotonic() - t0 - # Create backbone cache if not already from system KV - if bc is None: - bc = make_prompt_cache(model) + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " + "(offset=%d, effective_keep=%.2f)", + n_total, + t_score, + n_selected, + n_total, + n_selected / n_total * 100, + t_prefill, + specprefill_offset, + effective_keep, + ) - try: - # Phase 1: Score with draft model - import time - - t0 = time.monotonic() - importance = score_tokens( - self._draft_model, - specprefill_tokens, - prefill_step_size=self._prefill_step_size, - ) - t_score = time.monotonic() - t0 - - # Phase 2: Select important chunks - effective_keep = specprefill_keep_pct or self._specprefill_keep_pct - selected = select_chunks(importance, keep_pct=effective_keep) - n_selected = selected.shape[0] - n_total = len(specprefill_tokens) - - # Phase 3: Sparse prefill on target model - t0 = time.monotonic() - logits = sparse_prefill( - model, - specprefill_tokens, - selected, - bc, - step_size=self._prefill_step_size, - position_offset=specprefill_offset, - ) - t_prefill = time.monotonic() - t0 + # Phase 4: Generate (simple autoregressive, no MTP) + eos_id = self._text_tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) - logger.info( - "SpecPrefill: scored %d tokens in %.1fs, " - "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " - "(offset=%d, effective_keep=%.2f)", - n_total, - t_score, - n_selected, - n_total, - n_selected / n_total * 100, - t_prefill, - specprefill_offset, - effective_keep, - ) + results = [] + generated_ids = [] + prev_decoded = "" - # Phase 4: Generate (simple autoregressive, no MTP) - eos_id = self._text_tokenizer.eos_token_id - y = sampler(logits[:, -1, :]) - mx.eval(y) + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) - results = [] - generated_ids = [] - prev_decoded = "" - - for _ in range(max_tokens): - tok_id = y.item() - generated_ids.append(tok_id) - - # Incremental text decode - decoded = self._text_tokenizer.decode(generated_ids) - new_text = decoded[len(prev_decoded) :] - prev_decoded = decoded - - is_eos = tok_id == eos_id - results.append( - SimpleNamespace( - text=new_text, - finish_reason="stop" if is_eos else None, - ) + # Incremental text decode + decoded = self._text_tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded + + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, ) + ) - if is_eos: - break + if is_eos: + break - # Next token - logits = model(y.reshape(1, -1), cache=bc) - y = sampler(logits[:, -1, :]) - mx.eval(y) + # Next token + logits = model(y.reshape(1, -1), cache=bc) + y = sampler(logits[:, -1, :]) + mx.eval(y) - return results + return results - finally: - cleanup_rope(model) + finally: + cleanup_rope(model) all_resps = await self._run_blocking_serialized(_run_all)