Skip to content

fix nemotron capture for non attention layers#21436

Merged
Fridge003 merged 7 commits intosgl-project:mainfrom
vedantjh2:vjhaveri/nemotron
Mar 30, 2026
Merged

fix nemotron capture for non attention layers#21436
Fridge003 merged 7 commits intosgl-project:mainfrom
vedantjh2:vjhaveri/nemotron

Conversation

@vedantjh2
Copy link
Copy Markdown
Contributor

Motivation

Piecewise CUDA graph capture was silently disabled for NemotronH hybrid models (e.g., NVIDIA-Nemotron-Nano-9B-v2). The init_piecewise_cuda_graphs layer discovery loop only appended to attention_layers when it found an attention or mamba layer. NemotronH's pure MLP (- pattern) and MoE layers use the mixer attribute but have neither .attn nor ._forward_mamba, so they were skipped entirely. This caused len(attention_layers) < num_hidden_layers, triggering the early bail-out with "Disable piecewise CUDA graph because some layers do not apply Standard GQA".

Modifications

python/sglang/srt/model_executor/model_runner.py

In the init_piecewise_cuda_graphs method, updated the layer discovery logic to handle NemotronH-style hybrid models where some layers are pure MLP/MoE (accessed via mixer but without .attn or ._forward_mamba):

  • When a layer has a mixer attribute but is neither attention nor mamba, append None as a positional placeholder to attention_layers. This keeps the list aligned with layer indices so that split ops like nemotron_mamba2_with_output can correctly index by layer_id.
  • The None placeholder is only appended for layers that enter the mixer branch — models without a mixer attribute (e.g., LFM2 conv layers) are unaffected, preserving the existing safety check that disables piecewise CUDA graph for unsupported architectures.

Accuracy Tests

GSM8K accuracy is identical with and without piecewise CUDA graph:

Configuration Accuracy Invalid
Without PCG 0.895 0.005
With PCG 0.895 0.005

Benchmarking and Profiling

Model: NVIDIA-Nemotron-Nano-9B-v2 on a single GPU.

Benchmark: benchmark/gsm8k/bench_sglang.py (200 samples)

Configuration Latency (s) Output Throughput (token/s) Speedup
Without PCG (--disable-piecewise-cuda-graph) 16.159 1506.922
With PCG (this PR) 13.723 1763.156 ~17%

Baseline:

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:16<00:00, 12.38it/s]
Accuracy: 0.895
Invalid: 0.005
Latency: 16.159 s
Output throughput: 1506.922 token/s

PCG Enabled:

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:13<00:00, 14.58it/s]
Accuracy: 0.895
Invalid: 0.005
Latency: 13.723 s
Output throughput: 1763.156 token/s

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@vedantjh2 vedantjh2 marked this pull request as ready for review March 26, 2026 00:13
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Oasis-Git
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@he-weiwen
Copy link
Copy Markdown

he-weiwen commented Mar 26, 2026

I just tried applying this fix on my local RTX 4090 and I got this:

Piecewise CUDA Graph failed with error: q.shape[0] (8) does not match qo_indptr[-1] (7). For paged prefill, q must have shape [total_tokens, num_heads, head_dim] where total_tokens = qo_indptr[-1].
Piecewise CUDA Graph is enabled by default as an experimental feature.
To work around this error, add --disable-piecewise-cuda-graph to your launch command.
Please report this issue at https://github.com/sgl-project/sglang/issues/new/choose

Is it working on your side?

*Edit: Seems to be an flashinfer backend issue described in #21218

@vedantjh2
Copy link
Copy Markdown
Contributor Author

@he-weiwen yes it works on my end. does the issue you linked need to be merged before I can successfully run CI for this?

@Fridge003
Copy link
Copy Markdown
Collaborator

@he-weiwen This issue is fixed in #21452

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-ut test_nvidia_nemotron_nano_v2.py

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-ut test_nvidia_nemotron_3_nano.py

@github-actions
Copy link
Copy Markdown
Contributor

2-gpu-h100: View workflow run

cd test/ && python3 registered/models/test_nvidia_nemotron_nano_v2.py

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-ut test_nvidia_nemotron_3_super_bf16.py

@github-actions
Copy link
Copy Markdown
Contributor

2-gpu-h100: View workflow run

cd test/ && python3 registered/models/test_nvidia_nemotron_3_nano.py

@github-actions
Copy link
Copy Markdown
Contributor

8-gpu-h200: View workflow run

cd test/ && python3 registered/8-gpu-models/test_nvidia_nemotron_3_super_bf16.py

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-ut test_nvidia_nemotron_3_super_nvfp4.py

@github-actions
Copy link
Copy Markdown
Contributor

4-gpu-b200: View workflow run

cd test/ && python3 registered/4-gpu-models/test_nvidia_nemotron_3_super_nvfp4.py

@Fridge003 Fridge003 merged commit 4a9ffc3 into sgl-project:main Mar 30, 2026
76 of 100 checks passed
LucQueen pushed a commit to LucQueen/sglang that referenced this pull request Mar 31, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants