[Gemma 4] Add DFLASH speculative decoding support#24985
Open
dssugar wants to merge 1 commit into
Open
Conversation
Adds set_dflash_layers_to_capture to Gemma4ForCausalLM and
Gemma4ForConditionalGeneration, mirroring the existing EAGLE3 capture path.
This unblocks DFLASH v1 with z-lab/gemma-4-{31B-it,26B-A4B-it}-DFlash on
top of any Gemma 4 target.
Gemma 4 ties lm_head to embed_tokens (a plain nn.Embedding subclass, not a
VocabParallelEmbedding), so dflash_worker._prepare_for_speculative_decoding
rejects it at hasattr(lm_head, "shard_indices"). Adds a tiny helper that
injects a trivial tp=1 VocabParallelEmbeddingShardIndices onto the tied
lm_head; the worker's fast path (tp_size == 1 and num_added == 0) handles
greedy verification without touching TP / added-vocab branches.
The MM class previously didn't expose self.lm_head; this commit makes it an
alias to language_model.embed_tokens so the DFLASH worker can find it via
target_model.lm_head.
Verified locally on a single RTX 5090 (sm120, triton attention backend) with
RedHatAI/gemma-4-31B-it-NVFP4 as target and z-lab/gemma-4-31B-it-DFlash as
drafter:
- code 100w warm: 158.4 tok/s (vs MTP baseline 83.8, 1.89x)
- haiku warm: 92.9 tok/s
- jp warm: 64.0 tok/s
- server-log accept length (code peak): 4.47, accept rate 0.23
Scope verified is short prompts with temperature=0; longer / multi-turn /
streaming and the gibberish symptom reported on vllm-project/vllm#41262
(TP=2) have not been retested.
Refs sgl-project#23000 (comment).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Collaborator
|
Took a look, and this seems fine to me, though I’ll defer to @hnyls2002 or @kpham-sgl on the preferred style. This PR adds a Gemma 4 compatibility shim for the current DFLASH worker assumptions, while SGLang generally expects models to opt in by exposing the speculative-decoding interface explicitly. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Adds DFLASH speculative decoding support for Gemma 4 (both
Gemma4ForCausalLMand
Gemma4ForConditionalGeneration), mirroring the existing EAGLE3 capturehook. Picks up @dcw02's invitation in #23000 to "put up a PR to add gemma 4
support to v1" — see #23000 (comment).
Once merged, this unblocks
z-lab/gemma-4-31B-it-DFlashandz-lab/gemma-4-26B-A4B-it-DFlashagainst any Gemma 4 target.Changes
gemma4_causal.py/gemma4_mm.py: addset_dflash_layers_to_capturenext to the existing
set_eagle3_layers_to_capture. Uses the samemodel.layers_to_capture = [val + 1 for val in layer_ids]convention.Unlike EAGLE3, DFLASH requires explicit
layer_ids(the checkpoint'sdflash_config.target_layer_ids), soNoneraises._ensure_dflash_shard_indiceshelper (ingemma4_causal.py):Gemma 4 ties
lm_headtoembed_tokens(a plainnn.Embeddingsubclassvia
Gemma4TextScaledWordEmbedding), sodflash_worker._prepare_for_speculative_decodingrejects it athasattr(lm_head, "shard_indices"). The helper setattr-s a trivialVocabParallelEmbeddingShardIndices(tp=1,num_added=0) onto the tiedhead; the worker's fast path (
tp_size == 1 and num_added == 0) thenproceeds without touching the TP or added-vocab branches.
Gemma4ForConditionalGeneration.lm_head: previously not exposed.Aliased to
self.language_model.embed_tokensso the DFLASH worker canresolve it via
target_model.lm_head. No state-dict impact (it's apure attribute pointer, not a new submodule).
Local verification
Verified on a fresh venv built from this branch
(
pip install ./pythonfrom the branch tip),single RTX 5090 (sm120),
attention-backend=triton,kv-cache-dtype=fp8_e4m3,RedHatAI/gemma-4-31B-it-NVFP4target,z-lab/gemma-4-31B-it-DFlashdrafter, temperature=0 short prompts:Output quality looks healthy — e.g. haiku:
Also smoke-tested extended scenarios that historically triggered the
vllm-project/vllm#41262 repetitive-token symptom for adjacent stacks:
completion_tokens=600): full essayon speculative decoding, no repetitive substring patterns
(
30+ char × 4regex match count = 0).user-stated fact across turns.
per-section labels (
**1. アーキテクチャ**etc.), no degradation inthe lower-accept-rate regime where JP normally sits.
chat/completionswithstream=true):DFlash multi-token chunks are delivered correctly through the
streaming path.
For reference on the same box, NEXTN MTP gives ~84 tok/s on code
(this PR is ~1.79× on that workload), and vLLM main + MTP gives
~164 tok/s (this PR closes most of that gap on SGLang side).
Cross-checked the same patch shape on an earlier
bcf8d100base(
gemma4-mtp-finbranch) where I first found the missingset_dflash_layers_to_captureand theshard_indicesgate — samebehavior (158.4 / 92.9 / 64.0 tok/s, accept length peak 4.47).
Scope not yet exercised
--context-length 8192, defaultmem-fraction-static 0.75)temperature=0)/C-loop / repetitive-token failure mode reported on[Bug]: Gemma-4 31B with DFlash speculator produces gibberish/repetitive token loop vllm-project/vllm#41262 for
RedHatAI/gemma-4-31B-it-speculator.dflashwith TP=2 — not retested (this box is single-GPU)
Happy to extend testing if reviewers flag a specific gap.
Related
surface and does not touch V2 / overlap scheduling. V2 should compose
cleanly with this since the target-side hook is shared.