Skip to content

Enable Piecewise CUDA Graph for NemotronH Hybrid (Mamba+Attention) Models#19903

Merged
ispobock merged 22 commits intosgl-project:mainfrom
vedantjh2:vjhaveri/fix_nemotron
Mar 12, 2026
Merged

Enable Piecewise CUDA Graph for NemotronH Hybrid (Mamba+Attention) Models#19903
ispobock merged 22 commits intosgl-project:mainfrom
vedantjh2:vjhaveri/fix_nemotron

Conversation

@vedantjh2
Copy link
Copy Markdown
Contributor

@vedantjh2 vedantjh2 commented Mar 4, 2026

Motivation

Piecewise CUDA graph (PCG) was previously disabled for NemotronH models because the layer detection logic required all layers to use standard GQA attention. NemotronH is a hybrid architecture (4 Attention + 24 Mamba + 24 MLP across 52 layers) where all sublayers use a mixer attribute instead of self_attn, causing the detection to fail with:

Disable piecewise CUDA graph because some layers do not apply Standard GQA

Changes

model_runner.py

  • Added mixer attribute detection for NemotronH-style hybrid models
  • Every layer now gets an entry in attention_layers (None for non-attention layers), enabling PCG to handle sparse attention architectures
  • Relaxed validation: PCG is only disabled when no attention layers are found, rather than when any layer lacks attention

nemotron_h.py

  • Extracted _forward_mamba() method from NemotronHMambaDecoderLayer.forward()
  • Added nemotron_mamba2_with_output split op (using register_custom_op + register_split_op) to enable graph breaks around Mamba layers during PCG capture
  • Added token slicing in the split op to handle padded CUDA graph buffers (Mamba asserts on exact token counts)
  • Changed Layers from union type (A | B | C | D) to tuple for torch.compile compatibility

Benchmark Results

Model: NVIDIA-Nemotron-Nano-9B-v2 (H100, bfloat16)
Benchmark: GSM8K (200 samples)

Configuration Throughput (tok/s) Latency (s) Accuracy
Baseline (no PCG) 1521 15.74 88.0%
With PCG (inductor) 1681 14.29 89.0%
With PCG (eager) 1542.8 15.60 89.0%
Improvement vs baseline (inductor) +10.5% -9.2%
Improvement vs baseline (eager) +1.4% -0.9%

Notes

  • PCG provides the best gains when paired with the inductor compiler backend in this setup.
  • PCG eager achieves essentially the same accuracy as PCG+inductor (89.0%), with smaller performance gains.

Launch commands

# Baseline
python -m sglang.launch_server \
  --model-path nvidia/NVIDIA-Nemotron-Nano-9B-v2 \
  --host 0.0.0.0 --port 30000

# With PCG (inductor)
python -m sglang.launch_server \
  --model-path nvidia/NVIDIA-Nemotron-Nano-9B-v2 \
  --host 0.0.0.0 --port 30000 \
  --enable-piecewise-cuda-graph \
  --piecewise-cuda-graph-compiler inductor

# With PCG (eager)
python -m sglang.launch_server \
  --model-path /shared/public/elr-models/nvidia/NVIDIA-Nemotron-Nano-9B-v2 \
  --host 0.0.0.0 --port 30000 \
  --enable-piecewise-cuda-graph \
  --piecewise-cuda-graph-compiler eager

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

# captured graph size. Slice to actual token count for Mamba forward.
attn_backend = forward_batch.attn_backend
metadata = attn_backend.linear_attn_backend.forward_metadata
num_actual_tokens = metadata.num_prefill_tokens + (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does only mamba need this special shape handle, can't we know the exact output shape before?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During CUDA graph replay, [hidden_states] is padded to the captured graph batch size. Attention handles this naturally via KV cache and masks, but Mamba processes tokens sequentially through conv/SSM states and asserts [num_actual_tokens == projected_states.shape[0]]. The slicing must happen inside the split op (not the caller) because [torch.compile(fullgraph=True)] requires static tensor shapes within each compiled segment — the split op acts as the graph break where we can access runtime metadata.

elif hasattr(layer, "_forward_mamba"):
# Mamba layer with split op support - store the layer itself
attn_layer = layer
# attn_layer is None for non-attention layers (e.g. Mamba, MLP-only)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to store non-attention layers as None in attention_layers? Could we only store attention layers as previously, but we could insert mamba attention layers into attention_layers for mamba models or change the field name in nemotron_h to make it compatible with existing logic.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_layers is indexed by layer_id in the split ops (e.g., attention_layers[layer_id]). NemotronH's layer_id is the absolute model layer index (0–51), so each position must map correctly. Without None placeholders, layer_id=10 (a Mamba layer) would index into the wrong entry.

For non-hybrid models this is backward-compatible — every entry is an attention layer and no None values appear. Open to alternatives if you have a preferred approach (e.g., using a dict instead of a list).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good design. Should keep it

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only concern here is we're changing the attention_layer capturing logic for all other models as well, i.e. we now are adding None for non-attention decoder layers. But theoretically it should be fine since all other models should only have attention/linear-attention decoder layers. Only nemotron_h is having MLP in their decoder layer. We can verify it by checking if all PCG CIs of other models pass.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree @zminglei

@zminglei
Copy link
Copy Markdown
Collaborator

zminglei commented Mar 5, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Mar 5, 2026
@vedantjh2 vedantjh2 changed the title fix nemotron to be able to use pcg fix nemotron to be able to use piecewise cuda graph Mar 5, 2026
elif hasattr(layer, "_forward_mamba"):
# Mamba layer with split op support - store the layer itself
attn_layer = layer
# attn_layer is None for non-attention layers (e.g. Mamba, MLP-only)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good design. Should keep it

use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
)
return output, residual
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use is_in_piecewise_cuda_graph() context

@vedantjh2 vedantjh2 changed the title fix nemotron to be able to use piecewise cuda graph Enable Piecewise CUDA Graph for NemotronH Hybrid (Mamba+Attention) Models Mar 5, 2026
…de.is_extend() and get_forward_context() is not None
@vedantjh2
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@ispobock ispobock merged commit 25bd830 into sgl-project:main Mar 12, 2026
167 of 173 checks passed
liubiyongge pushed a commit to liubiyongge/sglang that referenced this pull request Mar 13, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
@he-weiwen
Copy link
Copy Markdown

Hi - I was profiling NVIDIA-Nemotron-Nano-9B-v2-FP8 locally and I noticed that PCG wasn't actually enabled successfully by default? Instead I see Disable piecewise CUDA graph because some layers do not apply Standard GQA in my logs.

Reading the source code, it seems that only the mamba and attention layers are accounted for while the MLP layers are completely skipped in the hasattr matching, so it falls through with the condition len(self.attention_layers) < self.model_config.num_hidden_layers ?

Can someone confirm what is the actual status of this? I met another crash by skipping the check on number of attention layers so I don't think the fix is trivial.

@Oasis-Git
Copy link
Copy Markdown
Collaborator

cc @vedantjh2

@vedantjh2
Copy link
Copy Markdown
Contributor Author

The layer discovery loop in init_piecewise_cuda_graphs only appended to attention_layers when it found an attention or mamba layer. NemotronH's MLP (-) and MoE layers use the mixer attribute but have neither .attn nor ._forward_mamba, so they were silently skipped. This made len(attention_layers) < num_hidden_layers (e.g., ~30 out of 56), triggering the "some layers do not apply Standard GQA" bail-out.

The fix: when a mixer-based layer is neither attention nor mamba, append None as a positional placeholder so the list stays aligned with layer indices and the length check passes.

Fix in PR #21436

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants