feat: extend system-prompt KV cache to pure-LLM stream_chat path#523
Conversation
The existing system-prefix KV cache (added in _stream_generate_text) covered the MLLM text-only path but not the pure-LLM stream_chat path, so non-MLLM models routing through stream_generate() re-prefilled the full system block on every turn. This mirrors the same hash-keyed, single-slot cache logic into stream_chat, after apply_chat_template: - detect a system prefix via ChatML markers - HIT: restore the snapshot and prefill only the suffix tokens - MISS: prefill the system tokens, snapshot per-layer KV state, then continue with the suffix - fallback: if anything looks off (no prefix, encode mismatch, cache call raises), drop through to the original uncached self.stream_generate() path unchanged Reuses the engine's existing _system_kv_snapshot / _system_kv_hash / _system_kv_token_count attributes - no __init__ changes, no new public surface. Holds no extra locks (the inner _run_blocking_serialized already takes _generation_lock). Measured locally on Qwen2.5-Coder-32B-Instruct-8bit driving Claude Code on Apple Silicon: ~100s+ follow-up-turn prefill -> ~7s once the system prefix is cached (~23K-token system+tools prefix).
1. Split-point detection (waybarrios#2): replace ChatML-specific marker scan (`<|im_start|>...`) with probe-divergence — render the chat template with two different user contents and take the shared prefix. Mirrors `prompt_warmup._build_strict_prefix_string` and works across Qwen/ChatML, Llama, Gemma, and any other chat format with no per-model marker list. 2. Incremental streaming (waybarrios#1): replace `list(mlx_stream_generate(...))` with the `asyncio.Queue` producer pattern used in `_stream_generate_text` — chunks are emitted via `loop.call_soon_threadsafe` from the thread running mlx-lm and yielded immediately to the caller. Adds a thread-safe `abort_event` tied to `_run_blocking_serialized`'s `on_cancel` hook. 3. Move `import hashlib` to the module top (waybarrios#4); use `hashlib.sha256` directly instead of the local alias. 4. Remove the dead `if True:` indent block (waybarrios#5) left over from earlier removing an async-with wrapper. The cache-aware-path failure fallback now distinguishes pre-first-token errors (safe to retry uncached) from mid-stream errors (re-raise — the client has already received partial output and switching paths would duplicate tokens). Canonicalization-before-hashing (waybarrios#3) is deferred to the registry being designed in waybarrios#524.
|
Thanks @janhilgard for the thorough review and for the context on the three invalidators the MLLM path already handles ( Pushed #1 streaming — replaced #2 model-agnostic split-point — replaced the ChatML marker scan with probe-divergence: render the chat template twice with two different user contents, take the shared prefix. This is the same technique #4 hashlib import + #5 dead #3 canonicalize-before-hashing — deferred to #524 as you suggested. Happy to take a swing at the simplified registry you sketched (single list, One nuance on the probe-divergence detection that's worth flagging: it works correctly when there's at least one user message in the conversation (the probes vary only the last user turn). If the message list has system messages followed by an assistant message and nothing else, the function appends a placeholder user message before probing. Edge case in practice but documented in the code path. Re-tested on the same Qwen2.5-Coder-32B-Instruct-8bit / Claude Code workload — cache hit/miss/store logging fires identically and now the OpenAI Chat Completions client sees streaming token deltas during generation instead of one big blob at the end. Will rerun the timing benchmark and post the numbers if you'd like a fresh measurement on this revision. |
|
Nice work on the revision. The probe-divergence approach for finding the system prefix boundary is the right call — it's the same technique we validated in The Deferring canonicalize-before-hashing to #524 makes sense — keeps this PR focused on the caching mechanism itself. Once #524 lands with even the minimal stripper list, it's a one-liner to wire One question on the probe-divergence edge case you flagged: if the message list has only system + assistant (no user message), is this a realistic scenario? In the OpenAI Chat Completions API the first non-system message is always user-originated. The placeholder insertion sounds safe but I want to make sure it doesn't shift token boundaries in the probe vs the real encode (e.g. if the template generates different separators for 2-message vs 3-message sequences). If you've verified the prefix length matches between probe and real encode in that case, it's fine. LGTM on the overall design. Happy to give a final approval once the diff is up for a quick re-check. |
`stream_chat` is typed as accepting `list[dict[str, Any]]` but some internal callers (server.py's streaming endpoint, the streaming-chat test in test_server.py) pass Pydantic `Message` objects directly. Those don't expose dict's `.get()`, so the cache-eligibility detection raised `'Message' object has no attribute 'get'` and the test `TestStreamChatCompletion::test_streaming_chat_no_stream_thread_error_after_residency_preload` failed. Normalize each message to a plain dict (via `model_dump()` / `dict()` / getattr fallback) before the role lookup and the probe-divergence renders. Also apply black for the lint step. Surfaces only when the chat-completion request goes through `stream_chat` with raw Pydantic Message objects (i.e. not pre-converted by the caller). The MLLM path (`_stream_generate_text`) has the same `m.get(...)` pattern, but the test patches `is_mllm_model=False` so this fix is enough to green the CI.
Three tests covering the new pure-LLM cache path:
- `test_stream_chat_cache_path_accepts_pydantic_message_objects` — regression test for the Pydantic `Message` handling that the prior CI run caught. Verifies that callers passing `Message(role=..., content=...)` (as `server.py`'s streaming endpoint does) don't raise `AttributeError` at the `.get('role')` boundary.
- `test_stream_chat_skips_cache_path_when_no_system_message` — `has_system=False` must short-circuit; `apply_chat_template` should only be called once (for the initial prompt), not three times (initial + Alpha probe + Bravo probe).
- `test_stream_chat_cache_path_falls_back_when_mlx_raises` — when `_run_with_cache` raises before the first token (here forced via patching `make_prompt_cache` to raise), the pre-first-token error branch must route to the uncached `stream_generate` fallback rather than propagating the exception.
|
Thanks @janhilgard — glad the design holds up. Answering the edge case directly: The placeholder-user injection in I empirically verified this against the Qwen2.5-Coder template for the two scenarios you flagged:
Both flow through to the uncached Also pushed
Ready for re-check whenever you have time. |
janhilgard
left a comment
There was a problem hiding this comment.
Solid revision — the streaming architecture, error handling, and model-agnostic detection all look correct now.
What works well:
- Probe-divergence correctly reuses the same
template_kwargsas the real render, soenable_thinking,tools, etc. don't throw off the boundary asyncio.Queue+call_soon_threadsafestreaming mirrors_stream_generate_textexactly- Pre-first-token vs mid-stream error distinction prevents token duplication on fallback
_to_msg_dict()normalizes Pydantic objects cleanlyon_cancel=abort_event.setintegrates with the existing cancellation path- Tests cover the three critical edge cases (Pydantic messages, no-system skip, error fallback)
One minor optimization suggestion (non-blocking):
On cache MISS, after the chunked system prefill loop, consider adding mx.clear_cache() before starting suffix generation:
if sys_arr.size > 0:
model(sys_arr[None], cache=bc)
mx.eval([c.state for c in bc])
mx.clear_cache() # free intermediate activations from prefill
snapshot = [c.state for c in bc]For 23K-token system prefixes, the intermediate activations from prefilling can be ~1-2 GB of temporary memory. The MLLM path does this between chunked prefill and generation. Not critical since it'll be reclaimed eventually by the allocator, but prevents a peak memory spike on the first MISS after startup.
LGTM — no blocking issues.
janhilgard
left a comment
There was a problem hiding this comment.
Solid revision — the streaming architecture, error handling, and model-agnostic detection all look correct now.
What works well:
- Probe-divergence correctly reuses the same template_kwargs as the real render, so enable_thinking, tools, etc. don't throw off the boundary
- asyncio.Queue + call_soon_threadsafe streaming mirrors _stream_generate_text exactly
- Pre-first-token vs mid-stream error distinction prevents token duplication on fallback
- _to_msg_dict() normalizes Pydantic objects cleanly
- on_cancel=abort_event.set integrates with the existing cancellation path
- Tests cover the three critical edge cases (Pydantic messages, no-system skip, error fallback)
One minor optimization suggestion (non-blocking):
On cache MISS, after the chunked system prefill loop, consider adding mx.clear_cache() before starting suffix generation:
if sys_arr.size > 0:
model(sys_arr[None], cache=bc)
mx.eval([c.state for c in bc])
mx.clear_cache() # free intermediate activations from prefill
snapshot = [c.state for c in bc]For 23K-token system prefixes, the intermediate activations from prefilling can be ~1-2 GB of temporary memory. The MLLM path does this between chunked prefill and generation. Not critical since it'll be reclaimed eventually by the allocator, but prevents a peak memory spike on the first MISS after startup.
LGTM — no blocking issues.
Per @janhilgard's non-blocking review suggestion. After chunked prefill of the system prefix, intermediate activations can hold ~1-2 GB of temporary memory on long (~23K-token) prefixes. Calling mx.clear_cache() before capturing the per-layer state mirrors the existing MLLM path's pattern between chunked prefill and decode, and prevents a peak-memory spike on the first cache MISS after startup. The snapshot itself is unaffected (it reads c.state on each cache slot, which is preserved).
|
Thanks @janhilgard — applied in 0577644. Confirmed the snapshot read still picks up the per-layer state cleanly after the clear, since |
Thump604
left a comment
There was a problem hiding this comment.
I ran the focused SimpleEngine tests locally against this PR: AI_RUNTIME_BYPASS_SAFETY_GATE=1 PYTHONPATH=/tmp/vllm-mlx-review-523 /opt/ai-runtime/venv-live/bin/python -m pytest tests/test_simple_engine.py -q -> 31 passed.
I also reproduced one behavior gap in the cache-eligible stream_chat branch. A request with stop, top_k, min_p, presence_penalty, repetition_penalty, and a request-local logits_processors list reaches the new direct mlx_lm.stream_generate(...) call with only prompt, max_tokens, sampler, and prompt_cache. The uncached path still calls self.stream_generate(..., **kwargs), and that wrapper preserves the request-local decode controls, stop handling, penalty processors, logits processors, and MTP/cache semantics.
That matters because server.py can attach parser stop tokens and JSON/constrained decoding logits processors per request, and it always threads the sampling/penalty fields through the engine call. With this branch enabled, a system-prefix-cache eligible request can silently decode under different constraints than the same request on the uncached path.
Please either preserve the same wrapper semantics in the cache branch, or explicitly skip the cache branch when unsupported request-local decode/constraint controls are present. I would also add a regression test where a cache-eligible stream_chat call with stop or logits_processors proves the control is preserved or the uncached path is used.
I am not claiming the system-prefix cache approach is wrong, only that this PR currently changes request-local decode/parser behavior on the cache branch.
The cache-eligible branch in stream_chat called mlx_lm.stream_generate directly with only prompt/max_tokens/sampler/prompt_cache, silently dropping `stop`, request-local `logits_processors` (parser stop tokens and JSON-constrained decoding attached by server.py), and the `top_k`/`min_p`/`presence_penalty`/`repetition_penalty` sampling controls that the uncached path threads through self.stream_generate. Same request could decode under different constraints depending on whether it hit the cache branch. Gate the cache branch on absence of those controls so both paths share identical decode semantics. The gate compares against server.py's no-op defaults (top_k=0, min_p=0.0, presence_penalty=0.0, repetition_penalty=1.0) rather than `key in kwargs`, so the common path still hits the cache. Adds two regression tests: stop+logits_processors forces the uncached fallback (probe-divergence never runs), and the no-op defaults still exercise the cache path. Reported by @Thump604. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Thanks for catching this, @Thump604 — confirmed and fixed in d768850. You're right on the substance: the cache-eligible branch was calling I took the second option you offered (skip the cache branch when those controls are active) over threading them through. Reasoning: parser stop tokens and the request-local logits processors aren't natively accepted by Gate logic, since
Compared against those no-op defaults rather than Two regression tests added in the same commit:
Happy to thread the wrapper semantics through the cache branch instead if you'd prefer the more aggressive version (more cache hits when callers explicitly set sampling penalties, at the cost of more surface area in the cache branch) — let me know. |
Thump604
left a comment
There was a problem hiding this comment.
Thanks for the quick revision. d768850 fixes the request-local decode controls I raised: stop, logits_processors, and active sampling/penalty controls now force the uncached path, while the server no-op defaults still allow the cache path. I also re-ran the focused suite locally:
AI_RUNTIME_BYPASS_SAFETY_GATE=1 PYTHONPATH=/opt/ai-runtime/worktrees/vllm-mlx/pr-523-review-latest /opt/ai-runtime/venv-live/bin/python -m pytest tests/test_simple_engine.py -q
# 33 passed
uvx ruff check vllm_mlx/engine/simple.py tests/test_simple_engine.py
# pass
/opt/ai-runtime/venv-live/bin/python -m black --check --target-version py312 vllm_mlx/engine/simple.py tests/test_simple_engine.py
# pass
git diff --check
# passI still cannot approve this revision yet because the cache branch remains a second generation path for engine-level features/limits. In stream_chat, the cache path calls mlx_lm.stream_generate directly with only max_tokens, sampler, and prompt_cache:
for resp in mlx_stream_generate(
model,
tokenizer,
prompt=prompt_arr,
max_tokens=max_tokens,
sampler=sampler,
prompt_cache=bc,
):That still diverges from stream_generate, which adds MTP cache/kwargs, SpecPrefill routing, max_kv_size on make_prompt_cache, and prefill_step_size.
Minimal local repro against d768850: configure SimpleEngine(mtp=True, mtp_num_draft_tokens=4), send a cache-eligible system+user chat request with only the server no-op decode defaults, and capture the direct mlx_lm.stream_generate kwargs. The cache path is entered (apply_chat_template called 3 times for prompt + probes), but captured kwargs are only:
{"max_tokens": 256, "sampler": ..., "prompt_cache": [...]}mtp is absent and num_draft_tokens is absent. So the same request changes behavior based only on whether the system-prefix cache branch is eligible when MTP is configured. The same class of gap applies to SpecPrefill when a draft model / per-request specprefill=True is active, and to the configured cache/step limits.
Please either:
- expand the skip gate so this cache branch is bypassed whenever an engine-level feature/limit it cannot honor is active (
self._mtp, loaded/forced SpecPrefill, configuredmax_kv_size, and any other wrapper-only generation behavior), or - route the cache branch through the same wrapper semantics instead of maintaining a second direct generation path.
If you choose the skip-gate option, I would add at least one MTP regression test proving self._mtp=True forces the uncached wrapper path and preserves mtp / num_draft_tokens, plus a SpecPrefill/per-request specprefill=True regression if practical. The current new tests prove the decode-control fix, but they do not cover this remaining wrapper-semantic surface.
The cache branch drives ``mlx_lm.stream_generate`` directly and bypasses engine-level features the ``self.stream_generate`` wrapper layers on top: ``self._mtp`` (multi-token prediction), loaded SpecPrefill draft model, per-request ``specprefill`` override from ``extra_body``, and configured ``self._max_kv_size``. Expand the existing decode-control gate to also skip the cache branch when any of those are active so cache-eligible and uncached requests decode under identical engine semantics. Engine-init defaults (``mtp=False``, no draft model, ``max_kv_size=0``) and the common per-request shape keep hitting the cache as before; the gate only fires when a feature/limit is actually configured. Adds three regression tests modeled on the existing decode-control ones: MTP active, SpecPrefill draft loaded, and ``max_kv_size`` configured. Each asserts the cache probes are skipped (``apply_chat_template`` called only once for the prompt render) and the uncached wrapper runs.
|
Thanks @Thump604 — fixed in Same reasoning as the prior decode-control gate: cache branch and uncached wrapper must decode under identical engine semantics, and threading the wrapper's wrapper-of-wrapper semantics ( Gate now also blocks the cache branch when any of the following are active:
Default engine config ( Three regression tests added in the same commit, modeled on the existing decode-control pair:
Each test only proves the cache branch is bypassed at this seam; the wrapper attaches MTP/SpecPrefill/ Happy to add the more aggressive option (threading wrapper semantics through the cache branch) as a follow-up if you'd rather have more cache hits when these features are configured — the current change just makes the two paths equivalent on engine semantics, which I read as your blocking concern. |
Thump604
left a comment
There was a problem hiding this comment.
Thanks for the follow-up. I re-checked 429f75c against the remaining wrapper-semantics concern I raised. The new gate now bypasses the pure-LLM system-prefix cache branch when MTP, a loaded SpecPrefill draft model, per-request SpecPrefill override, or nonzero max_kv_size would otherwise require self.stream_generate / wrapper behavior. That addresses the divergence I reproduced on d768850: cache-eligible turns no longer silently drop wrapper-only generation semantics.\n\nLocal verification on this revision:\n\nsh\ncd /opt/ai-runtime/worktrees/vllm-mlx/pr-523-review-latest\ngit diff --check\n/opt/ai-runtime/venv-live/bin/python -m pytest tests/test_simple_engine.py -q\n# 36 passed\nuvx ruff check vllm_mlx/engine/simple.py tests/test_simple_engine.py\n# pass\n/opt/ai-runtime/venv-live/bin/python -m black --check --target-version py312 vllm_mlx/engine/simple.py tests/test_simple_engine.py\n# pass\n\n\nI am not claiming broader runtime performance behavior from this review, only that the request-local decode controls and engine-feature compatibility issues raised on the prior revisions are addressed by the current diff and covered by focused regression tests.
|
Checked it. It looks good! I think it is a good work. Thank you @vinayvobbili. However, I will take few more minutes to revisit carefully. Sorry for being a bit away on this. |
|
Cache path looks solid, but three things worth a second look. There's a TOCTOU between the cache-hit gate and the restore. The gate at simple.py#L1027-L1032 reads if (
system_hash == self._system_kv_hash
and self._system_kv_snapshot is not None
and system_token_count == self._system_kv_token_count
):
cache_hit = Trueand the restore at L1077-L1080 re-reads it later under if cache_hit:
bc = make_prompt_cache(model)
for i, saved_state in enumerate(self._system_kv_snapshot):
bc[i].state = saved_stateA concurrent MISS can overwrite the snapshot in between, so the restored KV no longer matches the hash that set hit_snapshot = self._system_kv_snapshot if cache_hit else None
# ...later, inside _run_with_cache:
for i, saved_state in enumerate(hit_snapshot):
bc[i].state = saved_stateThe if (self._max_kv_size or 0) > 0:
cache_blocking_controls.append("max_kv_size")mlx-lm's gemma3_text, olmo3 and recurrent_gemma return from mlx_lm.models.cache import KVCache
if not all(isinstance(c, KVCache) for c in bc):
# fall back to uncached path; rotating caches alias their state
...Even on plain snapshot = [c.state for c in bc]
mx.eval([s for pair in snapshot for s in pair])
self._system_kv_snapshot = snapshot
# ...
for resp in mlx_stream_generate(
model, tokenizer,
prompt=prompt_arr,
max_tokens=max_tokens,
sampler=sampler,
prompt_cache=bc,
):A turn-2 parity test (HIT vs cache disabled) on a real model would tell us whether the snapshot needs a deep copy before publishing. |
|
Minor nit: the comment at L1094-L1096 says it mirrors the MLLM path, but the MLLM path doesn't call # Free intermediate prefill activations before snapshotting;
# mirrors the MLLM path between chunked prefill and decode.
mx.clear_cache()Either drop the comparison or note that this path is intentionally stricter. |
|
what do you think @vinayvobbili ? |
Addresses three concerns raised in maintainer review of waybarrios#523: 1. TOCTOU between HIT gate and snapshot restore. The gate at L1027-L1032 reads ``self._system_kv_snapshot`` outside ``_run_blocking_serialized`` while the restore re-reads it later, inside the serialized worker. A concurrent MISS that reassigns the attribute in between would load a snapshot for a different system prefix under the hash that decided HIT. Fixed by capturing the snapshot into a closure-local ``hit_snapshot`` at gate time and using that for the restore. 2. ``_max_kv_size`` gate alone doesn't catch sliding-window models. ``gemma3_text``, ``olmo3``, and ``recurrent_gemma`` return ``RotatingKVCache`` from ``make_cache()`` regardless of ``max_kv_size``; ``.state`` aliases buffers that ``update_and_fetch`` mutates in place. Probe once in ``start()`` via ``make_prompt_cache(model)`` and require every entry to be a plain ``KVCache``; the engine-feature gate adds ``non_kv_cache_class`` to ``cache_blocking_controls`` when the probe says no. 3. Comment on ``mx.clear_cache()`` claimed it mirrored the MLLM path, but the MLLM path doesn't ``mx.clear_cache()`` between its last prefill chunk and the snapshot. Reworded to call out that this path is intentionally stricter. Tests: - Add ``test_stream_chat_uses_gate_time_snapshot_under_concurrent_mutation``: simulates a concurrent MISS by reassigning ``_system_kv_snapshot`` inside the ``_run_blocking_serialized`` hook, then asserts the restore wrote the gate-time entries. - Add ``test_stream_chat_skips_cache_path_when_model_has_non_kv_cache``: proves the gate fires when ``_supports_system_kv_cache=False``. - Existing positive/negative tests set ``_supports_system_kv_cache=True`` to isolate the feature under test (otherwise the new gate would short-circuit them for the wrong reason).
|
Thanks for the careful read @waybarrios — all three are real. Pushed TOCTOU on Sliding-window models /
Comment nit (fixed). Dropped the MLLM-path comparison and noted that this path is intentionally stricter — the snapshot here reflects only the KV state, not residual prefill activations. CI's queued behind workflow approval again on the new commit (same |
… probe-divergence renders actually run
|
Sorry, went ahead and patched one of the tests that broke with the new logic (missing |
|
Thanks for catching that — that's the pydantic Message normalization test, I missed it when I was patching the obvious gate-isolation cases. Pulled into local. |
Follow-up to discussion #521 — adding the SimpleEngine prefix-cache work for review as @waybarrios suggested.
What
The system-prefix KV cache already lives in
_stream_generate_text(the MLLM text path). This extends the same single-slot, hash-keyed cache to the pure-LLMstream_chatpath, which currently re-prefills the full system block on every turn for non-MLLM models.How
After
apply_chat_template, before falling through to the existing uncachedstream_generate(prompt=...):<|im_start|>user,<|im_start|>assistant).promptandsystem_prefix_text; verify the system tokens are an exact prefix of the full token list.stream_generate()path, so the change is opt-in by detection and zero-impact on paths it doesn't apply to.Reuses existing engine attrs:
_system_kv_snapshot/_system_kv_hash/_system_kv_token_count/_prefill_step_size/_generation_lock/_run_blocking_serialized. No__init__changes, no new public surface.Measured impact
Qwen2.5-Coder-32B-Instruct-8bit driving Claude Code on an M-series Studio, ~23K-token system+tools prefix:
Discussion #521 has the full writeup including the interaction with the Anthropic billing header (handled in PR #277).
Notes for review
_stream_generate_textcache. If you'd prefer a multi-slot LRU here, easy to swap.apply_chat_templatewith a stripped-down message list — happy to refactor.