fix: replace manual decode loop with pipelined generation in SpecPrefill Phase 4#248
Conversation
e889470 to
d266b20
Compare
|
btw, I noticed SpecPrefill currently only triggers on the streaming path ( |
|
@Vigilans, @waybarrios: independent technical review of this PR. Verification of the fixConfirmed against current upstream main (b4fa030). The diff modifies two
The new Technical correctnessThe root cause analysis in #247 holds. The manual Both Backward compatibility
Minor observations (not blocking)
On the related observation in your commentThe non-streaming RecommendationMerge candidate. Real fix to a real performance regression, hard production evidence, sound implementation, reasonable tests. |
Thump604
left a comment
There was a problem hiding this comment.
This is a correct fix and a big one. The manual model() + mx.eval(y) per-token loop foregoes MLX's kernel pipelining, which is exactly what gets SpecPrefill Phase 4 stuck at ~0.3 tok/s while the same model at the same context runs ~30 tok/s through a normal decode. Routing through mlx_lm.stream_generate(prompt=mx.array([first_token]), prompt_cache=cache) is the clean fix.
On your note about SpecPrefill currently only triggering on the streaming path: that's a separate accumulator gap I'm working on. Once both this PR and that refactor land, non-streaming /v1/chat/completions and /v1/completions will also reach Phase 4, and this fix becomes load-bearing there too. Worth mentioning in the PR description if you want to foreshadow it.
Post-merge I'd like to run this against a 16K/64K/128K needle harness on Qwen 3.5 122B + 2B draft as an independent regression signal. If that's useful I can post the run output here as a follow-up comment.
Approving.
…l generate+stream_generate Pre-existing regression from an earlier rebase that dropped bdf7dcc's llm.py additions. The server.py request handlers still pass top_k, min_p, presence_penalty, repetition_penalty through to SimpleEngine, which forwards them via **kwargs to MLXLanguageModel.chat() (which accepts **kwargs) which then calls self.generate(..., **kwargs). But MLXLanguageModel.generate() and stream_generate() had been left with only (temperature, top_p, repetition_penalty) in their signatures, so any non-MLLM SimpleEngine request crashed with: TypeError: MLXLanguageModel.stream_generate() got an unexpected keyword argument 'top_k' Observed as 0/6 on simple-base, simple-mtp, and simple-spec profiles in the feature matrix regression sweep after the Session 87 cherry-picks of PRs waybarrios#248, waybarrios#229, waybarrios#218, waybarrios#222 landed. The cherry-picks did not cause this regression — they exposed it by finally running the LLM-path tests that no one had exercised since the rebase happened. Confirmed via stderr.log: TypeError: MLXLanguageModel.generate() got an unexpected keyword argument 'top_k' TypeError: MLXLanguageModel.stream_generate() got an unexpected keyword argument 'top_k' Fix: restore the signatures and bodies of _create_sampler, _create_logits_processors, generate, and stream_generate to match bdf7dcc's original intent. Preserves PR waybarrios#248's prompt_cache parameter and non-str prompt support on stream_generate. Adds **kwargs to both generate and stream_generate so future param additions degrade gracefully instead of crashing. This is a runtime-local fix. The equivalent upstream fix lives in bdf7dcc which was never upstreamed (confirmed via git merge-base --is-ancestor bdf7dcc upstream/main). A follow-up PR to upstream could carry this forward. Verification: bin/verify-patches: 33/33 clean Full feature matrix regression sweep pending re-run after this commit. Related: runtime PR waybarrios#265 (waybarrios#265) fixed the CompletionRequest schema side of the same bdf7dcc drop; this commit fixes the engine-model side.
…er stream_generate Closes a real /v1/completions contract gap: prior to this commit, server.py::create_completion extracted per-request specprefill / specprefill_keep_pct and passed them as gen_kwargs to engine.generate(), and SimpleEngine.generate() forwarded them via **kwargs to MLXLanguageModel.generate() (which now accepts **kwargs since a8f58bf restored the bdf7dcc sampling params). Both the server and the engine advertised per-request SpecPrefill on /v1/completions, but the overrides were silently dropped because the direct self._model.generate() path never consumes the specprefill kwargs — they just flowed into **kwargs and landed in mlx_lm.generate which ignores them. Fix: make generate() a thin accumulator over stream_generate(). stream_generate() is the only code path that actually pops `specprefill` / `specprefill_keep_pct` from kwargs, threshold-checks, and routes to _stream_generate_specprefill. By iterating it and returning the last GenerationOutput, non-streaming /v1/completions clients get the same SpecPrefill engagement as streaming clients. Verified against a live Qwen3.5-4B SimpleEngine runtime (simple-spec profile) with a 27 KB prompt (~6007 tokens, under the 8192 threshold so forcibly enabled via extra_body.specprefill=true): SpecPrefill: scored 6007 tokens in 5.3s, sparse prefill 1815/6007 (keep=30%) in 1.1s prompt_tokens reporting is now accurate (was always 0 on the old direct LLM path because mlx_lm.generate() never sets it). Scope: only generate() gets the accumulator treatment in this commit. chat() stays on the direct path for non-tool LLM cases; the tool-enabled accumulator from PR waybarrios#222 cherry-pick handles the chat side. Full chat() accumulator refactor is a follow-up that will layer on top of PR waybarrios#222 once it merges upstream. This commit is independently upstreamable as a follow-up to PR waybarrios#248 (Phase 4 decode fix). Worth filing as a small PR on waybarrios/vllm-mlx to close the same gap upstream.
|
Hey @Vigilans - the pipelined decode approach is correct and we've validated the same pattern independently on our local staging. Big speedup. Currently conflicting with main - can you rebase? Good to merge after that. |
3b728b0 to
8f3b753
Compare
Rebased on latest main and fixed the ruff lint errors from the previous CI run. |
janhilgard
left a comment
There was a problem hiding this comment.
+1, agreeing with Thump604's approval. The fix is correct and well-scoped.
The manual model() + mx.eval(y) per-token loop was indeed the bottleneck -- it serializes every decode step and prevents MLX's kernel pipelining from overlapping compute and memory operations. Routing through mlx_lm.stream_generate(prompt=mx.array([first_token_id]), prompt_cache=cache) is the right fix.
A few minor observations (none blocking):
-
stream_generatesignature broadening inllm.py-- acceptingUnion[str, mx.array, list[int]]is a nice generalization. TheTYPE_CHECKINGguard forimport mlx.core as mxis correct sincemxis only needed for the type annotation. -
First token is decoded standalone (
self._text_tokenizer.decode([first_token_id])) before handing off tostream_generate. This means the first token's text may differ slightly from whatstream_generate's internal detokenizer would produce (e.g., for tokens that need surrounding context for correct decode, like byte-fallback tokens). In practice this is unlikely to matter for the first token after SpecPrefill, but worth being aware of. -
The MLLM path (second
_run_specprefillat ~line 1204) usesmlx_stream_generate(the rawmlx_lmfunction) directly, while the LLM path usesself._model.stream_generate(the wrapper). This asymmetry is fine since the MLLM path needs the raw model + tokenizer, but it means MTP is not available in the MLLM SpecPrefill path (the wrapper handles MTP kwargs). Not a regression since MTP was not available there before either.
CI has no checks running -- would be good to trigger a run to confirm green before merge. Branch is mergeable with no conflicts.
|
This is approved by both reviewers and ready to go. Could you push a rebase onto current main to trigger CI? The branch hasn't had checks run. |
…prompt Add `prompt_cache` parameter to `MLXLanguageModel.stream_generate()` so callers can pass a pre-populated KV cache (e.g. from SpecPrefill sparse prefill). Broaden the `prompt` type from `str` to `str | mx.array | list[int]` since `mlx_lm.stream_generate` already supports all three; gate `tokenizer.encode()` behind an `isinstance` check to avoid encoding non-string prompts.
…SpecPrefill Phase 4 The Phase 4 decode loop in SpecPrefill called model() + mx.eval() synchronously in a plain Python loop, resulting in ~0.3 tok/s decode (~100x slower than normal generation). The GPU idled while Python prepared the next step because there was no async pipelining. Replace the manual loop with the engine's standard pipelined generation: - Non-MLLM path: hand off to self._model.stream_generate() with the pre-populated sparse-prefill cache via the new prompt_cache parameter - MLLM text-only path: hand off to mlx_lm.stream_generate() (matching the normal path which also calls mlx_lm directly) Both paths keep the first-token EOS check since stream_generate does not validate whether the prompt token itself is EOS. Fixes waybarrios#247
Cover the new prompt_cache and non-string prompt parameters added to MLXLanguageModel.stream_generate(): - prompt_cache with a pre-populated KV cache and mx.array prompt - list[int] prompt verifying prompt_tokens count
8f3b753 to
34023e1
Compare
Summary
Fixes #247. Try to make my first reviewed contribution here :)
SpecPrefill Phase 4 called
model()+mx.eval()in a plain Python loopwithout
mx.async_evalpipelining, causing the GPU to idle between steps(~0.3 tok/s vs 30+ tok/s normal decode).
MLXLanguageModel.stream_generate()to accept aprompt_cacheparameter and
mx.array/list[int]prompt types, so SpecPrefill canhand off the pre-populated sparse-prefill cache to the engine's standard
pipelined decode path
self._model.stream_generate(prompt_cache=cache)mlx_lm.stream_generate(prompt_cache=bc)(matching its normal path)
stream_generatedoes notvalidate whether the prompt itself is EOS
Design decisions
Handoff target: Model layer
self._model.stream_generateThere are three layers that can drive decode: Engine layer
(
SimpleEngine.stream_generate), Model layer(
MLXLanguageModel.stream_generate), andmlx_lm.stream_generatedirectly.The Engine layer is not suitable because
_stream_generate_specprefillalreadyholds
self._generation_lock(asyncio.Lock, non-reentrant) — calling backinto
SimpleEngine.stream_generatewould deadlock. It would also re-runSpecPrefill eligibility checks and re-tokenize the prompt unnecessarily
(though no infinite recursion: the single-token prompt falls below the
threshold).
The Model layer (
self._model.stream_generate) is the right fit. Thenon-MLLM normal decode path already calls it (line 362), so routing
SpecPrefill Phase 4 through the same method maintains symmetry and reuses
its stop sequence handling. For the MLLM text-only path,
self._text_modelis a raw
mlx_lm TextModel(nn.Module built bybuild_text_model) with nostream_generatemethod, so it callsmlx_lm.stream_generatedirectly —again symmetric with the MLLM normal path (line 1090).
Extending
MLXLanguageModel.stream_generateinterfaceThe Model layer originally only accepted
prompt: strand had no way topass a pre-populated cache. Two changes were needed:
prompt_cacheparameter: forwarded via kwargs tomlx_lm.stream_generate, which passes it through togenerate_step.When
prompt_cacheis not None,generate_stepskips creating a newcache and uses the provided one directly.
prompttype broadened tostr | mx.array | list[int]: Phase 4passes
mx.array([first_token_id])as prompt.mlx_lm.stream_generatealready accepts all three types natively; the only adaptation needed was
gating
tokenizer.encode()behind anisinstance(prompt, str)checkfor the prompt token count.
Test plan
test_model_stream_generatestill passestest_model_stream_generate_with_prompt_cache,test_model_stream_generate_with_list_promptdecode runs at normal speed