Skip to content

fix: replace manual decode loop with pipelined generation in SpecPrefill Phase 4#248

Merged
Thump604 merged 4 commits intowaybarrios:mainfrom
Vigilans:fix/specprefill-phase4-decode
Apr 17, 2026
Merged

fix: replace manual decode loop with pipelined generation in SpecPrefill Phase 4#248
Thump604 merged 4 commits intowaybarrios:mainfrom
Vigilans:fix/specprefill-phase4-decode

Conversation

@Vigilans
Copy link
Copy Markdown
Contributor

@Vigilans Vigilans commented Apr 2, 2026

Summary

Fixes #247. Try to make my first reviewed contribution here :)

SpecPrefill Phase 4 called model() + mx.eval() in a plain Python loop
without mx.async_eval pipelining, causing the GPU to idle between steps
(~0.3 tok/s vs 30+ tok/s normal decode).

  • Extend MLXLanguageModel.stream_generate() to accept a prompt_cache
    parameter and mx.array/list[int] prompt types, so SpecPrefill can
    hand off the pre-populated sparse-prefill cache to the engine's standard
    pipelined decode path
  • Replace the manual decode loop in both SpecPrefill paths:
    • Non-MLLM: hand off to self._model.stream_generate(prompt_cache=cache)
    • MLLM text-only: hand off to mlx_lm.stream_generate(prompt_cache=bc)
      (matching its normal path)
  • First-token EOS check is preserved since stream_generate does not
    validate whether the prompt itself is EOS

Design decisions

Handoff target: Model layer self._model.stream_generate

There are three layers that can drive decode: Engine layer
(SimpleEngine.stream_generate), Model layer
(MLXLanguageModel.stream_generate), and mlx_lm.stream_generate directly.

The Engine layer is not suitable because _stream_generate_specprefill already
holds self._generation_lock (asyncio.Lock, non-reentrant) — calling back
into SimpleEngine.stream_generate would deadlock. It would also re-run
SpecPrefill eligibility checks and re-tokenize the prompt unnecessarily
(though no infinite recursion: the single-token prompt falls below the
threshold).

The Model layer (self._model.stream_generate) is the right fit. The
non-MLLM normal decode path already calls it (line 362), so routing
SpecPrefill Phase 4 through the same method maintains symmetry and reuses
its stop sequence handling. For the MLLM text-only path, self._text_model
is a raw mlx_lm TextModel (nn.Module built by build_text_model) with no
stream_generate method, so it calls mlx_lm.stream_generate directly —
again symmetric with the MLLM normal path (line 1090).

Extending MLXLanguageModel.stream_generate interface

The Model layer originally only accepted prompt: str and had no way to
pass a pre-populated cache. Two changes were needed:

  1. prompt_cache parameter: forwarded via kwargs to
    mlx_lm.stream_generate, which passes it through to generate_step.
    When prompt_cache is not None, generate_step skips creating a new
    cache and uses the provided one directly.
  2. prompt type broadened to str | mx.array | list[int]: Phase 4
    passes mx.array([first_token_id]) as prompt. mlx_lm.stream_generate
    already accepts all three types natively; the only adaptation needed was
    gating tokenizer.encode() behind an isinstance(prompt, str) check
    for the prompt token count.

Test plan

  • Existing test_model_stream_generate still passes
  • New tests: test_model_stream_generate_with_prompt_cache,
    test_model_stream_generate_with_list_prompt
  • End-to-end: SpecPrefill activated on 12k-token prompt, Phase 4
    decode runs at normal speed

@Vigilans Vigilans force-pushed the fix/specprefill-phase4-decode branch from e889470 to d266b20 Compare April 2, 2026 16:31
@Vigilans
Copy link
Copy Markdown
Contributor Author

Vigilans commented Apr 2, 2026

btw, I noticed SpecPrefill currently only triggers on the streaming path (stream_chatstream_generate). The non-streaming chat() method calls self._model.chat() directly and bypasses stream_generate entirely, so SpecPrefill never activates for ·stream=False` requests. Is this intentional, or it will be implemented in the future?

@Thump604
Copy link
Copy Markdown
Collaborator

Thump604 commented Apr 7, 2026

@Vigilans, @waybarrios: independent technical review of this PR.

Verification of the fix

Confirmed against current upstream main (b4fa030). The diff modifies two _run_specprefill call sites in vllm_mlx/engine/simple.py (the SimpleEngine streaming path and the BatchedEngine SpecPrefill path) to:

  1. Sample the first token from SpecPrefill`s final logits (preserved behavior)
  2. Hand off subsequent decode to mlx_lm.stream_generate with the SpecPrefill-prepared prompt_cache

The new prompt_cache parameter is added to MLXLanguageModel.stream_generate in vllm_mlx/models/llm.py:177 and forwarded into mlx_lms stream_generatevia the existingmtp_kwargs` dict.

Technical correctness

The root cause analysis in #247 holds. The manual model() + mx.eval(y) decode loop bypasses the pipelined kernel scheduling that mlx_lm.stream_generate provides, which is why the GPU stays idle (<1W vs 90-93W) and decode crawls at 0.3 tok/s vs 30+. The fix routes through the standard stream_generate path which gets full kernel pipelining and proper async eval batching.

Both _run_specprefill call sites get the same treatment, which is correct. Both BatchedEngine and SimpleEngine had the same bug.

Backward compatibility

prompt_cache=None is the default, existing callers behave identically. The prompt: Union[str, mx.array, list[int]] signature widening is safe. Existing string callers still work, new token-list and array callers are needed for the SpecPrefill handoff.

Minor observations (not blocking)

  1. mtp_kwargs is now a slight misnomer since it also carries prompt_cache for non-MTP cases. Renaming to extra_kwargs would be clearer in a follow-up.
  2. First-token decode is duplicated between the manual sample-and-decode block and what stream_generate would do internally. This is necessary because the first token must come from SpecPrefill`s final logits, not from a fresh forward pass. A code comment noting this would help future readers, but the logic is correct.
  3. The two new tests in tests/test_llm.py exercise the new prompt_cache and list[int] prompt paths. An end-to-end SpecPrefill+stream_generate test that asserts the decode speed regression does not return would close the loop, but the production evidence in SpecPrefill Phase 4 in SimpleEngine: very slow decode using model() #247 is hard.

On the related observation in your comment

The non-streaming chat() method bypassing SpecPrefill is a separate, related gap. It is on a different code path (the bypass happens before the streaming pipeline is reached) and would need its own change. The two issues are complementary and the fixes would compose well if both land.

Recommendation

Merge candidate. Real fix to a real performance regression, hard production evidence, sound implementation, reasonable tests.

Copy link
Copy Markdown
Collaborator

@Thump604 Thump604 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a correct fix and a big one. The manual model() + mx.eval(y) per-token loop foregoes MLX's kernel pipelining, which is exactly what gets SpecPrefill Phase 4 stuck at ~0.3 tok/s while the same model at the same context runs ~30 tok/s through a normal decode. Routing through mlx_lm.stream_generate(prompt=mx.array([first_token]), prompt_cache=cache) is the clean fix.

On your note about SpecPrefill currently only triggering on the streaming path: that's a separate accumulator gap I'm working on. Once both this PR and that refactor land, non-streaming /v1/chat/completions and /v1/completions will also reach Phase 4, and this fix becomes load-bearing there too. Worth mentioning in the PR description if you want to foreshadow it.

Post-merge I'd like to run this against a 16K/64K/128K needle harness on Qwen 3.5 122B + 2B draft as an independent regression signal. If that's useful I can post the run output here as a follow-up comment.

Approving.

Thump604 added a commit to Thump604/vllm-mlx that referenced this pull request Apr 9, 2026
…l generate+stream_generate

Pre-existing regression from an earlier rebase that dropped bdf7dcc's
llm.py additions. The server.py request handlers still pass top_k,
min_p, presence_penalty, repetition_penalty through to SimpleEngine,
which forwards them via **kwargs to MLXLanguageModel.chat() (which
accepts **kwargs) which then calls self.generate(..., **kwargs). But
MLXLanguageModel.generate() and stream_generate() had been left with
only (temperature, top_p, repetition_penalty) in their signatures, so
any non-MLLM SimpleEngine request crashed with:

    TypeError: MLXLanguageModel.stream_generate() got an unexpected
    keyword argument 'top_k'

Observed as 0/6 on simple-base, simple-mtp, and simple-spec profiles in
the feature matrix regression sweep after the Session 87 cherry-picks
of PRs waybarrios#248, waybarrios#229, waybarrios#218, waybarrios#222 landed. The cherry-picks did not cause
this regression — they exposed it by finally running the LLM-path
tests that no one had exercised since the rebase happened. Confirmed
via stderr.log:

  TypeError: MLXLanguageModel.generate() got an unexpected keyword argument 'top_k'
  TypeError: MLXLanguageModel.stream_generate() got an unexpected keyword argument 'top_k'

Fix: restore the signatures and bodies of _create_sampler,
_create_logits_processors, generate, and stream_generate to match
bdf7dcc's original intent. Preserves PR waybarrios#248's prompt_cache parameter
and non-str prompt support on stream_generate. Adds **kwargs to both
generate and stream_generate so future param additions degrade
gracefully instead of crashing.

This is a runtime-local fix. The equivalent upstream fix lives in
bdf7dcc which was never upstreamed (confirmed via
git merge-base --is-ancestor bdf7dcc upstream/main). A follow-up PR
to upstream could carry this forward.

Verification:
  bin/verify-patches: 33/33 clean
  Full feature matrix regression sweep pending re-run after this commit.

Related: runtime PR waybarrios#265 (waybarrios#265) fixed the
CompletionRequest schema side of the same bdf7dcc drop; this commit
fixes the engine-model side.
Thump604 added a commit to Thump604/vllm-mlx that referenced this pull request Apr 9, 2026
…er stream_generate

Closes a real /v1/completions contract gap: prior to this commit,
server.py::create_completion extracted per-request specprefill /
specprefill_keep_pct and passed them as gen_kwargs to
engine.generate(), and SimpleEngine.generate() forwarded them via
**kwargs to MLXLanguageModel.generate() (which now accepts **kwargs
since a8f58bf restored the bdf7dcc sampling params). Both the server
and the engine advertised per-request SpecPrefill on /v1/completions,
but the overrides were silently dropped because the direct
self._model.generate() path never consumes the specprefill kwargs —
they just flowed into **kwargs and landed in mlx_lm.generate which
ignores them.

Fix: make generate() a thin accumulator over stream_generate().
stream_generate() is the only code path that actually pops
`specprefill` / `specprefill_keep_pct` from kwargs, threshold-checks,
and routes to _stream_generate_specprefill. By iterating it and
returning the last GenerationOutput, non-streaming /v1/completions
clients get the same SpecPrefill engagement as streaming clients.

Verified against a live Qwen3.5-4B SimpleEngine runtime (simple-spec
profile) with a 27 KB prompt (~6007 tokens, under the 8192 threshold
so forcibly enabled via extra_body.specprefill=true):

  SpecPrefill: scored 6007 tokens in 5.3s,
  sparse prefill 1815/6007 (keep=30%) in 1.1s

prompt_tokens reporting is now accurate (was always 0 on the old
direct LLM path because mlx_lm.generate() never sets it).

Scope: only generate() gets the accumulator treatment in this commit.
chat() stays on the direct path for non-tool LLM cases; the
tool-enabled accumulator from PR waybarrios#222 cherry-pick handles the chat
side. Full chat() accumulator refactor is a follow-up that will
layer on top of PR waybarrios#222 once it merges upstream.

This commit is independently upstreamable as a follow-up to PR waybarrios#248
(Phase 4 decode fix). Worth filing as a small PR on
waybarrios/vllm-mlx to close the same gap upstream.
@Thump604
Copy link
Copy Markdown
Collaborator

Hey @Vigilans - the pipelined decode approach is correct and we've validated the same pattern independently on our local staging. Big speedup. Currently conflicting with main - can you rebase? Good to merge after that.

@Vigilans Vigilans force-pushed the fix/specprefill-phase4-decode branch 2 times, most recently from 3b728b0 to 8f3b753 Compare April 14, 2026 05:21
@Vigilans
Copy link
Copy Markdown
Contributor Author

Hey @Vigilans - the pipelined decode approach is correct and we've validated the same pattern independently on our local staging. Big speedup. Currently conflicting with main - can you rebase? Good to merge after that.

Rebased on latest main and fixed the ruff lint errors from the previous CI run.

Copy link
Copy Markdown
Collaborator

@janhilgard janhilgard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, agreeing with Thump604's approval. The fix is correct and well-scoped.

The manual model() + mx.eval(y) per-token loop was indeed the bottleneck -- it serializes every decode step and prevents MLX's kernel pipelining from overlapping compute and memory operations. Routing through mlx_lm.stream_generate(prompt=mx.array([first_token_id]), prompt_cache=cache) is the right fix.

A few minor observations (none blocking):

  1. stream_generate signature broadening in llm.py -- accepting Union[str, mx.array, list[int]] is a nice generalization. The TYPE_CHECKING guard for import mlx.core as mx is correct since mx is only needed for the type annotation.

  2. First token is decoded standalone (self._text_tokenizer.decode([first_token_id])) before handing off to stream_generate. This means the first token's text may differ slightly from what stream_generate's internal detokenizer would produce (e.g., for tokens that need surrounding context for correct decode, like byte-fallback tokens). In practice this is unlikely to matter for the first token after SpecPrefill, but worth being aware of.

  3. The MLLM path (second _run_specprefill at ~line 1204) uses mlx_stream_generate (the raw mlx_lm function) directly, while the LLM path uses self._model.stream_generate (the wrapper). This asymmetry is fine since the MLLM path needs the raw model + tokenizer, but it means MTP is not available in the MLLM SpecPrefill path (the wrapper handles MTP kwargs). Not a regression since MTP was not available there before either.

CI has no checks running -- would be good to trigger a run to confirm green before merge. Branch is mergeable with no conflicts.

@Thump604
Copy link
Copy Markdown
Collaborator

This is approved by both reviewers and ready to go. Could you push a rebase onto current main to trigger CI? The branch hasn't had checks run.

…prompt

Add `prompt_cache` parameter to `MLXLanguageModel.stream_generate()` so
callers can pass a pre-populated KV cache (e.g. from SpecPrefill sparse
prefill).  Broaden the `prompt` type from `str` to
`str | mx.array | list[int]` since `mlx_lm.stream_generate` already
supports all three; gate `tokenizer.encode()` behind an `isinstance`
check to avoid encoding non-string prompts.
…SpecPrefill Phase 4

The Phase 4 decode loop in SpecPrefill called model() + mx.eval()
synchronously in a plain Python loop, resulting in ~0.3 tok/s decode
(~100x slower than normal generation). The GPU idled while Python
prepared the next step because there was no async pipelining.

Replace the manual loop with the engine's standard pipelined generation:
- Non-MLLM path: hand off to self._model.stream_generate() with the
  pre-populated sparse-prefill cache via the new prompt_cache parameter
- MLLM text-only path: hand off to mlx_lm.stream_generate() (matching
  the normal path which also calls mlx_lm directly)

Both paths keep the first-token EOS check since stream_generate does
not validate whether the prompt token itself is EOS.

Fixes waybarrios#247
Cover the new prompt_cache and non-string prompt parameters added to
MLXLanguageModel.stream_generate():
- prompt_cache with a pre-populated KV cache and mx.array prompt
- list[int] prompt verifying prompt_tokens count
@Vigilans Vigilans force-pushed the fix/specprefill-phase4-decode branch from 8f3b753 to 34023e1 Compare April 17, 2026 12:42
@Thump604 Thump604 merged commit b0a79f5 into waybarrios:main Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SpecPrefill Phase 4 in SimpleEngine: very slow decode using model()

3 participants