Skip to content

[Gemma4] Add test for MTP models #24552

Open
kpham-sgl wants to merge 7 commits into
mainfrom
gemma4-mtp-add-test
Open

[Gemma4] Add test for MTP models #24552
kpham-sgl wants to merge 7 commits into
mainfrom
gemma4-mtp-add-test

Conversation

@kpham-sgl
Copy link
Copy Markdown
Collaborator

@kpham-sgl kpham-sgl commented May 6, 2026

Motivation

Add test for PR #24436. Require transformer v5.8.0

Modifications

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

CI States

Latest PR Test (Base): ❌ Run #26135220858
Latest PR Test (Extra): ❌ Run #26135220756

@kpham-sgl kpham-sgl changed the base branch from main to gemma4-mtp-fin May 6, 2026 21:44
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Frozen-KV MTP (Multi-Token Prediction) speculative decoding for Gemma 4 models, introducing a new assistant model class, worker, and CUDA graph runner. The review highlights critical issues regarding Tensor Parallelism compatibility in embedding and linear layers, suggests avoiding torch.cuda.empty_cache() to prevent performance degradation, recommends copying server_args for isolation, and advises using specific exceptions for better error handling.

I am having trouble creating individual review comments. Click here to see my feedback.

python/sglang/srt/models/gemma4_mtp.py (234)

high

Using torch.nn.functional.embedding directly on self.target_embed_weight will fail when Tensor Parallelism (TP) is enabled (TP > 1). In SGLang, target_embed_weight is typically a partitioned weight from a VocabParallelEmbedding. A direct lookup will only work for tokens within the local GPU's vocabulary range and will not perform the necessary communication (all-reduce) to produce the full embedding vector on all GPUs. This will lead to incorrect results or IndexError on GPUs where the token IDs are out of the local partition range.

python/sglang/srt/models/gemma4_mtp.py (111)

medium

Using nn.Linear for lm_head bypasses Tensor Parallelism. For large vocabulary sizes (e.g., Gemma's 256k), this replicates a large weight matrix (approx. 1.5GB for bf16) on every GPU, which is memory-inefficient and inconsistent with SGLang's standard use of ColumnParallelLinear for output heads. This also applies to the centroids layer on line 125.

python/sglang/srt/models/gemma4_mtp.py (211)

medium

Calling torch.cuda.empty_cache() inside a model method is generally discouraged as it can cause significant performance overhead due to GPU synchronization and fragmentation. It is better to let the high-level memory manager or the user handle cache clearing if necessary.

python/sglang/srt/speculative/frozen_kv_mtp_worker.py (117)

medium

Modifying server_args.context_length in-place can have unintended side effects if the server_args object is shared with other components (like the target worker). It is safer to create a copy of the arguments for the draft worker to ensure isolation.

        import copy
        server_args = copy.copy(server_args)
        server_args.context_length = target_worker.model_runner.model_config.context_len

python/sglang/srt/server_args.py (3465)

medium

Hardcoding max_running_requests to 48 when using FROZEN_KV_MTP seems arbitrary and might be too restrictive for some hardware configurations. Consider making this a configurable default or providing a more detailed rationale for this specific limit.

python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py (146-149)

medium

Raising a generic Exception is discouraged. Using a more specific exception like RuntimeError is preferred for better error handling and clarity.

        except RuntimeError as e:
            raise RuntimeError(
                f"Capture frozen-KV MTP cuda graph failed: {e}\n"
                f"{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
            ) from e

Base automatically changed from gemma4-mtp-fin to main May 7, 2026 21:08
@kpham-sgl kpham-sgl force-pushed the gemma4-mtp-add-test branch from bced3e0 to b780757 Compare May 10, 2026 08:03
Per #25197, stage-shaped per-commit CUDA suites must register via
stage=/runner_config= kwargs; the legacy suite= form is reserved for
nightly/stress/weekly and non-stage backends.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@kpham-sgl kpham-sgl changed the title [WIP][Gemma4] Add test for MTP models [Gemma4] Add test for MTP models May 19, 2026
@kpham-sgl
Copy link
Copy Markdown
Collaborator Author

/tag-run-ci-label

kpham-sgl and others added 5 commits May 19, 2026 15:31
The per-commit CUDA suites were renamed from stage-* to base-* on main;
test_frozen_kv_mtp was still registering stage="stage-b", which fails
validate_all_suites with "Tests registered to invalid suites".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
These tests are gemma4-specific and only relevant for PRs touching that
model path. Move them off the nightly schedule onto extra-a so they:
- gate on per-commit signal when a PR opts in with run-ci-extra label
- stop consuming nightly slot for changes unrelated to gemma4

Both fit extra-a-test-2-gpu-large (TP=2, ~720s each).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
eval/ skews nightly (its other 2 files are nightly=True); these two are
now extra-a per-commit speculative-decoding tests. Their sibling
test_frozen_kv_mtp.py already lives in spec/, and spec/ uses the _extra
suffix to mark extra-* registrations (test_spec_ngram_extra.py,
test_spec_standalone_extra.py).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@kpham-sgl
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant