Skip to content

feat: extend system-prompt KV cache to pure-LLM stream_chat path#523

Merged
waybarrios merged 9 commits into
waybarrios:mainfrom
vinayvobbili:feat/system-kv-cache-stream-chat
May 14, 2026
Merged

feat: extend system-prompt KV cache to pure-LLM stream_chat path#523
waybarrios merged 9 commits into
waybarrios:mainfrom
vinayvobbili:feat/system-kv-cache-stream-chat

Conversation

@vinayvobbili
Copy link
Copy Markdown
Contributor

@vinayvobbili vinayvobbili commented May 10, 2026

Follow-up to discussion #521 — adding the SimpleEngine prefix-cache work for review as @waybarrios suggested.

What

The system-prefix KV cache already lives in _stream_generate_text (the MLLM text path). This extends the same single-slot, hash-keyed cache to the pure-LLM stream_chat path, which currently re-prefills the full system block on every turn for non-MLLM models.

How

After apply_chat_template, before falling through to the existing uncached stream_generate(prompt=...):

  1. Detect a system prefix via ChatML markers (<|im_start|>user, <|im_start|>assistant).
  2. Encode prompt and system_prefix_text; verify the system tokens are an exact prefix of the full token list.
  3. Cache HIT (same hash + same token count): restore the per-layer KV snapshot and prefill only the suffix. MISS: prefill the system tokens, snapshot, then continue with the suffix.
  4. Fallback: any unexpected condition (no system prefix, encode mismatch, exception inside the cache-aware runner) yields control back to the original stream_generate() path, so the change is opt-in by detection and zero-impact on paths it doesn't apply to.

Reuses existing engine attrs: _system_kv_snapshot / _system_kv_hash / _system_kv_token_count / _prefill_step_size / _generation_lock / _run_blocking_serialized. No __init__ changes, no new public surface.

Measured impact

Qwen2.5-Coder-32B-Instruct-8bit driving Claude Code on an M-series Studio, ~23K-token system+tools prefix:

  • before: ~100s+ prefill on every follow-up turn
  • after: ~7s on cache hits

Discussion #521 has the full writeup including the interaction with the Anthropic billing header (handled in PR #277).

Notes for review

  • Happy to add a unit test or benchmark in whatever shape fits the repo — let me know what would be most useful.
  • The cache is single-slot (one system prefix at a time), which matches the existing _stream_generate_text cache. If you'd prefer a multi-slot LRU here, easy to swap.
  • If the marker-based detection feels too heuristic, an alternative is to take the system prefix length from apply_chat_template with a stripped-down message list — happy to refactor.

The existing system-prefix KV cache (added in _stream_generate_text)
covered the MLLM text-only path but not the pure-LLM stream_chat path,
so non-MLLM models routing through stream_generate() re-prefilled the
full system block on every turn.

This mirrors the same hash-keyed, single-slot cache logic into
stream_chat, after apply_chat_template:

- detect a system prefix via ChatML markers
- HIT: restore the snapshot and prefill only the suffix tokens
- MISS: prefill the system tokens, snapshot per-layer KV state,
  then continue with the suffix
- fallback: if anything looks off (no prefix, encode mismatch,
  cache call raises), drop through to the original uncached
  self.stream_generate() path unchanged

Reuses the engine's existing _system_kv_snapshot / _system_kv_hash /
_system_kv_token_count attributes - no __init__ changes, no new public
surface. Holds no extra locks (the inner _run_blocking_serialized
already takes _generation_lock).

Measured locally on Qwen2.5-Coder-32B-Instruct-8bit driving Claude Code
on Apple Silicon: ~100s+ follow-up-turn prefill -> ~7s once the system
prefix is cached (~23K-token system+tools prefix).
@waybarrios
Copy link
Copy Markdown
Owner

Related: filed #524 to track generalising the per-client prefix-poison strippers (currently just the Anthropic billing header from #277). The canonicalization registry proposed there interacts directly with the system-prefix caching this PR extends — worth keeping the two in sync.

1. Split-point detection (waybarrios#2): replace ChatML-specific marker scan (`<|im_start|>...`) with probe-divergence — render the chat template with two different user contents and take the shared prefix. Mirrors `prompt_warmup._build_strict_prefix_string` and works across Qwen/ChatML, Llama, Gemma, and any other chat format with no per-model marker list.

2. Incremental streaming (waybarrios#1): replace `list(mlx_stream_generate(...))` with the `asyncio.Queue` producer pattern used in `_stream_generate_text` — chunks are emitted via `loop.call_soon_threadsafe` from the thread running mlx-lm and yielded immediately to the caller. Adds a thread-safe `abort_event` tied to `_run_blocking_serialized`'s `on_cancel` hook.

3. Move `import hashlib` to the module top (waybarrios#4); use `hashlib.sha256` directly instead of the local alias.

4. Remove the dead `if True:` indent block (waybarrios#5) left over from earlier removing an async-with wrapper.

The cache-aware-path failure fallback now distinguishes pre-first-token errors (safe to retry uncached) from mid-stream errors (re-raise — the client has already received partial output and switching paths would duplicate tokens).

Canonicalization-before-hashing (waybarrios#3) is deferred to the registry being designed in waybarrios#524.
@vinayvobbili
Copy link
Copy Markdown
Contributor Author

Thanks @janhilgard for the thorough review and for the context on the three invalidators the MLLM path already handles (last_query_index template patching and the <think>\n suffix in particular — both were news to me and explain why our measured win on the MLLM path is smaller than what we saw on the non-MLLM Qwen2.5-Coder setup that motivated this PR).

Pushed 560237c addressing four of your five points:

#1 streaming — replaced list(mlx_stream_generate(...)) with the asyncio.Queue + loop.call_soon_threadsafe producer pattern from _stream_generate_text. Chunks are emitted to the queue from the mlx-lm worker thread and yielded to the caller immediately. The cache-path failure fallback now distinguishes pre-first-token errors (safe to retry uncached) from mid-stream errors (re-raise — switching paths after partial output would duplicate tokens). Abort propagation wired through _run_blocking_serialized's on_cancel hook.

#2 model-agnostic split-point — replaced the ChatML marker scan with probe-divergence: render the chat template twice with two different user contents, take the shared prefix. This is the same technique prompt_warmup._build_strict_prefix_string already uses, which made it the obvious choice. Works on Qwen/ChatML, Llama, Gemma, and any other format without a per-model marker table. Happy to extract the helper into a shared utility (prompt_canonicalize.py or similar) if you'd prefer one source of truth rather than the inline copy — wasn't sure which way you wanted that to land alongside #524.

#4 hashlib import + #5 dead if True: block — both cleaned up; import hashlib at module top, indentation flattened.

#3 canonicalize-before-hashing — deferred to #524 as you suggested. Happy to take a swing at the simplified registry you sketched (single list, sub in a loop, no runtime register API) as a follow-up PR once this one lands or in a single combined PR if you'd prefer that — your call on sequencing.

One nuance on the probe-divergence detection that's worth flagging: it works correctly when there's at least one user message in the conversation (the probes vary only the last user turn). If the message list has system messages followed by an assistant message and nothing else, the function appends a placeholder user message before probing. Edge case in practice but documented in the code path.

Re-tested on the same Qwen2.5-Coder-32B-Instruct-8bit / Claude Code workload — cache hit/miss/store logging fires identically and now the OpenAI Chat Completions client sees streaming token deltas during generation instead of one big blob at the end. Will rerun the timing benchmark and post the numbers if you'd like a fresh measurement on this revision.

@janhilgard
Copy link
Copy Markdown
Collaborator

Nice work on the revision. The probe-divergence approach for finding the system prefix boundary is the right call — it's the same technique we validated in prompt_warmup and it generalizes cleanly across all template formats without maintaining a marker table.

The asyncio.Queue + call_soon_threadsafe streaming pattern matches what _stream_generate_text already does, so behavior will be consistent across both paths. Good catch on the pre-first-token vs mid-stream error distinction for fallback safety.

Deferring canonicalize-before-hashing to #524 makes sense — keeps this PR focused on the caching mechanism itself. Once #524 lands with even the minimal stripper list, it's a one-liner to wire canonicalize_system_prompt(system_prefix_text) before the sha256() call.

One question on the probe-divergence edge case you flagged: if the message list has only system + assistant (no user message), is this a realistic scenario? In the OpenAI Chat Completions API the first non-system message is always user-originated. The placeholder insertion sounds safe but I want to make sure it doesn't shift token boundaries in the probe vs the real encode (e.g. if the template generates different separators for 2-message vs 3-message sequences). If you've verified the prefix length matches between probe and real encode in that case, it's fine.

LGTM on the overall design. Happy to give a final approval once the diff is up for a quick re-check.

`stream_chat` is typed as accepting `list[dict[str, Any]]` but some internal callers (server.py's streaming endpoint, the streaming-chat test in test_server.py) pass Pydantic `Message` objects directly. Those don't expose dict's `.get()`, so the cache-eligibility detection raised `'Message' object has no attribute 'get'` and the test `TestStreamChatCompletion::test_streaming_chat_no_stream_thread_error_after_residency_preload` failed.

Normalize each message to a plain dict (via `model_dump()` / `dict()` / getattr fallback) before the role lookup and the probe-divergence renders. Also apply black for the lint step.

Surfaces only when the chat-completion request goes through `stream_chat` with raw Pydantic Message objects (i.e. not pre-converted by the caller). The MLLM path (`_stream_generate_text`) has the same `m.get(...)` pattern, but the test patches `is_mllm_model=False` so this fix is enough to green the CI.
Three tests covering the new pure-LLM cache path:

- `test_stream_chat_cache_path_accepts_pydantic_message_objects` — regression test for the Pydantic `Message` handling that the prior CI run caught. Verifies that callers passing `Message(role=..., content=...)` (as `server.py`'s streaming endpoint does) don't raise `AttributeError` at the `.get('role')` boundary.
- `test_stream_chat_skips_cache_path_when_no_system_message` — `has_system=False` must short-circuit; `apply_chat_template` should only be called once (for the initial prompt), not three times (initial + Alpha probe + Bravo probe).
- `test_stream_chat_cache_path_falls_back_when_mlx_raises` — when `_run_with_cache` raises before the first token (here forced via patching `make_prompt_cache` to raise), the pre-first-token error branch must route to the uncached `stream_generate` fallback rather than propagating the exception.
@vinayvobbili
Copy link
Copy Markdown
Contributor Author

Thanks @janhilgard — glad the design holds up. Answering the edge case directly:

The placeholder-user injection in _with_user does shift token boundaries in some templates — but the post-probe token-level prefix verification is the safety net. After the boundary is found character-wise on the probes, we encode both prompt (the real, no-placeholder render) and system_prefix_text (the probe-extracted slice) and assert full_tokens_list[:system_token_count] == system_tokens_list. If the placeholder caused the probe to extract a string that isn't actually a prefix of the real render, this equality fails and kv_cache_eligible stays False. Cache silently skips — no incorrect cache state, just no speedup for that turn.

I empirically verified this against the Qwen2.5-Coder template for the two scenarios you flagged:

Messages Probe-extracted prefix ends in Real prompt ends in Token check
[system] <|im_start|>user\n <|im_start|>assistant\n (generation prompt) mismatch → cache skipped
[system, assistant] ...assistant\nHow can I help?<|im_end|>\n<|im_start|>user\n ...<|im_start|>assistant\n mismatch → cache skipped

Both flow through to the uncached stream_generate correctly. As you noted, neither is a realistic OpenAI-API-shaped conversation anyway — but the structural safety net means we don't have to rely on caller well-formedness. If a future template format renders boilerplate differently for N vs N+1 messages, the token equality catches that too.

Also pushed 9d9506c with three unit tests covering the new path:

  • regression test for Pydantic Message normalisation (the bug the prior CI run caught — server.py's streaming endpoint passes Message objects, not dicts)
  • has_system=False short-circuit (verifies apply_chat_template isn't called for probe renders when no system message exists)
  • pre-first-token error → uncached fallback (forces make_prompt_cache to raise; verifies the fallback path runs instead of propagating)

Ready for re-check whenever you have time.

Copy link
Copy Markdown
Collaborator

@janhilgard janhilgard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid revision — the streaming architecture, error handling, and model-agnostic detection all look correct now.

What works well:

  • Probe-divergence correctly reuses the same template_kwargs as the real render, so enable_thinking, tools, etc. don't throw off the boundary
  • asyncio.Queue + call_soon_threadsafe streaming mirrors _stream_generate_text exactly
  • Pre-first-token vs mid-stream error distinction prevents token duplication on fallback
  • _to_msg_dict() normalizes Pydantic objects cleanly
  • on_cancel=abort_event.set integrates with the existing cancellation path
  • Tests cover the three critical edge cases (Pydantic messages, no-system skip, error fallback)

One minor optimization suggestion (non-blocking):

On cache MISS, after the chunked system prefill loop, consider adding mx.clear_cache() before starting suffix generation:

if sys_arr.size > 0:
    model(sys_arr[None], cache=bc)
    mx.eval([c.state for c in bc])

mx.clear_cache()  # free intermediate activations from prefill

snapshot = [c.state for c in bc]

For 23K-token system prefixes, the intermediate activations from prefilling can be ~1-2 GB of temporary memory. The MLLM path does this between chunked prefill and generation. Not critical since it'll be reclaimed eventually by the allocator, but prevents a peak memory spike on the first MISS after startup.

LGTM — no blocking issues.

Copy link
Copy Markdown
Collaborator

@janhilgard janhilgard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid revision — the streaming architecture, error handling, and model-agnostic detection all look correct now.

What works well:

  • Probe-divergence correctly reuses the same template_kwargs as the real render, so enable_thinking, tools, etc. don't throw off the boundary
  • asyncio.Queue + call_soon_threadsafe streaming mirrors _stream_generate_text exactly
  • Pre-first-token vs mid-stream error distinction prevents token duplication on fallback
  • _to_msg_dict() normalizes Pydantic objects cleanly
  • on_cancel=abort_event.set integrates with the existing cancellation path
  • Tests cover the three critical edge cases (Pydantic messages, no-system skip, error fallback)

One minor optimization suggestion (non-blocking):

On cache MISS, after the chunked system prefill loop, consider adding mx.clear_cache() before starting suffix generation:

if sys_arr.size > 0:
    model(sys_arr[None], cache=bc)
    mx.eval([c.state for c in bc])

mx.clear_cache()  # free intermediate activations from prefill

snapshot = [c.state for c in bc]

For 23K-token system prefixes, the intermediate activations from prefilling can be ~1-2 GB of temporary memory. The MLLM path does this between chunked prefill and generation. Not critical since it'll be reclaimed eventually by the allocator, but prevents a peak memory spike on the first MISS after startup.

LGTM — no blocking issues.

Per @janhilgard's non-blocking review suggestion. After chunked prefill of the system prefix, intermediate activations can hold ~1-2 GB of temporary memory on long (~23K-token) prefixes. Calling mx.clear_cache() before capturing the per-layer state mirrors the existing MLLM path's pattern between chunked prefill and decode, and prevents a peak-memory spike on the first cache MISS after startup. The snapshot itself is unaffected (it reads c.state on each cache slot, which is preserved).
@vinayvobbili
Copy link
Copy Markdown
Contributor Author

Thanks @janhilgard — applied in 0577644. Confirmed the snapshot read still picks up the per-layer state cleanly after the clear, since c.state returns the cache slot's current arrays rather than relying on intermediate activations. Approval should still hold; CI re-running.

Copy link
Copy Markdown
Collaborator

@Thump604 Thump604 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the focused SimpleEngine tests locally against this PR: AI_RUNTIME_BYPASS_SAFETY_GATE=1 PYTHONPATH=/tmp/vllm-mlx-review-523 /opt/ai-runtime/venv-live/bin/python -m pytest tests/test_simple_engine.py -q -> 31 passed.

I also reproduced one behavior gap in the cache-eligible stream_chat branch. A request with stop, top_k, min_p, presence_penalty, repetition_penalty, and a request-local logits_processors list reaches the new direct mlx_lm.stream_generate(...) call with only prompt, max_tokens, sampler, and prompt_cache. The uncached path still calls self.stream_generate(..., **kwargs), and that wrapper preserves the request-local decode controls, stop handling, penalty processors, logits processors, and MTP/cache semantics.

That matters because server.py can attach parser stop tokens and JSON/constrained decoding logits processors per request, and it always threads the sampling/penalty fields through the engine call. With this branch enabled, a system-prefix-cache eligible request can silently decode under different constraints than the same request on the uncached path.

Please either preserve the same wrapper semantics in the cache branch, or explicitly skip the cache branch when unsupported request-local decode/constraint controls are present. I would also add a regression test where a cache-eligible stream_chat call with stop or logits_processors proves the control is preserved or the uncached path is used.

I am not claiming the system-prefix cache approach is wrong, only that this PR currently changes request-local decode/parser behavior on the cache branch.

The cache-eligible branch in stream_chat called mlx_lm.stream_generate
directly with only prompt/max_tokens/sampler/prompt_cache, silently
dropping `stop`, request-local `logits_processors` (parser stop tokens
and JSON-constrained decoding attached by server.py), and the
`top_k`/`min_p`/`presence_penalty`/`repetition_penalty` sampling
controls that the uncached path threads through self.stream_generate.

Same request could decode under different constraints depending on
whether it hit the cache branch. Gate the cache branch on absence of
those controls so both paths share identical decode semantics. The
gate compares against server.py's no-op defaults (top_k=0, min_p=0.0,
presence_penalty=0.0, repetition_penalty=1.0) rather than
`key in kwargs`, so the common path still hits the cache.

Adds two regression tests: stop+logits_processors forces the uncached
fallback (probe-divergence never runs), and the no-op defaults still
exercise the cache path.

Reported by @Thump604.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@vinayvobbili
Copy link
Copy Markdown
Contributor Author

Thanks for catching this, @Thump604 — confirmed and fixed in d768850.

You're right on the substance: the cache-eligible branch was calling mlx_lm.stream_generate directly with only prompt/max_tokens/sampler/prompt_cache, while the uncached path threads **kwargs through self.stream_generate(...). The same request could decode under different constraints depending on which branch it hit — stop, request-local logits_processors (parser stop tokens and JSON-constrained decoding from server.py), and the top_k/min_p/presence_penalty/repetition_penalty sampling controls were all silently dropped on the cache branch.

I took the second option you offered (skip the cache branch when those controls are active) over threading them through. Reasoning: parser stop tokens and the request-local logits processors aren't natively accepted by mlx_lm.stream_generate's signature, so threading them would mean re-implementing the wrapper's semantics on the cache side rather than just delegating. Skip is the smaller surface change, keeps cache and uncached paths converged on the wrapper for anything non-trivial, and only sacrifices cache hits when the request actually uses those controls.

Gate logic, since server.py always sets the sampling fields (even when unused):

  • stop, logits_processors: blocked if truthy (non-empty list / non-empty string).
  • top_k: blocked if > 0 (0 is the unbounded no-op server.py sends).
  • min_p: blocked if > 0.0 (0.0 is the no-op).
  • presence_penalty: blocked if != 0.0 (0.0 is the no-op).
  • repetition_penalty: blocked if != 1.0 (1.0 is the no-op).

Compared against those no-op defaults rather than key in kwargs, so the common path with server.py's always-set 0/0.0/0.0/1.0 still hits the cache. Logs System KV cache SKIP (stream_chat): ... with the list of triggering controls when it falls back, so this is visible in operation.

Two regression tests added in the same commit:

  • test_stream_chat_skips_cache_path_when_decode_controls_present: request with stop=["<|im_end|>"] + logits_processors=[sentinel] forces the uncached fallback, probe-divergence never runs (apply_chat_template.call_count == 1), and the controls are confirmed threaded through.
  • test_stream_chat_takes_cache_path_when_decode_controls_are_no_ops: request with the server.py-style top_k=0, min_p=0.0, presence_penalty=0.0, repetition_penalty=1.0 still enters the cache path (probe runs, apply_chat_template.call_count == 3).

Happy to thread the wrapper semantics through the cache branch instead if you'd prefer the more aggressive version (more cache hits when callers explicitly set sampling penalties, at the cost of more surface area in the cache branch) — let me know.

Copy link
Copy Markdown
Collaborator

@Thump604 Thump604 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick revision. d768850 fixes the request-local decode controls I raised: stop, logits_processors, and active sampling/penalty controls now force the uncached path, while the server no-op defaults still allow the cache path. I also re-ran the focused suite locally:

AI_RUNTIME_BYPASS_SAFETY_GATE=1 PYTHONPATH=/opt/ai-runtime/worktrees/vllm-mlx/pr-523-review-latest /opt/ai-runtime/venv-live/bin/python -m pytest tests/test_simple_engine.py -q
# 33 passed

uvx ruff check vllm_mlx/engine/simple.py tests/test_simple_engine.py
# pass

/opt/ai-runtime/venv-live/bin/python -m black --check --target-version py312 vllm_mlx/engine/simple.py tests/test_simple_engine.py
# pass

git diff --check
# pass

I still cannot approve this revision yet because the cache branch remains a second generation path for engine-level features/limits. In stream_chat, the cache path calls mlx_lm.stream_generate directly with only max_tokens, sampler, and prompt_cache:

for resp in mlx_stream_generate(
    model,
    tokenizer,
    prompt=prompt_arr,
    max_tokens=max_tokens,
    sampler=sampler,
    prompt_cache=bc,
):

That still diverges from stream_generate, which adds MTP cache/kwargs, SpecPrefill routing, max_kv_size on make_prompt_cache, and prefill_step_size.

Minimal local repro against d768850: configure SimpleEngine(mtp=True, mtp_num_draft_tokens=4), send a cache-eligible system+user chat request with only the server no-op decode defaults, and capture the direct mlx_lm.stream_generate kwargs. The cache path is entered (apply_chat_template called 3 times for prompt + probes), but captured kwargs are only:

{"max_tokens": 256, "sampler": ..., "prompt_cache": [...]}

mtp is absent and num_draft_tokens is absent. So the same request changes behavior based only on whether the system-prefix cache branch is eligible when MTP is configured. The same class of gap applies to SpecPrefill when a draft model / per-request specprefill=True is active, and to the configured cache/step limits.

Please either:

  1. expand the skip gate so this cache branch is bypassed whenever an engine-level feature/limit it cannot honor is active (self._mtp, loaded/forced SpecPrefill, configured max_kv_size, and any other wrapper-only generation behavior), or
  2. route the cache branch through the same wrapper semantics instead of maintaining a second direct generation path.

If you choose the skip-gate option, I would add at least one MTP regression test proving self._mtp=True forces the uncached wrapper path and preserves mtp / num_draft_tokens, plus a SpecPrefill/per-request specprefill=True regression if practical. The current new tests prove the decode-control fix, but they do not cover this remaining wrapper-semantic surface.

The cache branch drives ``mlx_lm.stream_generate`` directly and bypasses
engine-level features the ``self.stream_generate`` wrapper layers on top:
``self._mtp`` (multi-token prediction), loaded SpecPrefill draft model,
per-request ``specprefill`` override from ``extra_body``, and configured
``self._max_kv_size``.

Expand the existing decode-control gate to also skip the cache branch
when any of those are active so cache-eligible and uncached requests
decode under identical engine semantics.

Engine-init defaults (``mtp=False``, no draft model, ``max_kv_size=0``)
and the common per-request shape keep hitting the cache as before; the
gate only fires when a feature/limit is actually configured.

Adds three regression tests modeled on the existing decode-control ones:
MTP active, SpecPrefill draft loaded, and ``max_kv_size`` configured.
Each asserts the cache probes are skipped (``apply_chat_template`` called
only once for the prompt render) and the uncached wrapper runs.
@vinayvobbili
Copy link
Copy Markdown
Contributor Author

Thanks @Thump604 — fixed in 429f75c, taking the skip-gate option you offered.

Same reasoning as the prior decode-control gate: cache branch and uncached wrapper must decode under identical engine semantics, and threading the wrapper's wrapper-of-wrapper semantics (SimpleEngine.stream_generateMLXLanguageModel.stream_generatemlx_lm.stream_generate) into the cache branch means re-implementing both layers on the cache side. Skipping is the smaller surface change and keeps cache and uncached paths converged on the wrapper for any non-trivial generation behavior.

Gate now also blocks the cache branch when any of the following are active:

  • self._mtp — MTP injects mtp=True and num_draft_tokens into the mlx_lm.stream_generate call inside MLXLanguageModel.stream_generate.
  • self._draft_model is not None — a loaded SpecPrefill draft model (set when specprefill_enabled + specprefill_draft_model are configured at engine init) triggers _stream_generate_specprefill routing in the wrapper for prompts above specprefill_threshold.
  • kwargs.get("specprefill") is not None — per-request specprefill override from extra_body. Gated on is not None rather than truthiness so an explicit specprefill=False suppression signal also forces the wrapper path.
  • self._max_kv_size > 0 — caps the prompt cache; the cache branch builds with make_prompt_cache(model) and has no equivalent bound.

Default engine config (mtp=False, no draft model, max_kv_size=0) and the common per-request shape keep hitting the cache as before — the gate only fires when a feature/limit is actually configured.

Three regression tests added in the same commit, modeled on the existing decode-control pair:

  • test_stream_chat_skips_cache_path_when_mtp_active — constructs SimpleEngine("test-model", mtp=True, mtp_num_draft_tokens=4), asserts apply_chat_template.call_count == 1 (probes never run) and the uncached fallback was invoked.
  • test_stream_chat_skips_cache_path_when_specprefill_loaded — sets engine._draft_model = MagicMock() post-construction, asserts the same skip behavior.
  • test_stream_chat_skips_cache_path_when_max_kv_size_set — constructs SimpleEngine("test-model", max_kv_size=4096), asserts the same.

Each test only proves the cache branch is bypassed at this seam; the wrapper attaches MTP/SpecPrefill/max_kv_size semantics itself once requests reach it, and those layers have their own test coverage.

Happy to add the more aggressive option (threading wrapper semantics through the cache branch) as a follow-up if you'd rather have more cache hits when these features are configured — the current change just makes the two paths equivalent on engine semantics, which I read as your blocking concern.

Copy link
Copy Markdown
Collaborator

@Thump604 Thump604 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow-up. I re-checked 429f75c against the remaining wrapper-semantics concern I raised. The new gate now bypasses the pure-LLM system-prefix cache branch when MTP, a loaded SpecPrefill draft model, per-request SpecPrefill override, or nonzero max_kv_size would otherwise require self.stream_generate / wrapper behavior. That addresses the divergence I reproduced on d768850: cache-eligible turns no longer silently drop wrapper-only generation semantics.\n\nLocal verification on this revision:\n\nsh\ncd /opt/ai-runtime/worktrees/vllm-mlx/pr-523-review-latest\ngit diff --check\n/opt/ai-runtime/venv-live/bin/python -m pytest tests/test_simple_engine.py -q\n# 36 passed\nuvx ruff check vllm_mlx/engine/simple.py tests/test_simple_engine.py\n# pass\n/opt/ai-runtime/venv-live/bin/python -m black --check --target-version py312 vllm_mlx/engine/simple.py tests/test_simple_engine.py\n# pass\n\n\nI am not claiming broader runtime performance behavior from this review, only that the request-local decode controls and engine-feature compatibility issues raised on the prior revisions are addressed by the current diff and covered by focused regression tests.

@waybarrios
Copy link
Copy Markdown
Owner

waybarrios commented May 14, 2026

Checked it. It looks good! I think it is a good work. Thank you @vinayvobbili. However, I will take few more minutes to revisit carefully. Sorry for being a bit away on this.

@waybarrios
Copy link
Copy Markdown
Owner

Cache path looks solid, but three things worth a second look.

There's a TOCTOU between the cache-hit gate and the restore. The gate at simple.py#L1027-L1032 reads self._system_kv_snapshot without a lock:

if (
    system_hash == self._system_kv_hash
    and self._system_kv_snapshot is not None
    and system_token_count == self._system_kv_token_count
):
    cache_hit = True

and the restore at L1077-L1080 re-reads it later under _run_blocking_serialized:

if cache_hit:
    bc = make_prompt_cache(model)
    for i, saved_state in enumerate(self._system_kv_snapshot):
        bc[i].state = saved_state

A concurrent MISS can overwrite the snapshot in between, so the restored KV no longer matches the hash that set cache_hit=True. Capturing the snapshot into a closure local at gate time fixes it:

hit_snapshot = self._system_kv_snapshot if cache_hit else None
# ...later, inside _run_with_cache:
for i, saved_state in enumerate(hit_snapshot):
    bc[i].state = saved_state

The _max_kv_size gate at L938-L940 misses sliding-window models:

if (self._max_kv_size or 0) > 0:
    cache_blocking_controls.append("max_kv_size")

mlx-lm's gemma3_text, olmo3 and recurrent_gemma return RotatingKVCache from make_cache() regardless of max_kv_size, so make_prompt_cache(model) at L1078/L1082 can hand back a rotating cache whose update_and_fetch mutates in place while .state aliases those same arrays. An isinstance check after building bc would close the gap:

from mlx_lm.models.cache import KVCache
if not all(isinstance(c, KVCache) for c in bc):
    # fall back to uncached path; rotating caches alias their state
    ...

Even on plain KVCache, worth confirming that the slices returned by .state (line 1098) don't share storage with the buffers mlx_stream_generate writes through at L1114-L1121:

snapshot = [c.state for c in bc]
mx.eval([s for pair in snapshot for s in pair])
self._system_kv_snapshot = snapshot
# ...
for resp in mlx_stream_generate(
    model, tokenizer,
    prompt=prompt_arr,
    max_tokens=max_tokens,
    sampler=sampler,
    prompt_cache=bc,
):

A turn-2 parity test (HIT vs cache disabled) on a real model would tell us whether the snapshot needs a deep copy before publishing.

@waybarrios
Copy link
Copy Markdown
Owner

Minor nit: the comment at L1094-L1096 says it mirrors the MLLM path, but the MLLM path doesn't call mx.clear_cache() between the last chunk and the snapshot.

# Free intermediate prefill activations before snapshotting;
# mirrors the MLLM path between chunked prefill and decode.
mx.clear_cache()

Either drop the comparison or note that this path is intentionally stricter.

@waybarrios
Copy link
Copy Markdown
Owner

what do you think @vinayvobbili ?

Addresses three concerns raised in maintainer review of waybarrios#523:

1. TOCTOU between HIT gate and snapshot restore. The gate at L1027-L1032
   reads ``self._system_kv_snapshot`` outside ``_run_blocking_serialized``
   while the restore re-reads it later, inside the serialized worker.
   A concurrent MISS that reassigns the attribute in between would load
   a snapshot for a different system prefix under the hash that decided
   HIT. Fixed by capturing the snapshot into a closure-local
   ``hit_snapshot`` at gate time and using that for the restore.

2. ``_max_kv_size`` gate alone doesn't catch sliding-window models.
   ``gemma3_text``, ``olmo3``, and ``recurrent_gemma`` return
   ``RotatingKVCache`` from ``make_cache()`` regardless of
   ``max_kv_size``; ``.state`` aliases buffers that ``update_and_fetch``
   mutates in place. Probe once in ``start()`` via
   ``make_prompt_cache(model)`` and require every entry to be a plain
   ``KVCache``; the engine-feature gate adds ``non_kv_cache_class`` to
   ``cache_blocking_controls`` when the probe says no.

3. Comment on ``mx.clear_cache()`` claimed it mirrored the MLLM path,
   but the MLLM path doesn't ``mx.clear_cache()`` between its last
   prefill chunk and the snapshot. Reworded to call out that this path
   is intentionally stricter.

Tests:
- Add ``test_stream_chat_uses_gate_time_snapshot_under_concurrent_mutation``:
  simulates a concurrent MISS by reassigning ``_system_kv_snapshot``
  inside the ``_run_blocking_serialized`` hook, then asserts the restore
  wrote the gate-time entries.
- Add ``test_stream_chat_skips_cache_path_when_model_has_non_kv_cache``:
  proves the gate fires when ``_supports_system_kv_cache=False``.
- Existing positive/negative tests set ``_supports_system_kv_cache=True``
  to isolate the feature under test (otherwise the new gate would
  short-circuit them for the wrong reason).
@vinayvobbili
Copy link
Copy Markdown
Contributor Author

Thanks for the careful read @waybarrios — all three are real. Pushed 00717b9 with fixes:

TOCTOU on self._system_kv_snapshot (fixed). Captured the snapshot into a closure-local hit_snapshot at the gate, exactly the shape you suggested. The restore now reads from that local rather than re-reading the instance attribute inside _run_with_cache. Added test_stream_chat_uses_gate_time_snapshot_under_concurrent_mutation which reassigns engine._system_kv_snapshot inside the _run_blocking_serialized hook and asserts the restore wrote the gate-time entries.

Sliding-window models / RotatingKVCache (fixed). Good catch — _max_kv_size is not a complete check because gemma3_text, olmo3, and recurrent_gemma produce RotatingKVCache regardless. Moved this to a one-time probe in start(): builds make_prompt_cache(model) once, sets self._supports_system_kv_cache = all(isinstance(c, KVCache) for c in probe), and the gate appends non_kv_cache_class to cache_blocking_controls when the probe says no. Doing it at init keeps the per-request path cheap. Added test_stream_chat_skips_cache_path_when_model_has_non_kv_cache for the negative path. I considered the inline isinstance check inside _run_with_cache too, but probing once at start avoids re-checking on every cache-eligible request and gives a single clear log line when a model isn't snapshot-safe.

.state aliasing on plain KVCache (open — argued safe, deferred to parity test). Reading mlx-lm's KVCache, .state returns (self.keys, self.values) — the full preallocated buffers. The system region is [..., :system_token_count, :], and decode update_and_fetch writes at prev:offset where prev >= system_token_count. The cases where the system region could be touched are (a) writes overrunning the allocated step, in which case self.keys = mx.concatenate(...) allocates a new array and the snapshot's reference pins the old buffer (whose system region is intact), and (b) realloc going through self.keys[..., :prev, :] slice — that slice creates a view, and the assignment that follows replaces self.keys rather than mutating it. In both cases the snapshot's old reference stays valid. I don't think this needs a code change, but I agree a turn-2 HIT-vs-uncached parity test on a real model is the right way to prove it. Happy to add one on the CI runner if you want — it'd need a small model in the fixture set so the cost stays bounded. Want me to add it in this PR or fold it into a follow-up?

Comment nit (fixed). Dropped the MLLM-path comparison and noted that this path is intentionally stricter — the snapshot here reflects only the KV state, not residual prefill activations.

CI's queued behind workflow approval again on the new commit (same action_required gate as the prior turn).

@waybarrios
Copy link
Copy Markdown
Owner

Sorry, went ahead and patched one of the tests that broke with the new logic (missing _supports_system_kv_cache = True in the setup). Should be fine now.

@waybarrios waybarrios merged commit 43bd05f into waybarrios:main May 14, 2026
9 checks passed
@vinayvobbili
Copy link
Copy Markdown
Contributor Author

Thanks for catching that — that's the pydantic Message normalization test, I missed it when I was patching the obvious gate-isolation cases. Pulled into local.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants