From 53252733687a073321a0b8990017ffd5e4addce9 Mon Sep 17 00:00:00 2001 From: Vinay Vobbilichetty Date: Sun, 10 May 2026 12:25:48 -0400 Subject: [PATCH 1/9] feat: extend system-prompt KV cache to pure-LLM stream_chat path The existing system-prefix KV cache (added in _stream_generate_text) covered the MLLM text-only path but not the pure-LLM stream_chat path, so non-MLLM models routing through stream_generate() re-prefilled the full system block on every turn. This mirrors the same hash-keyed, single-slot cache logic into stream_chat, after apply_chat_template: - detect a system prefix via ChatML markers - HIT: restore the snapshot and prefill only the suffix tokens - MISS: prefill the system tokens, snapshot per-layer KV state, then continue with the suffix - fallback: if anything looks off (no prefix, encode mismatch, cache call raises), drop through to the original uncached self.stream_generate() path unchanged Reuses the engine's existing _system_kv_snapshot / _system_kv_hash / _system_kv_token_count attributes - no __init__ changes, no new public surface. Holds no extra locks (the inner _run_blocking_serialized already takes _generation_lock). Measured locally on Qwen2.5-Coder-32B-Instruct-8bit driving Claude Code on Apple Silicon: ~100s+ follow-up-turn prefill -> ~7s once the system prefix is cached (~23K-token system+tools prefix). --- vllm_mlx/engine/simple.py | 167 +++++++++++++++++++++++++++++++++++++- 1 file changed, 166 insertions(+), 1 deletion(-) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 47343892..819a7fd0 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -860,7 +860,172 @@ def run_stream(): prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages) prompt += "\nassistant:" - # Stream generate + # --- System prompt KV caching for pure-LLM (mirrors _stream_generate_text) --- + # Detect system prefix via ChatML markers. If we have one, route through + # the cache-aware path: HIT restores the snapshot and prefills only the + # suffix; MISS prefills system, snapshots it, then continues with suffix. + # Falls back to the original self.stream_generate() if no system prefix + # is detected or anything else looks wrong. + import hashlib as _hashlib + + cache_hit = False + backbone_cache = None + suffix_tokens = None + system_tokens = None + system_token_count = 0 + system_hash = None + full_tokens_list = None + kv_cache_eligible = False + + has_system = any(m.get("role") == "system" for m in messages) + if has_system and hasattr(tokenizer, "apply_chat_template"): + system_prefix_end = -1 + for marker in ("<|im_start|>user\n", "<|im_start|>assistant\n"): + idx = prompt.find(marker) + if idx > 0: + system_prefix_end = idx + break + + if system_prefix_end > 0: + system_prefix_text = prompt[:system_prefix_end] + system_hash = _hashlib.sha256( + system_prefix_text.encode() + ).hexdigest()[:16] + + add_special = ( + tokenizer.bos_token is None + or not prompt.startswith(tokenizer.bos_token) + ) + full_tokens_list = tokenizer.encode( + prompt, add_special_tokens=add_special + ) + system_tokens_list = tokenizer.encode( + system_prefix_text, add_special_tokens=add_special + ) + system_token_count = len(system_tokens_list) + + if ( + len(full_tokens_list) > system_token_count + and full_tokens_list[:system_token_count] == system_tokens_list + ): + system_tokens = system_tokens_list + suffix_tokens = full_tokens_list[system_token_count:] + kv_cache_eligible = True + if ( + system_hash == self._system_kv_hash + and self._system_kv_snapshot is not None + and system_token_count == self._system_kv_token_count + ): + cache_hit = True + logger.info( + "System KV cache HIT (pure-LLM): reusing %d tokens, " + "prefilling %d new (hash=%s)", + system_token_count, len(suffix_tokens), system_hash, + ) + else: + logger.info( + "System KV cache MISS (pure-LLM): will prefill %d system + " + "%d suffix tokens (hash=%s)", + system_token_count, len(suffix_tokens), system_hash, + ) + + if kv_cache_eligible: + # Cache-aware path: drive mlx-lm directly with a pre-populated cache. + # NOTE: do NOT wrap in `async with self._generation_lock` — the inner + # _run_blocking_serialized already acquires it. Double-acquire deadlocks. + if True: + def _run_with_cache(): + import mlx.core as mx + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + + model = self._model.model + sampler = make_sampler(temp=temperature, top_p=top_p) + + if cache_hit: + bc = make_prompt_cache(model) + for i, saved_state in enumerate(self._system_kv_snapshot): + bc[i].state = saved_state + else: + bc = make_prompt_cache(model) + sys_arr = mx.array(system_tokens) + step = self._prefill_step_size + while sys_arr.size > step: + model(sys_arr[:step][None], cache=bc) + mx.eval([c.state for c in bc]) + sys_arr = sys_arr[step:] + mx.clear_cache() + if sys_arr.size > 0: + model(sys_arr[None], cache=bc) + mx.eval([c.state for c in bc]) + snapshot = [c.state for c in bc] + mx.eval([s for pair in snapshot for s in pair]) + self._system_kv_snapshot = snapshot + self._system_kv_hash = system_hash + self._system_kv_token_count = system_token_count + try: + cache_mb = sum(c.nbytes for c in bc) / 1e6 + except Exception: + cache_mb = -1 + logger.info( + "System KV cache STORED (pure-LLM): %d tokens (%.1f MB)", + system_token_count, cache_mb, + ) + + prompt_arr = mx.array(suffix_tokens) + return list(mlx_stream_generate( + model, + tokenizer, + prompt=prompt_arr, + max_tokens=max_tokens, + sampler=sampler, + prompt_cache=bc, + )) + + try: + all_resps = await self._run_blocking_serialized(_run_with_cache) + except Exception as exc: + logger.warning( + "Pure-LLM KV-cache path failed (%s); falling back to " + "uncached stream_generate", exc, + ) + async for output in self.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + **kwargs, + ): + yield output + return + + accumulated_text = "" + token_count = 0 + finished = False + full_token_count = ( + len(full_tokens_list) if full_tokens_list is not None else 0 + ) + for i, resp in enumerate(all_resps): + token_count += 1 + new_text = resp.text if hasattr(resp, "text") else str(resp) + accumulated_text += new_text + is_last = i == len(all_resps) - 1 + finished = is_last or token_count >= max_tokens + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=full_token_count, + completion_tokens=token_count, + finished=finished, + finish_reason=getattr(resp, "finish_reason", None) + or ("stop" if finished else None), + ) + if finished: + break + return + + # Fallback: no system prefix detected → original uncached path async for output in self.stream_generate( prompt=prompt, max_tokens=max_tokens, From 560237c78419d7f8d24881c602e20d212040a08b Mon Sep 17 00:00:00 2001 From: Vinay Vobbilichetty Date: Sun, 10 May 2026 14:33:08 -0400 Subject: [PATCH 2/9] review: address @janhilgard feedback on stream_chat KV cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Split-point detection (#2): replace ChatML-specific marker scan (`<|im_start|>...`) with probe-divergence — render the chat template with two different user contents and take the shared prefix. Mirrors `prompt_warmup._build_strict_prefix_string` and works across Qwen/ChatML, Llama, Gemma, and any other chat format with no per-model marker list. 2. Incremental streaming (#1): replace `list(mlx_stream_generate(...))` with the `asyncio.Queue` producer pattern used in `_stream_generate_text` — chunks are emitted via `loop.call_soon_threadsafe` from the thread running mlx-lm and yielded immediately to the caller. Adds a thread-safe `abort_event` tied to `_run_blocking_serialized`'s `on_cancel` hook. 3. Move `import hashlib` to the module top (#4); use `hashlib.sha256` directly instead of the local alias. 4. Remove the dead `if True:` indent block (#5) left over from earlier removing an async-with wrapper. The cache-aware-path failure fallback now distinguishes pre-first-token errors (safe to retry uncached) from mid-stream errors (re-raise — the client has already received partial output and switching paths would duplicate tokens). Canonicalization-before-hashing (#3) is deferred to the registry being designed in #524. --- vllm_mlx/engine/simple.py | 348 ++++++++++++++++++++++++-------------- 1 file changed, 218 insertions(+), 130 deletions(-) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 819a7fd0..bd68d388 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -7,6 +7,7 @@ """ import asyncio +import hashlib import logging import os import threading @@ -860,172 +861,259 @@ def run_stream(): prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages) prompt += "\nassistant:" - # --- System prompt KV caching for pure-LLM (mirrors _stream_generate_text) --- - # Detect system prefix via ChatML markers. If we have one, route through - # the cache-aware path: HIT restores the snapshot and prefills only the - # suffix; MISS prefills system, snapshots it, then continues with suffix. - # Falls back to the original self.stream_generate() if no system prefix - # is detected or anything else looks wrong. - import hashlib as _hashlib - + # --- System-prompt KV caching on the pure-LLM stream_chat path --- + # Mirrors the cache in _stream_generate_text. Locates the system prefix + # via probe-divergence (cf. prompt_warmup._build_strict_prefix_string): + # render the template with two different user contents and take the + # shared prefix. Works across Qwen/ChatML, Llama, Gemma, and any other + # chat format -- no per-model marker list. Falls back to the original + # uncached self.stream_generate() if the system prefix can't be + # isolated or any step of the cache-aware path raises. cache_hit = False - backbone_cache = None suffix_tokens = None system_tokens = None system_token_count = 0 + full_token_count = 0 system_hash = None - full_tokens_list = None kv_cache_eligible = False has_system = any(m.get("role") == "system" for m in messages) if has_system and hasattr(tokenizer, "apply_chat_template"): - system_prefix_end = -1 - for marker in ("<|im_start|>user\n", "<|im_start|>assistant\n"): - idx = prompt.find(marker) - if idx > 0: - system_prefix_end = idx - break + def _with_user(user_content: str) -> list[dict[str, Any]]: + msgs = [dict(m) for m in messages] + if msgs and msgs[-1].get("role") == "user": + msgs[-1] = {**msgs[-1], "content": user_content} + else: + msgs = [*msgs, {"role": "user", "content": user_content}] + return msgs - if system_prefix_end > 0: - system_prefix_text = prompt[:system_prefix_end] - system_hash = _hashlib.sha256( - system_prefix_text.encode() - ).hexdigest()[:16] - - add_special = ( - tokenizer.bos_token is None - or not prompt.startswith(tokenizer.bos_token) - ) - full_tokens_list = tokenizer.encode( - prompt, add_special_tokens=add_special + rendered_a: Any = None + rendered_b: Any = None + try: + rendered_a = tokenizer.apply_chat_template( + _with_user("Alpha"), **template_kwargs ) - system_tokens_list = tokenizer.encode( - system_prefix_text, add_special_tokens=add_special + rendered_b = tokenizer.apply_chat_template( + _with_user("Bravo"), **template_kwargs ) - system_token_count = len(system_tokens_list) + except Exception: + pass + + if isinstance(rendered_a, str) and isinstance(rendered_b, str): + boundary = 0 + diverged = False + for i in range(min(len(rendered_a), len(rendered_b))): + if rendered_a[i] != rendered_b[i]: + diverged = True + break + boundary = i + 1 + + if diverged and boundary >= 16: + system_prefix_text = rendered_a[:boundary] + system_hash = hashlib.sha256( + system_prefix_text.encode() + ).hexdigest()[:16] + + add_special = ( + tokenizer.bos_token is None + or not prompt.startswith(tokenizer.bos_token) + ) + full_tokens_list = tokenizer.encode( + prompt, add_special_tokens=add_special + ) + system_tokens_list = tokenizer.encode( + system_prefix_text, add_special_tokens=add_special + ) + full_token_count = len(full_tokens_list) + system_token_count = len(system_tokens_list) - if ( - len(full_tokens_list) > system_token_count - and full_tokens_list[:system_token_count] == system_tokens_list - ): - system_tokens = system_tokens_list - suffix_tokens = full_tokens_list[system_token_count:] - kv_cache_eligible = True if ( - system_hash == self._system_kv_hash - and self._system_kv_snapshot is not None - and system_token_count == self._system_kv_token_count + len(full_tokens_list) > system_token_count + and full_tokens_list[:system_token_count] == system_tokens_list ): - cache_hit = True - logger.info( - "System KV cache HIT (pure-LLM): reusing %d tokens, " - "prefilling %d new (hash=%s)", - system_token_count, len(suffix_tokens), system_hash, - ) - else: - logger.info( - "System KV cache MISS (pure-LLM): will prefill %d system + " - "%d suffix tokens (hash=%s)", - system_token_count, len(suffix_tokens), system_hash, - ) + system_tokens = system_tokens_list + suffix_tokens = full_tokens_list[system_token_count:] + kv_cache_eligible = True + if ( + system_hash == self._system_kv_hash + and self._system_kv_snapshot is not None + and system_token_count == self._system_kv_token_count + ): + cache_hit = True + logger.info( + "System KV cache HIT (stream_chat): reusing %d " + "tokens, prefilling %d new (hash=%s)", + system_token_count, + len(suffix_tokens), + system_hash, + ) + else: + logger.info( + "System KV cache MISS (stream_chat): will " + "prefill %d system + %d suffix tokens (hash=%s)", + system_token_count, + len(suffix_tokens), + system_hash, + ) if kv_cache_eligible: # Cache-aware path: drive mlx-lm directly with a pre-populated cache. - # NOTE: do NOT wrap in `async with self._generation_lock` — the inner - # _run_blocking_serialized already acquires it. Double-acquire deadlocks. - if True: - def _run_with_cache(): - import mlx.core as mx - from mlx_lm import stream_generate as mlx_stream_generate - from mlx_lm.models.cache import make_prompt_cache - from mlx_lm.sample_utils import make_sampler - - model = self._model.model - sampler = make_sampler(temp=temperature, top_p=top_p) - - if cache_hit: - bc = make_prompt_cache(model) - for i, saved_state in enumerate(self._system_kv_snapshot): - bc[i].state = saved_state - else: - bc = make_prompt_cache(model) - sys_arr = mx.array(system_tokens) - step = self._prefill_step_size - while sys_arr.size > step: - model(sys_arr[:step][None], cache=bc) - mx.eval([c.state for c in bc]) - sys_arr = sys_arr[step:] - mx.clear_cache() - if sys_arr.size > 0: - model(sys_arr[None], cache=bc) - mx.eval([c.state for c in bc]) - snapshot = [c.state for c in bc] - mx.eval([s for pair in snapshot for s in pair]) - self._system_kv_snapshot = snapshot - self._system_kv_hash = system_hash - self._system_kv_token_count = system_token_count - try: - cache_mb = sum(c.nbytes for c in bc) / 1e6 - except Exception: - cache_mb = -1 - logger.info( - "System KV cache STORED (pure-LLM): %d tokens (%.1f MB)", - system_token_count, cache_mb, - ) + # Stream chunks back to the caller via an asyncio.Queue (mirrors + # _stream_generate_text) so the client sees tokens as they arrive + # rather than after the full generation finishes. + loop = asyncio.get_running_loop() + response_queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue() + abort_event = threading.Event() + + def _emit_response(resp: Any) -> None: + if abort_event.is_set(): + return + loop.call_soon_threadsafe( + response_queue.put_nowait, ("resp", resp) + ) - prompt_arr = mx.array(suffix_tokens) - return list(mlx_stream_generate( - model, - tokenizer, - prompt=prompt_arr, - max_tokens=max_tokens, - sampler=sampler, - prompt_cache=bc, - )) + def _emit_done() -> None: + loop.call_soon_threadsafe( + response_queue.put_nowait, ("done", None) + ) + + def _emit_error(exc: BaseException) -> None: + loop.call_soon_threadsafe( + response_queue.put_nowait, ("error", exc) + ) + + def _run_with_cache() -> None: + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + + model = self._model.model + sampler = make_sampler(temp=temperature, top_p=top_p) + + if cache_hit: + bc = make_prompt_cache(model) + for i, saved_state in enumerate(self._system_kv_snapshot): + bc[i].state = saved_state + else: + bc = make_prompt_cache(model) + sys_arr = mx.array(system_tokens) + step = self._prefill_step_size + while sys_arr.size > step: + model(sys_arr[:step][None], cache=bc) + mx.eval([c.state for c in bc]) + sys_arr = sys_arr[step:] + mx.clear_cache() + if sys_arr.size > 0: + model(sys_arr[None], cache=bc) + mx.eval([c.state for c in bc]) + + snapshot = [c.state for c in bc] + mx.eval([s for pair in snapshot for s in pair]) + self._system_kv_snapshot = snapshot + self._system_kv_hash = system_hash + self._system_kv_token_count = system_token_count + try: + cache_mb = sum(c.nbytes for c in bc) / 1e6 + except Exception: + cache_mb = -1 + logger.info( + "System KV cache STORED (stream_chat): %d tokens " + "(%.1f MB)", + system_token_count, + cache_mb, + ) + + prompt_arr = mx.array(suffix_tokens) + for resp in mlx_stream_generate( + model, + tokenizer, + prompt=prompt_arr, + max_tokens=max_tokens, + sampler=sampler, + prompt_cache=bc, + ): + if abort_event.is_set(): + break + _emit_response(resp) + async def _produce_responses() -> None: try: - all_resps = await self._run_blocking_serialized(_run_with_cache) - except Exception as exc: - logger.warning( - "Pure-LLM KV-cache path failed (%s); falling back to " - "uncached stream_generate", exc, + await self._run_blocking_serialized( + _run_with_cache, + on_cancel=abort_event.set, ) - async for output in self.stream_generate( - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - **kwargs, - ): - yield output - return + except asyncio.CancelledError: + raise + except BaseException as exc: + _emit_error(exc) + else: + _emit_done() - accumulated_text = "" - token_count = 0 - finished = False - full_token_count = ( - len(full_tokens_list) if full_tokens_list is not None else 0 - ) - for i, resp in enumerate(all_resps): + producer_task = asyncio.create_task(_produce_responses()) + + accumulated_text = "" + token_count = 0 + finished = False + cache_path_failed_before_first_token = False + try: + while True: + kind, payload = await response_queue.get() + if kind == "done": + break + if kind == "error": + if token_count == 0: + logger.warning( + "Pure-LLM KV-cache path failed before first " + "token (%s); falling back to uncached " + "stream_generate", + payload, + ) + cache_path_failed_before_first_token = True + break + # Already streamed partial output; can't cleanly + # restart on the uncached path, so surface the error. + raise payload + resp = payload token_count += 1 new_text = resp.text if hasattr(resp, "text") else str(resp) accumulated_text += new_text - is_last = i == len(all_resps) - 1 - finished = is_last or token_count >= max_tokens + finish_reason = getattr(resp, "finish_reason", None) + finished = ( + finish_reason is not None or token_count >= max_tokens + ) + if finish_reason is None and finished: + finish_reason = "stop" + yield GenerationOutput( text=accumulated_text, new_text=new_text, prompt_tokens=full_token_count, completion_tokens=token_count, finished=finished, - finish_reason=getattr(resp, "finish_reason", None) - or ("stop" if finished else None), + finish_reason=finish_reason, ) if finished: break - return + finally: + if not producer_task.done(): + abort_event.set() + try: + await producer_task + except BaseException: + pass + + if cache_path_failed_before_first_token: + async for output in self.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + **kwargs, + ): + yield output + return - # Fallback: no system prefix detected → original uncached path + # Fallback: no system prefix detected -> original uncached path async for output in self.stream_generate( prompt=prompt, max_tokens=max_tokens, From aeb6f9189a7b1bb22e7a42c9fc52fa86368db709 Mon Sep 17 00:00:00 2001 From: Vinay Vobbilichetty Date: Sun, 10 May 2026 15:34:17 -0400 Subject: [PATCH 3/9] fix: normalize Pydantic Message objects in stream_chat cache path `stream_chat` is typed as accepting `list[dict[str, Any]]` but some internal callers (server.py's streaming endpoint, the streaming-chat test in test_server.py) pass Pydantic `Message` objects directly. Those don't expose dict's `.get()`, so the cache-eligibility detection raised `'Message' object has no attribute 'get'` and the test `TestStreamChatCompletion::test_streaming_chat_no_stream_thread_error_after_residency_preload` failed. Normalize each message to a plain dict (via `model_dump()` / `dict()` / getattr fallback) before the role lookup and the probe-divergence renders. Also apply black for the lint step. Surfaces only when the chat-completion request goes through `stream_chat` with raw Pydantic Message objects (i.e. not pre-converted by the caller). The MLLM path (`_stream_generate_text`) has the same `m.get(...)` pattern, but the test patches `is_mllm_model=False` so this fix is enough to green the CI. --- vllm_mlx/engine/simple.py | 46 +++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index bd68d388..ce28fb34 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -877,10 +877,28 @@ def run_stream(): system_hash = None kv_cache_eligible = False - has_system = any(m.get("role") == "system" for m in messages) + # Normalize messages to plain dicts. The public stream_chat signature + # types messages as list[dict], but internal callers (server.py, + # tests) sometimes pass Pydantic Message objects directly; those + # don't expose a dict-style .get() interface. + def _to_msg_dict(m: Any) -> dict[str, Any]: + if isinstance(m, dict): + return m + if hasattr(m, "model_dump"): + return m.model_dump() + if hasattr(m, "dict"): + return m.dict() + return { + "role": getattr(m, "role", None), + "content": getattr(m, "content", ""), + } + + messages_for_cache = [_to_msg_dict(m) for m in messages] + has_system = any(m.get("role") == "system" for m in messages_for_cache) if has_system and hasattr(tokenizer, "apply_chat_template"): + def _with_user(user_content: str) -> list[dict[str, Any]]: - msgs = [dict(m) for m in messages] + msgs = [dict(m) for m in messages_for_cache] if msgs and msgs[-1].get("role") == "user": msgs[-1] = {**msgs[-1], "content": user_content} else: @@ -914,9 +932,8 @@ def _with_user(user_content: str) -> list[dict[str, Any]]: system_prefix_text.encode() ).hexdigest()[:16] - add_special = ( - tokenizer.bos_token is None - or not prompt.startswith(tokenizer.bos_token) + add_special = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token ) full_tokens_list = tokenizer.encode( prompt, add_special_tokens=add_special @@ -968,19 +985,13 @@ def _with_user(user_content: str) -> list[dict[str, Any]]: def _emit_response(resp: Any) -> None: if abort_event.is_set(): return - loop.call_soon_threadsafe( - response_queue.put_nowait, ("resp", resp) - ) + loop.call_soon_threadsafe(response_queue.put_nowait, ("resp", resp)) def _emit_done() -> None: - loop.call_soon_threadsafe( - response_queue.put_nowait, ("done", None) - ) + loop.call_soon_threadsafe(response_queue.put_nowait, ("done", None)) def _emit_error(exc: BaseException) -> None: - loop.call_soon_threadsafe( - response_queue.put_nowait, ("error", exc) - ) + loop.call_soon_threadsafe(response_queue.put_nowait, ("error", exc)) def _run_with_cache() -> None: from mlx_lm import stream_generate as mlx_stream_generate @@ -1017,8 +1028,7 @@ def _run_with_cache() -> None: except Exception: cache_mb = -1 logger.info( - "System KV cache STORED (stream_chat): %d tokens " - "(%.1f MB)", + "System KV cache STORED (stream_chat): %d tokens " "(%.1f MB)", system_token_count, cache_mb, ) @@ -1078,9 +1088,7 @@ async def _produce_responses() -> None: new_text = resp.text if hasattr(resp, "text") else str(resp) accumulated_text += new_text finish_reason = getattr(resp, "finish_reason", None) - finished = ( - finish_reason is not None or token_count >= max_tokens - ) + finished = finish_reason is not None or token_count >= max_tokens if finish_reason is None and finished: finish_reason = "stop" From 9d9506c685f1496b02bd05e6139e4046383e0e6f Mon Sep 17 00:00:00 2001 From: Vinay Vobbilichetty Date: Sun, 10 May 2026 15:44:42 -0400 Subject: [PATCH 4/9] test: add stream_chat KV-cache unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three tests covering the new pure-LLM cache path: - `test_stream_chat_cache_path_accepts_pydantic_message_objects` — regression test for the Pydantic `Message` handling that the prior CI run caught. Verifies that callers passing `Message(role=..., content=...)` (as `server.py`'s streaming endpoint does) don't raise `AttributeError` at the `.get('role')` boundary. - `test_stream_chat_skips_cache_path_when_no_system_message` — `has_system=False` must short-circuit; `apply_chat_template` should only be called once (for the initial prompt), not three times (initial + Alpha probe + Bravo probe). - `test_stream_chat_cache_path_falls_back_when_mlx_raises` — when `_run_with_cache` raises before the first token (here forced via patching `make_prompt_cache` to raise), the pre-first-token error branch must route to the uncached `stream_generate` fallback rather than propagating the exception. --- tests/test_simple_engine.py | 210 ++++++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 1c8ca29b..07465769 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -197,6 +197,216 @@ async def fake_stream_chat(*args, **kwargs): assert output.finish_reason == "stop" mock_llm_model.chat.assert_not_called() + @pytest.mark.anyio + async def test_stream_chat_cache_path_accepts_pydantic_message_objects(self): + """`stream_chat`'s declared signature is ``list[dict]`` but real callers + (``server.py``'s streaming endpoint, ``test_server.py``'s direct + invocations) pass Pydantic ``Message`` objects. The system-prefix + KV-cache eligibility check on this path uses ``.get('role')`` / + ``dict(m)`` semantics; without normalisation the iteration raises + ``'Message' object has no attribute 'get'`` before the call ever + reaches the underlying ``stream_generate``.""" + from vllm_mlx.api.models import Message + from vllm_mlx.engine.simple import SimpleEngine + + # ``apply_chat_template`` returns identical strings for both + # probe-divergence renders → boundary stays at 0 → the cache path + # is correctly skipped and execution falls through to + # ``self.stream_generate``. The test's value is asserting no + # ``AttributeError`` leaks out of the message-normalisation step. + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nyou are an assistant<|im_end|>\n" + "<|im_start|>user\nhi<|im_end|>\n<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + captured_stream_generate = [] + + async def fake_stream_generate(*, prompt, **kwargs): + captured_stream_generate.append({"prompt": prompt, "kwargs": kwargs}) + out = MagicMock( + text="hi back", + new_text="hi back", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + messages = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="hi"), + ] + + chunks = [c async for c in engine.stream_chat(messages=messages)] + + # No AttributeError was raised → normalisation worked. + # apply_chat_template was called at least 3 times: once for the + # initial ``prompt`` build and twice more for the Alpha/Bravo + # probe-divergence renders. + assert tokenizer.apply_chat_template.call_count >= 3 + assert len(captured_stream_generate) == 1 + assert chunks and chunks[0].text == "hi back" + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_no_system_message(self): + """If the message list has no system role, the cache-eligibility + check must short-circuit ``has_system = False`` without entering the + probe-divergence step or the cache-aware streaming branch.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>user\nhello<|im_end|>\n<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + called_stream_generate = [] + + async def fake_stream_generate(*, prompt, **kwargs): + called_stream_generate.append(prompt) + out = MagicMock( + text="hi", + new_text="hi", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[{"role": "user", "content": "hello"}], + ) + ] + + # No system → cache path skipped → only the initial apply_chat_template + # call (for ``prompt``) happens, no probe renders. + assert tokenizer.apply_chat_template.call_count == 1 + assert called_stream_generate + assert chunks and chunks[0].text == "hi" + + @pytest.mark.anyio + async def test_stream_chat_cache_path_falls_back_when_mlx_raises(self): + """When the cache-aware ``_run_with_cache`` body raises *before* the + first generated token, the path must surface the failure as a + pre-first-token error and fall back to the uncached + ``self.stream_generate`` instead of bubbling the exception out.""" + from vllm_mlx.engine.simple import SimpleEngine + + # Probe-divergence renders that DO diverge, so the cache path is + # entered. The boundary lies after the system block, well past the + # 16-char minimum. + def apply_chat_template_side_effect(messages, **kwargs): + # Find the last user message content to make probes diverge. + user_content = "" + for m in reversed(messages): + role = ( + m.get("role") if isinstance(m, dict) else getattr(m, "role", None) + ) + if role == "user": + content = ( + m.get("content") + if isinstance(m, dict) + else getattr(m, "content", "") + ) + user_content = content or "" + break + return ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + tokenizer = MagicMock() + tokenizer.apply_chat_template.side_effect = apply_chat_template_side_effect + tokenizer.bos_token = None + # Return long enough token lists that the system-prefix slice is + # a proper prefix of the full sequence and ``kv_cache_eligible`` + # becomes True. + tokenizer.encode = MagicMock( + side_effect=[ + list(range(50)), # full prompt + list(range(20)), # system prefix (proper prefix of above) + ] + ) + + model = MagicMock() + model.tokenizer = tokenizer + # ``self._model.model`` is dereferenced inside _run_with_cache. + model.model = MagicMock() + + fallback_calls = [] + + async def fake_stream_generate(*, prompt, **kwargs): + fallback_calls.append(prompt) + out = MagicMock( + text="fallback-response", + new_text="fallback-response", + prompt_tokens=50, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + # Force the cache-aware path to raise before the first emit so we + # exercise the pre-first-token error → uncached fallback branch. + def make_prompt_cache_raises(*args, **kwargs): + raise RuntimeError("simulated mlx-lm failure") + + with ( + patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False), + patch( + "mlx_lm.models.cache.make_prompt_cache", + side_effect=make_prompt_cache_raises, + ), + patch("mlx_lm.sample_utils.make_sampler", return_value=MagicMock()), + ): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + # The cache-path failure must NOT propagate; we should see the + # fallback's chunk instead. + assert fallback_calls, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "fallback-response" + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" From 0577644650902d117bb4c18cadf8d9f5ec7a1f49 Mon Sep 17 00:00:00 2001 From: Vinay Vobbilichetty Date: Sun, 10 May 2026 16:46:24 -0400 Subject: [PATCH 5/9] perf: clear MLX intermediate cache before snapshotting system KV Per @janhilgard's non-blocking review suggestion. After chunked prefill of the system prefix, intermediate activations can hold ~1-2 GB of temporary memory on long (~23K-token) prefixes. Calling mx.clear_cache() before capturing the per-layer state mirrors the existing MLLM path's pattern between chunked prefill and decode, and prevents a peak-memory spike on the first cache MISS after startup. The snapshot itself is unaffected (it reads c.state on each cache slot, which is preserved). --- vllm_mlx/engine/simple.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index ce28fb34..9edf0e7c 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -1018,6 +1018,10 @@ def _run_with_cache() -> None: model(sys_arr[None], cache=bc) mx.eval([c.state for c in bc]) + # Free intermediate prefill activations before snapshotting; + # mirrors the MLLM path between chunked prefill and decode. + mx.clear_cache() + snapshot = [c.state for c in bc] mx.eval([s for pair in snapshot for s in pair]) self._system_kv_snapshot = snapshot From d76885057194c72b55eb1021bd99cf62c0da5a74 Mon Sep 17 00:00:00 2001 From: Vinay Vobbilichetty Date: Mon, 11 May 2026 20:55:53 -0400 Subject: [PATCH 6/9] fix: gate stream_chat KV cache on absence of decode controls The cache-eligible branch in stream_chat called mlx_lm.stream_generate directly with only prompt/max_tokens/sampler/prompt_cache, silently dropping `stop`, request-local `logits_processors` (parser stop tokens and JSON-constrained decoding attached by server.py), and the `top_k`/`min_p`/`presence_penalty`/`repetition_penalty` sampling controls that the uncached path threads through self.stream_generate. Same request could decode under different constraints depending on whether it hit the cache branch. Gate the cache branch on absence of those controls so both paths share identical decode semantics. The gate compares against server.py's no-op defaults (top_k=0, min_p=0.0, presence_penalty=0.0, repetition_penalty=1.0) rather than `key in kwargs`, so the common path still hits the cache. Adds two regression tests: stop+logits_processors forces the uncached fallback (probe-divergence never runs), and the no-op defaults still exercise the cache path. Reported by @Thump604. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_simple_engine.py | 140 ++++++++++++++++++++++++++++++++++++ vllm_mlx/engine/simple.py | 42 ++++++++++- 2 files changed, 181 insertions(+), 1 deletion(-) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 07465769..366e203b 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -407,6 +407,146 @@ def make_prompt_cache_raises(*args, **kwargs): assert fallback_calls, "uncached stream_generate fallback was not invoked" assert chunks and chunks[0].text == "fallback-response" + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_decode_controls_present(self): + """If the request carries ``stop`` or ``logits_processors`` (or any non-default + ``top_k`` / ``min_p`` / ``presence_penalty`` / ``repetition_penalty``), the + cache branch must be skipped so those controls flow through + ``self.stream_generate``. + The cache branch drives ``mlx_lm.stream_generate`` directly with only + prompt/max_tokens/sampler/prompt_cache, silently dropping any other decode + controls. + Gating here keeps cache-eligible and uncached requests on identical decode + semantics.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + # Render contains a system block so ``has_system`` would otherwise send us + # into the cache branch. + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + sentinel_processor = MagicMock(name="logits_processor") + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + stop=["<|im_end|>"], + logits_processors=[sentinel_processor], + ) + ] + + # Cache-path probe (``apply_chat_template`` called a second + third time for + # probe-divergence) must NOT happen — only the initial prompt render. + assert tokenizer.apply_chat_template.call_count == 1 + # The uncached fallback must have been invoked. + # The decode-control kwargs must have been threaded through. + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert fallback_kwargs[0].get("stop") == ["<|im_end|>"] + assert fallback_kwargs[0].get("logits_processors") == [sentinel_processor] + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_takes_cache_path_when_decode_controls_are_no_ops(self): + """server.py always sets ``top_k=0``, ``min_p=0.0``, ``presence_penalty=0.0``, + ``repetition_penalty=1.0`` (no-ops) in ``chat_kwargs``. + The gate must compare against those defaults so the common path still hits the + cache. + Only *active* controls should block.""" + from vllm_mlx.engine.simple import SimpleEngine + + def apply_chat_template_side_effect(messages, **kwargs): + user_content = "" + for m in reversed(messages): + role = ( + m.get("role") if isinstance(m, dict) else getattr(m, "role", None) + ) + if role == "user": + content = ( + m.get("content") + if isinstance(m, dict) + else getattr(m, "content", "") + ) + user_content = content or "" + break + return ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + tokenizer = MagicMock() + tokenizer.apply_chat_template.side_effect = apply_chat_template_side_effect + tokenizer.bos_token = None + tokenizer.encode = MagicMock(side_effect=[list(range(50)), list(range(20))]) + + model = MagicMock() + model.tokenizer = tokenizer + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + + # No fallback needed since cache path should be exercised. + # Patch _run_blocking_serialized to short-circuit cache execution cleanly + # without needing a real mlx_lm. + async def short_circuit(func, *args, on_cancel=None, **kw): + # Simulate immediate completion with no responses. + # The producer-task harness will fire _emit_done(). + return None + + engine._run_blocking_serialized = short_circuit # type: ignore[method-assign] + + _ = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + top_k=0, + min_p=0.0, + presence_penalty=0.0, + repetition_penalty=1.0, + ) + ] + + # Probe-divergence ran ⇒ apply_chat_template called for prompt + 2 probes. + assert tokenizer.apply_chat_template.call_count == 3 + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 9edf0e7c..386e430e 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -877,6 +877,42 @@ def run_stream(): system_hash = None kv_cache_eligible = False + # Decode-control gate. + # The cache branch below drives ``mlx_lm.stream_generate`` directly with only + # ``prompt``, ``max_tokens``, ``sampler`` (built from temperature+top_p), and + # ``prompt_cache``. + # The uncached fallback threads ``**kwargs`` through ``self.stream_generate``, + # which preserves ``stop``, request-local ``logits_processors`` (parser stop + # tokens and JSON-constrained decoding attached by server.py per request), and + # the ``top_k`` / ``min_p`` / ``presence_penalty`` / ``repetition_penalty`` + # sampling controls. + # If the cache branch ran with any of those active, cache-eligible and uncached + # requests would silently decode under different constraints. + # Skip the cache branch in that case so both paths share identical decode + # semantics. + # server.py always supplies the no-op defaults (``top_k=0``, ``min_p=0.0``, + # ``presence_penalty=0.0``, ``repetition_penalty=1.0``); compare against those + # rather than ``key in kwargs`` so the common path still hits the cache. + cache_blocking_controls: list[str] = [] + if kwargs.get("stop"): + cache_blocking_controls.append("stop") + if kwargs.get("logits_processors"): + cache_blocking_controls.append("logits_processors") + if (kwargs.get("top_k") or 0) > 0: + cache_blocking_controls.append("top_k") + if (kwargs.get("min_p") or 0.0) > 0.0: + cache_blocking_controls.append("min_p") + if (kwargs.get("presence_penalty") or 0.0) != 0.0: + cache_blocking_controls.append("presence_penalty") + if (kwargs.get("repetition_penalty") or 1.0) != 1.0: + cache_blocking_controls.append("repetition_penalty") + if cache_blocking_controls: + logger.info( + "System KV cache SKIP (stream_chat): request has decode controls the " + "cache branch cannot honor (%s); using uncached path", + cache_blocking_controls, + ) + # Normalize messages to plain dicts. The public stream_chat signature # types messages as list[dict], but internal callers (server.py, # tests) sometimes pass Pydantic Message objects directly; those @@ -895,7 +931,11 @@ def _to_msg_dict(m: Any) -> dict[str, Any]: messages_for_cache = [_to_msg_dict(m) for m in messages] has_system = any(m.get("role") == "system" for m in messages_for_cache) - if has_system and hasattr(tokenizer, "apply_chat_template"): + if ( + has_system + and not cache_blocking_controls + and hasattr(tokenizer, "apply_chat_template") + ): def _with_user(user_content: str) -> list[dict[str, Any]]: msgs = [dict(m) for m in messages_for_cache] From 429f75c0b957e26e1b2e119ad3e3c37f57cc255f Mon Sep 17 00:00:00 2001 From: Vinay Vobbilichetty Date: Tue, 12 May 2026 07:33:40 -0400 Subject: [PATCH 7/9] fix: gate stream_chat KV cache on engine-feature compatibility The cache branch drives ``mlx_lm.stream_generate`` directly and bypasses engine-level features the ``self.stream_generate`` wrapper layers on top: ``self._mtp`` (multi-token prediction), loaded SpecPrefill draft model, per-request ``specprefill`` override from ``extra_body``, and configured ``self._max_kv_size``. Expand the existing decode-control gate to also skip the cache branch when any of those are active so cache-eligible and uncached requests decode under identical engine semantics. Engine-init defaults (``mtp=False``, no draft model, ``max_kv_size=0``) and the common per-request shape keep hitting the cache as before; the gate only fires when a feature/limit is actually configured. Adds three regression tests modeled on the existing decode-control ones: MTP active, SpecPrefill draft loaded, and ``max_kv_size`` configured. Each asserts the cache probes are skipped (``apply_chat_template`` called only once for the prompt render) and the uncached wrapper runs. --- tests/test_simple_engine.py | 173 ++++++++++++++++++++++++++++++++++++ vllm_mlx/engine/simple.py | 37 +++++++- 2 files changed, 208 insertions(+), 2 deletions(-) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 366e203b..9dcfb0bd 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -547,6 +547,179 @@ async def short_circuit(func, *args, on_cancel=None, **kw): # Probe-divergence ran ⇒ apply_chat_template called for prompt + 2 probes. assert tokenizer.apply_chat_template.call_count == 3 + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_mtp_active(self): + """When ``self._mtp`` is configured, the cache branch must be skipped. + The branch calls ``mlx_lm.stream_generate`` directly with no ``mtp`` / + ``num_draft_tokens`` kwargs, while ``MLXLanguageModel.stream_generate`` + attaches them from ``self._mtp`` / ``self._mtp_num_draft_tokens``. + Running the same request through the cache branch would silently drop + speculative decoding for cache-eligible turns while keeping it on + uncached turns — different engine semantics for the same request.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model", mtp=True, mtp_num_draft_tokens=4) + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + # Cache-path probes must NOT run — only the initial prompt render. + assert tokenizer.apply_chat_template.call_count == 1 + # The uncached wrapper must have been invoked. + # MTP kwargs are layered on inside ``MLXLanguageModel.stream_generate``, + # not at this seam, so this test only proves the cache branch was + # bypassed; the wrapper attaches MTP itself when ``self._mtp`` is set. + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_specprefill_loaded(self): + """A loaded SpecPrefill draft model (``self._draft_model is not None``) + triggers ``_stream_generate_specprefill`` routing inside the wrapper for + large prompts. + The cache branch has no equivalent routing, so it must be skipped + whenever a draft model is loaded so all requests go through the + wrapper's SpecPrefill decision.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._draft_model = MagicMock(name="specprefill_draft_model") + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + assert tokenizer.apply_chat_template.call_count == 1 + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_max_kv_size_set(self): + """Configured ``max_kv_size`` caps the prompt cache. + The cache branch builds its cache with ``make_prompt_cache(model)`` + without forwarding ``max_kv_size``, so a non-zero engine-level bound + must force the uncached path.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model", max_kv_size=4096) + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + assert tokenizer.apply_chat_template.call_count == 1 + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "ok" + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 386e430e..df68a173 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -906,10 +906,43 @@ def run_stream(): cache_blocking_controls.append("presence_penalty") if (kwargs.get("repetition_penalty") or 1.0) != 1.0: cache_blocking_controls.append("repetition_penalty") + + # Engine-feature gate. + # The cache branch also bypasses engine-level features that + # ``self.stream_generate`` (and the ``MLXLanguageModel.stream_generate`` + # wrapper underneath it) layer on top of ``mlx_lm.stream_generate``. + # Same correctness reasoning as the decode-control gate: cache-eligible + # and uncached requests must decode under identical engine semantics, so + # skip the cache branch when any of these are active. + # Specifically: + # - ``self._mtp`` injects ``mtp=True`` and ``num_draft_tokens`` into + # the mlx-lm call (see ``MLXLanguageModel.stream_generate``). + # - A loaded SpecPrefill draft model (``self._draft_model is not None``, + # set when ``specprefill_enabled`` + ``specprefill_draft_model`` are + # configured at engine init) routes large prompts through + # ``_stream_generate_specprefill`` instead of the plain stream path. + # - A per-request ``specprefill`` override from ``extra_body`` (popped + # by the wrapper from ``kwargs``) can force or suppress SpecPrefill + # for a single request. + # ``specprefill=False`` is a meaningful suppression signal — gate on + # ``is not None`` rather than truthiness so the wrapper sees it. + # - ``self._max_kv_size`` (when > 0) caps the prompt cache; the cache + # branch builds its cache with ``make_prompt_cache(model)`` and has + # no equivalent bound. + if self._mtp: + cache_blocking_controls.append("mtp") + if self._draft_model is not None: + cache_blocking_controls.append("specprefill_loaded") + if kwargs.get("specprefill") is not None: + cache_blocking_controls.append("specprefill_request_override") + if (self._max_kv_size or 0) > 0: + cache_blocking_controls.append("max_kv_size") + if cache_blocking_controls: logger.info( - "System KV cache SKIP (stream_chat): request has decode controls the " - "cache branch cannot honor (%s); using uncached path", + "System KV cache SKIP (stream_chat): request or engine has " + "controls/features the cache branch cannot honor (%s); using " + "uncached path", cache_blocking_controls, ) From 00717b944e102b0d592d8766ddc9036c51c85548 Mon Sep 17 00:00:00 2001 From: Vinay Vobbilichetty Date: Thu, 14 May 2026 08:23:51 -0400 Subject: [PATCH 8/9] fix: harden stream_chat KV cache against TOCTOU + rotating caches Addresses three concerns raised in maintainer review of #523: 1. TOCTOU between HIT gate and snapshot restore. The gate at L1027-L1032 reads ``self._system_kv_snapshot`` outside ``_run_blocking_serialized`` while the restore re-reads it later, inside the serialized worker. A concurrent MISS that reassigns the attribute in between would load a snapshot for a different system prefix under the hash that decided HIT. Fixed by capturing the snapshot into a closure-local ``hit_snapshot`` at gate time and using that for the restore. 2. ``_max_kv_size`` gate alone doesn't catch sliding-window models. ``gemma3_text``, ``olmo3``, and ``recurrent_gemma`` return ``RotatingKVCache`` from ``make_cache()`` regardless of ``max_kv_size``; ``.state`` aliases buffers that ``update_and_fetch`` mutates in place. Probe once in ``start()`` via ``make_prompt_cache(model)`` and require every entry to be a plain ``KVCache``; the engine-feature gate adds ``non_kv_cache_class`` to ``cache_blocking_controls`` when the probe says no. 3. Comment on ``mx.clear_cache()`` claimed it mirrored the MLLM path, but the MLLM path doesn't ``mx.clear_cache()`` between its last prefill chunk and the snapshot. Reworded to call out that this path is intentionally stricter. Tests: - Add ``test_stream_chat_uses_gate_time_snapshot_under_concurrent_mutation``: simulates a concurrent MISS by reassigning ``_system_kv_snapshot`` inside the ``_run_blocking_serialized`` hook, then asserts the restore wrote the gate-time entries. - Add ``test_stream_chat_skips_cache_path_when_model_has_non_kv_cache``: proves the gate fires when ``_supports_system_kv_cache=False``. - Existing positive/negative tests set ``_supports_system_kv_cache=True`` to isolate the feature under test (otherwise the new gate would short-circuit them for the wrong reason). --- tests/test_simple_engine.py | 191 ++++++++++++++++++++++++++++++++++++ vllm_mlx/engine/simple.py | 73 +++++++++++++- 2 files changed, 260 insertions(+), 4 deletions(-) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 9dcfb0bd..0a2cb849 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -519,6 +519,9 @@ def apply_chat_template_side_effect(messages, **kwargs): engine = SimpleEngine("test-model") engine._model = model engine._loaded = True + # Probe normally runs in start(); short-circuit it here so the + # gate doesn't trip on the synthetic engine state. + engine._supports_system_kv_cache = True # No fallback needed since cache path should be exercised. # Patch _run_blocking_serialized to short-circuit cache execution cleanly @@ -588,6 +591,8 @@ async def fake_stream_generate(*, prompt, **kw): engine = SimpleEngine("test-model", mtp=True, mtp_num_draft_tokens=4) engine._model = model engine._loaded = True + # Isolate the gate to the feature under test. + engine._supports_system_kv_cache = True engine.stream_generate = fake_stream_generate # type: ignore[method-assign] chunks = [ @@ -650,6 +655,7 @@ async def fake_stream_generate(*, prompt, **kw): engine._model = model engine._draft_model = MagicMock(name="specprefill_draft_model") engine._loaded = True + engine._supports_system_kv_cache = True engine.stream_generate = fake_stream_generate # type: ignore[method-assign] chunks = [ @@ -704,6 +710,7 @@ async def fake_stream_generate(*, prompt, **kw): engine = SimpleEngine("test-model", max_kv_size=4096) engine._model = model engine._loaded = True + engine._supports_system_kv_cache = True engine.stream_generate = fake_stream_generate # type: ignore[method-assign] chunks = [ @@ -720,6 +727,190 @@ async def fake_stream_generate(*, prompt, **kw): assert fallback_kwargs, "uncached stream_generate fallback was not invoked" assert chunks and chunks[0].text == "ok" + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_model_has_non_kv_cache(self): + """Models whose ``make_prompt_cache`` returns ``RotatingKVCache`` + (sliding-window models like gemma3_text, olmo3, recurrent_gemma) cannot + be safely snapshotted: ``.state`` aliases buffers that + ``update_and_fetch`` mutates in place, so restoring a captured snapshot + on the next turn would silently desynchronize from the running cache. + ``start()`` probes ``make_prompt_cache`` once and sets + ``_supports_system_kv_cache=False`` for those models; the gate must + then skip the cache branch.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + # Simulate the probe finding non-KVCache entries (rotating cache). + engine._supports_system_kv_cache = False + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + # Cache-path probes must NOT run when the model isn't snapshot-safe. + assert tokenizer.apply_chat_template.call_count == 1 + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_uses_gate_time_snapshot_under_concurrent_mutation( + self, + ): + """A concurrent MISS that reassigns ``self._system_kv_snapshot`` between + the cache-hit gate (which runs outside ``_run_blocking_serialized``) and + the snapshot restore (which runs inside the serialized worker) must not + corrupt the HIT. The restore must use the snapshot reference captured at + gate time, not re-read ``self._system_kv_snapshot`` later — otherwise a + different system prefix's KV would be loaded under the hash that decided + HIT. + + Simulates the race by reassigning ``engine._system_kv_snapshot`` inside + the ``_run_blocking_serialized`` hook (executed after the gate but + before the worker enters the cache branch), then asserts the restore + loop wrote the gate-time entries, not the post-gate intruder.""" + import hashlib + + from vllm_mlx.engine.simple import SimpleEngine + + # Same template the positive test uses: divergence falls at the user + # content so the detected system prefix is the leading frame up through + # ``<|im_start|>user\n``. + def apply_chat_template_side_effect(messages, **kwargs): + user_content = "" + for m in reversed(messages): + role = ( + m.get("role") if isinstance(m, dict) else getattr(m, "role", None) + ) + if role == "user": + content = ( + m.get("content") + if isinstance(m, dict) + else getattr(m, "content", "") + ) + user_content = content or "" + break + return ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + expected_prefix = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" "<|im_start|>user\n" + ) + expected_hash = hashlib.sha256(expected_prefix.encode()).hexdigest()[:16] + + tokenizer = MagicMock() + tokenizer.apply_chat_template.side_effect = apply_chat_template_side_effect + tokenizer.bos_token = None + # First encode = full prompt tokens; second = system prefix tokens. + # range(20) is a prefix of range(50), so prefix-match validation passes. + tokenizer.encode = MagicMock(side_effect=[list(range(50)), list(range(20))]) + + model = MagicMock() + model.tokenizer = tokenizer + + original_snapshot = [("ORIGINAL_K", "ORIGINAL_V")] + intruder_snapshot = [("INTRUDER_K", "INTRUDER_V")] + + captured_states: list = [] + + class MockCacheEntry: + def __init__(self) -> None: + self._state = None + + @property + def state(self): + return self._state + + @state.setter + def state(self, value) -> None: + captured_states.append(value) + self._state = value + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine._supports_system_kv_cache = True + # Pre-seed HIT state matching the divergence-detected prefix. + engine._system_kv_hash = expected_hash + engine._system_kv_token_count = 20 + engine._system_kv_snapshot = original_snapshot + + async def serialized_with_race(func, *args, on_cancel=None, **kw): + # Simulate a concurrent MISS overwriting the instance attribute + # AFTER the gate's HIT decision but BEFORE the worker restores. + engine._system_kv_snapshot = intruder_snapshot + await asyncio.to_thread(func) + return None + + engine._run_blocking_serialized = ( + serialized_with_race # type: ignore[method-assign] + ) + + with ( + patch("mlx_lm.stream_generate", return_value=iter([])), + patch( + "mlx_lm.models.cache.make_prompt_cache", + return_value=[MockCacheEntry()], + ), + patch("mlx_lm.sample_utils.make_sampler"), + ): + _ = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + # Restore wrote the gate-time snapshot exactly once. + # If the worker had re-read ``self._system_kv_snapshot`` we would see + # ``("INTRUDER_K", "INTRUDER_V")`` instead — that's the TOCTOU bug. + assert captured_states == [("ORIGINAL_K", "ORIGINAL_V")], ( + "Snapshot restore did not use the gate-time reference; " + f"captured={captured_states}" + ) + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index df68a173..14c7760b 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -179,6 +179,12 @@ def __init__( self._system_kv_snapshot = None # List of (keys, values) per backbone layer self._system_kv_hash = None # Hash of system prefix text self._system_kv_token_count = 0 # Tokens in cached prefix + # True only when the model's prompt cache is composed entirely of + # plain ``KVCache`` entries. Sliding-window models (gemma3_text, + # olmo3, recurrent_gemma) return ``RotatingKVCache`` whose ``.state`` + # aliases buffers ``update_and_fetch`` mutates in place — snapshot + # restore would silently desynchronize. Probed once in ``start()``. + self._supports_system_kv_cache: bool = False @property def model_name(self) -> str: @@ -254,6 +260,36 @@ async def start(self) -> None: self._mtp_num_draft_tokens, ) + # Probe whether this model's prompt cache is snapshot-safe for the + # stream_chat system-prefix cache branch. Sliding-window models + # (gemma3_text, olmo3, recurrent_gemma) return RotatingKVCache + # entries whose ``.state`` aliases in-place-mutated buffers. + # Only relevant for the LLM path; MLLM never enters the cache + # branch. + if not self._is_mllm and self._model is not None: + try: + from mlx_lm.models.cache import KVCache, make_prompt_cache + + probe_cache = make_prompt_cache(self._model.model) + self._supports_system_kv_cache = bool(probe_cache) and all( + isinstance(c, KVCache) for c in probe_cache + ) + if not self._supports_system_kv_cache: + cache_types = sorted({type(c).__name__ for c in probe_cache}) + logger.info( + "System KV cache snapshot disabled: model returned " + "non-KVCache entries (%s); stream_chat will use the " + "uncached path", + cache_types, + ) + except Exception as e: + logger.debug( + "System KV cache support probe failed (%s); " + "disabling snapshot path", + e, + ) + self._supports_system_kv_cache = False + # Build parallel mlx_lm TextModel for text-only routing. # Even when MTP is disabled, text-only requests should not be trapped # on the slower mlx_vlm multimodal path. @@ -339,6 +375,7 @@ async def stop(self) -> None: self._system_kv_snapshot = None self._system_kv_hash = None self._system_kv_token_count = 0 + self._supports_system_kv_cache = False logger.info("SimpleEngine stopped") async def _run_blocking_serialized(self, func, /, *args, on_cancel=None, **kwargs): @@ -876,6 +913,11 @@ def run_stream(): full_token_count = 0 system_hash = None kv_cache_eligible = False + # Snapshot reference captured at gate time so a concurrent MISS that + # reassigns ``self._system_kv_snapshot`` between the gate and the + # restore (which runs later inside ``_run_blocking_serialized``) + # can't desynchronize the restored KV from the hash that decided HIT. + hit_snapshot: Any = None # Decode-control gate. # The cache branch below drives ``mlx_lm.stream_generate`` directly with only @@ -937,6 +979,13 @@ def run_stream(): cache_blocking_controls.append("specprefill_request_override") if (self._max_kv_size or 0) > 0: cache_blocking_controls.append("max_kv_size") + # Sliding-window models build their prompt cache from RotatingKVCache + # entries whose ``.state`` aliases buffers that ``update_and_fetch`` + # mutates in place. Snapshot capture would corrupt the cached prefix + # on the next decode. Probed once at start; ``False`` if the model + # exposes any non-KVCache entries or the probe failed. + if not self._supports_system_kv_cache: + cache_blocking_controls.append("non_kv_cache_class") if cache_blocking_controls: logger.info( @@ -1024,12 +1073,20 @@ def _with_user(user_content: str) -> list[dict[str, Any]]: system_tokens = system_tokens_list suffix_tokens = full_tokens_list[system_token_count:] kv_cache_eligible = True + # Read the snapshot reference once. If we promote to + # HIT, ``hit_snapshot`` is the exact list the hash + # check just validated against. A later concurrent + # MISS that reassigns ``self._system_kv_snapshot`` + # before our serialized worker restores it cannot + # alias what we captured here. + candidate_snapshot = self._system_kv_snapshot if ( system_hash == self._system_kv_hash - and self._system_kv_snapshot is not None + and candidate_snapshot is not None and system_token_count == self._system_kv_token_count ): cache_hit = True + hit_snapshot = candidate_snapshot logger.info( "System KV cache HIT (stream_chat): reusing %d " "tokens, prefilling %d new (hash=%s)", @@ -1076,7 +1133,12 @@ def _run_with_cache() -> None: if cache_hit: bc = make_prompt_cache(model) - for i, saved_state in enumerate(self._system_kv_snapshot): + # Restore from the closure-local reference captured at the + # gate, never from ``self._system_kv_snapshot`` directly: + # a concurrent MISS could have replaced the instance + # attribute with a snapshot for a different system prefix + # between the gate check and this point. + for i, saved_state in enumerate(hit_snapshot): bc[i].state = saved_state else: bc = make_prompt_cache(model) @@ -1091,8 +1153,11 @@ def _run_with_cache() -> None: model(sys_arr[None], cache=bc) mx.eval([c.state for c in bc]) - # Free intermediate prefill activations before snapshotting; - # mirrors the MLLM path between chunked prefill and decode. + # Free intermediate prefill activations before snapshotting. + # Intentionally stricter than the MLLM path, which does not + # ``mx.clear_cache()`` between its last prefill chunk and + # the snapshot; here we want the snapshot to reflect only + # the KV state, not residual activations from prefill. mx.clear_cache() snapshot = [c.state for c in bc] From dbaf18ba8db7491699568d703fac73fb41e300ac Mon Sep 17 00:00:00 2001 From: Wayner Barrios Date: Thu, 14 May 2026 12:05:32 -0400 Subject: [PATCH 9/9] set _supports_system_kv_cache on the pydantic stream_chat test so the probe-divergence renders actually run --- tests/test_simple_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 0a2cb849..779d40e4 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -243,6 +243,7 @@ async def fake_stream_generate(*, prompt, **kwargs): engine = SimpleEngine("test-model") engine._model = model engine._loaded = True + engine._supports_system_kv_cache = True engine.stream_generate = fake_stream_generate # type: ignore[method-assign] messages = [