Skip to content

[Gemma 4] Adding MTP support#24436

Merged
Qiaolin-Yu merged 34 commits into
mainfrom
gemma4-mtp-fin
May 7, 2026
Merged

[Gemma 4] Adding MTP support#24436
Qiaolin-Yu merged 34 commits into
mainfrom
gemma4-mtp-fin

Conversation

@kpham-sgl
Copy link
Copy Markdown
Collaborator

Motivation

Add Multi-Token Prediction (MTP) speculative decoding for Gemma 4. Each Gemma 4 target ships with a small "assistant" checkpoint trained for MTP. This PR introduces a new speculative algorithm — FROZEN_KV_MTP — that runs the assistant against the target's KV cache (the assistant has no KV of its own and a recurrent hidden state across draft steps, so it does not fit cleanly under EAGLE/EAGLE3 or NEXTN).

NEXTN and EAGLE are auto-promoted to FROZEN_KV_MTP when --speculative-draft-model-path resolves to a Gemma4AssistantForCausalLM. EAGLE3 is rejected for this draft architecture.

Supported Targets:

Target Architecture Parameters
google/gemma-4-E2B-it Dense ~2B
google/gemma-4-E4B-it Dense ~4B
google/gemma-4-31B-it Dense 31B
google/gemma-4-26B-A4B-it MoE 26B total / 4B active

Both Gemma4ForCausalLM (text) and Gemma4ForConditionalGeneration (multimodal) targets are supported.

Installation

# Install transformers with Gemma 4 support
pip install 'git+https://github.com/huggingface/transformers.git@2c7d385621c80fee70c1472f3a622fcba2c93fb9'

Usage

Launch Server

# E4B + matched assistant, topk=1
sglang serve --model-path google/gemma-4-E4B-it \
  --speculative-algorithm NEXTN \
  --speculative-draft-model-path google/gemma-4-E4B-it-mtp \
  --speculative-num-steps 5 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 6 \

# Tree verify (topk=3)
  --speculative-eagle-topk 3 --speculative-num-draft-tokens 12

# Tree verify (topk=5)
  --speculative-eagle-topk 5 --speculative-num-draft-tokens 16

When FROZEN_KV_MTP is active, the overlap scheduler and mixed chunked prefill are force-disabled, and --max-running-requests defaults to 48 if unset.

Accuracy Tests

MTP matches the target-only baseline across GPQA Diamond (CoT / generative) and GSM8K (CoT, 0-shot):

Topk = 1

Model Config GPQA-D CoT GPQA-D gen GSM8K CoT
E2B baseline 0.3384 0.3838 0.7369
E2B MTP 0.3586 0.3889 0.7369
E4B baseline 0.5505 0.3939 0.7582
E4B MTP 0.5303 0.3990 0.7597
31B baseline 0.7172 0.6111 0.8802
31B MTP 0.7121 0.6162 0.8817
26B-A4B baseline 0.7172 0.5455 0.8772
26B-A4B MTP 0.7222 0.5758 0.8696

Speed Tests and Profiling

Will follow up with FROZEN_KV_MTP speed benchmark and profiling later

Modifications

  • New assistant model: gemma4_mtp.py (Gemma4AssistantForCausalLM) — target-embed + recurrent hidden via pre/post projection, owns its own lm_head, optional centroid-ordered logits head.
  • New speculative algorithm FROZEN_KV_MTP in speculative/spec_info.py (with FROZEN_KV_MTP_DRAFT / FROZEN_KV_MTP_VERIFY spec input types).
  • New worker + utilities: frozen_kv_mtp_worker.py, frozen_kv_mtp_info.py, frozen_kv_mtp_utils.py, frozen_kv_mtp_cuda_graph_runner.py (supports topk=1 and tree verify for topk>1, with full CUDA graph capture).
  • Frozen-KV binding collapses HF Gemma 4's typed-last-layer + kv_shared_layer_index two-hop into a direct assistant-logical → target-physical layer map. Assistant KV writes are suppressed.
  • server_args.py: alias resolution promotes NEXTN/EAGLEFROZEN_KV_MTP for Gemma 4 assistant drafts and rejects EAGLE3. Force-disables overlap scheduler and mixed chunked prefill.
  • hf_transformers/config.py: recognize model_type == "gemma4_assistant" for SWA attribute remap.
  • gemma4_causal.py / gemma4_mm.py: expose get_embed_and_head() so the assistant can rebind to the target's input embedding at load time.
  • Manual GSM8K validation harness in test/manual/models/test_gemma4_mtp.py (target-only baseline → topk=1 MTP → topk>1 MTP).

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's implemented in a separate spec algo, I think we could merge it safely anyway. We could do some refactor/clean later if possible.

@github-actions github-actions Bot added the blackwell SM100/SM120 label May 6, 2026
@kpham-sgl
Copy link
Copy Markdown
Collaborator Author

Added new stage b CI test for E4B variant. Will add the bigger model test to nightly in a separate PR

@kpham-sgl kpham-sgl requested a review from Qiaolin-Yu May 6, 2026 06:13
@kpham-sgl
Copy link
Copy Markdown
Collaborator Author

/rerun-test test/registered/spec/test_frozen_kv_mtp.py

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

1-gpu-h100-h200 (1 test): View workflow run

cd test/ && python3 registered/spec/test_frozen_kv_mtp.py

@kpham-sgl
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label May 6, 2026
@kpham-sgl
Copy link
Copy Markdown
Collaborator Author

/rerun-test test/registered/spec/test_frozen_kv_mtp.py

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

1-gpu-h100 (1 test): View workflow run

cd test/ && python3 registered/spec/test_frozen_kv_mtp.py

@kpham-sgl
Copy link
Copy Markdown
Collaborator Author

kpham-sgl commented May 6, 2026

/rerun-failed-ci two

@Qiaolin-Yu Qiaolin-Yu merged commit d2c1034 into main May 7, 2026
252 of 303 checks passed
@Qiaolin-Yu Qiaolin-Yu deleted the gemma4-mtp-fin branch May 7, 2026 21:08
Dogacel pushed a commit to Dogacel/sglang-fork that referenced this pull request May 8, 2026
Co-authored-by: Pengyu Chen <pychen96@gmail.com>
LLThomas pushed a commit to LLThomas/sglang that referenced this pull request May 8, 2026
Co-authored-by: Pengyu Chen <pychen96@gmail.com>
LucQueen pushed a commit to LucQueen/sglang that referenced this pull request May 12, 2026
Co-authored-by: Pengyu Chen <pychen96@gmail.com>
vroomfondel added a commit to vroomfondel/dgxarley that referenced this pull request May 13, 2026
**Gemma-4 MTP (Frozen-KV) speculative-decoding patch** — the `0.5.11-gemma4-sm121` tag carries a cherry-pick of upstream [PR #24436](sgl-project/sglang#24436), ("Gemma 4 — Adding MTP support", merged 2026-05-07, after the v0.5.11 release tag). Adds the dedicated `Gemma4AssistantForCausalLM` model and a new `FROZEN_KV_MTP` speculative algorithm (recurrent hidden-state draft loop with frozen target KV cache). At runtime SGLang auto-promotes `--speculative-algorithm NEXTN → FROZEN_KV_MTP` once the drafter is detected. Without this patch the stock NEXTN/EAGLE worker crashes with `ValueError: No module or parameter named 'model.language_model' in TransformersMultiModalForCausalLM` during drafter weight-load.

Verified working on the 4-node DGX Spark cluster — see the 31B-it TESTLOG, [Test 07 (`num_steps=2`, `num_draft_tokens=3`)](https://github.com/vroomfondel/dgxarley/blob/main/TESTLOGS/sglang_nn4_tp4_ep1/gemma-4-31b-it/TESTLOG_nv580.142_sglang-0.5.11_gemma-4-31b-it_4n.md#mtp-sweep-tests-711--partial-15-cases-done): +98 % at n=1 (10.49 → 20.83 tok/s), +76 % at n=4 (44.06 → 77.67 tok/s), drafter acceptance rate median ~0.68, 5/5 requests stopped on natural EOS. The 26B-A4B MoE sibling's MTP sweep is still in progress ([TESTLOG](https://github.com/vroomfondel/dgxarley/blob/main/TESTLOGS/sglang_nn4_tp4_ep1/gemma-4-26b-a4b-it/TESTLOG_nv580.142_sglang-0.5.11_gemma-4-26b-a4b-it_4n.md)).

P.S.: "In the end, it's not about how hard you can hit, but how accurately you can score." - Tom Cruise
Jiminator added a commit to Jiminator/sglang that referenced this pull request May 15, 2026
…2c1034

Two findings appended to the bisect report:

1. PR sgl-project#25335 ("Fix gpt oss triton kernels and upgrade flashinfer back
   to 0.6.11.post1") re-bumped flashinfer past PR sgl-project#25310's revert.
   The one-line fix in fp4_utils.py:22 (cute-dsl -> cuda) is therefore
   no longer sufficient on latest main: experiment G reproduces the
   strict cuda-side check from fp4Quantize.cpp:64 ("globalScale should
   have shape [1] or [num_tokens]"), identical to experiment C.
   The proper fix is now at the call site in
   compressed_tensors_w4a4_nvfp4_moe.py:315: collapse
   layer.w13_input_scale_quant (shape [num_experts]) to scalar [1] or
   per-token [num_tokens] before passing as global_scale.

2. The TP8+MTP variant has its own separate pre-existing regression,
   bisected to d2c1034 ("[Gemma 4] Adding MTP support", PR sgl-project#24436).
   That PR added _resolve_speculative_algorithm_alias in
   server_args.py:318-342 which unconditionally calls
   AutoConfig.from_pretrained on the draft path to detect Gemma4
   drafts. It crashes on any draft in Mistral native format (params.json,
   no HF config.json), even when --speculative-algorithm is already
   explicit EAGLE.

Empirical proof for (2):
- d2c1034 + TP8+MTP-only test: FAIL with
  "Unrecognized model in ...Eagle. Should have a model_type key in
  its config.json", total wall time 60.7s (crashes before model load).
- f1395af (parent of d2c1034) + same test: PASS, gsm8k 0.949.

Both with flashinfer 0.6.8.post1, sglang-kernel 0.4.2.post1+cu130,
torch 2.11.0+cu130, SGLANG_IS_IN_CI=true, SGLANG_ENABLE_JIT_DEEPGEMM=0,
SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1.

Minimal fix for (2): wrap the AutoConfig.from_pretrained call in
_resolve_speculative_algorithm_alias with try/except, or
short-circuit when speculative_algorithm is already explicit and
the user did not request NEXTN aliasing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Jiminator added a commit to Jiminator/sglang that referenced this pull request May 15, 2026
…5407

The Mistral-Large-3 B200 nightly partition has been red because of
TWO independent regressions sharing the same job. Keeping them in one
document is misleading — different root causes, different fixes,
different PRs. This split:

- Creates mistral_large3_tp8_mtp_b200_bisect_report.md with all
  TP8+MTP-specific content (root cause d2c1034 / PR sgl-project#24436, the
  _resolve_speculative_algorithm_alias crash on Mistral-native-format
  drafts, the AutoConfig.from_pretrained ValueError, the empirical
  one-commit bisect d2c1034 vs f1395af, the proposed try/except fix,
  the maintainer-ready server log block, and the CI-visibility table).

- Strips the same content out of
  mistral_large3_nvfp4_b200_bisect_report.md, replacing it with
  cross-references in the header, Open Items, follow-up note, and TL;DR.

- Adds a PR sgl-project#25407 verification section to BOTH documents (NVFP4 doc
  records that PR sgl-project#25407 fixes its issue with gsm8k 0.957; TP8+MTP doc
  records that PR sgl-project#25407 explicitly does NOT touch server_args.py and
  the failure remains identical).

Run summary on PR sgl-project#25407 head e3fb4ee (1574s wall time, 8x B200,
flashinfer 0.6.11.post1, sglang-kernel 0.4.2.post2+cu130, torch 2.11.0):
  - TP8        PASS  gsm8k 0.953
  - TP8+MTP    FAIL  unchanged ValueError (server_args.py:329)
  - NVFP4      PASS  gsm8k 0.957

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@yuan-luo
Copy link
Copy Markdown
Collaborator

Now the gemma4 mtp model has renamed to google/gemma-4-E4B-it-assistant. FYI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants