[Gemma4] Add test for MTP models #24552
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Frozen-KV MTP (Multi-Token Prediction) speculative decoding for Gemma 4 models, introducing a new assistant model class, worker, and CUDA graph runner. The review highlights critical issues regarding Tensor Parallelism compatibility in embedding and linear layers, suggests avoiding torch.cuda.empty_cache() to prevent performance degradation, recommends copying server_args for isolation, and advises using specific exceptions for better error handling.
I am having trouble creating individual review comments. Click here to see my feedback.
python/sglang/srt/models/gemma4_mtp.py (234)
Using torch.nn.functional.embedding directly on self.target_embed_weight will fail when Tensor Parallelism (TP) is enabled (TP > 1). In SGLang, target_embed_weight is typically a partitioned weight from a VocabParallelEmbedding. A direct lookup will only work for tokens within the local GPU's vocabulary range and will not perform the necessary communication (all-reduce) to produce the full embedding vector on all GPUs. This will lead to incorrect results or IndexError on GPUs where the token IDs are out of the local partition range.
python/sglang/srt/models/gemma4_mtp.py (111)
Using nn.Linear for lm_head bypasses Tensor Parallelism. For large vocabulary sizes (e.g., Gemma's 256k), this replicates a large weight matrix (approx. 1.5GB for bf16) on every GPU, which is memory-inefficient and inconsistent with SGLang's standard use of ColumnParallelLinear for output heads. This also applies to the centroids layer on line 125.
python/sglang/srt/models/gemma4_mtp.py (211)
Calling torch.cuda.empty_cache() inside a model method is generally discouraged as it can cause significant performance overhead due to GPU synchronization and fragmentation. It is better to let the high-level memory manager or the user handle cache clearing if necessary.
python/sglang/srt/speculative/frozen_kv_mtp_worker.py (117)
Modifying server_args.context_length in-place can have unintended side effects if the server_args object is shared with other components (like the target worker). It is safer to create a copy of the arguments for the draft worker to ensure isolation.
import copy
server_args = copy.copy(server_args)
server_args.context_length = target_worker.model_runner.model_config.context_len
python/sglang/srt/server_args.py (3465)
Hardcoding max_running_requests to 48 when using FROZEN_KV_MTP seems arbitrary and might be too restrictive for some hardware configurations. Consider making this a configurable default or providing a more detailed rationale for this specific limit.
python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py (146-149)
Raising a generic Exception is discouraged. Using a more specific exception like RuntimeError is preferred for better error handling and clarity.
except RuntimeError as e:
raise RuntimeError(
f"Capture frozen-KV MTP cuda graph failed: {e}\n"
f"{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
) from e
bced3e0 to
b780757
Compare
Per #25197, stage-shaped per-commit CUDA suites must register via stage=/runner_config= kwargs; the legacy suite= form is reserved for nightly/stress/weekly and non-stage backends. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/tag-run-ci-label |
The per-commit CUDA suites were renamed from stage-* to base-* on main; test_frozen_kv_mtp was still registering stage="stage-b", which fails validate_all_suites with "Tests registered to invalid suites". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
These tests are gemma4-specific and only relevant for PRs touching that model path. Move them off the nightly schedule onto extra-a so they: - gate on per-commit signal when a PR opts in with run-ci-extra label - stop consuming nightly slot for changes unrelated to gemma4 Both fit extra-a-test-2-gpu-large (TP=2, ~720s each). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
eval/ skews nightly (its other 2 files are nightly=True); these two are now extra-a per-commit speculative-decoding tests. Their sibling test_frozen_kv_mtp.py already lives in spec/, and spec/ uses the _extra suffix to mark extra-* registrations (test_spec_ngram_extra.py, test_spec_standalone_extra.py). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/rerun-failed-ci |
Motivation
Add test for PR #24436. Require transformer v5.8.0
Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): ❌ Run #26135220858
Latest PR Test (Extra): ❌ Run #26135220756