Skip to content

[Bugfix] Fix Triton stream capture error on A100 in GDN attention with MTP speculative decoding#39483

Open
jacob-crux wants to merge 1 commit into
vllm-project:mainfrom
jacob-crux:fix/gdn-mtp-cudagraph-warmup
Open

[Bugfix] Fix Triton stream capture error on A100 in GDN attention with MTP speculative decoding#39483
jacob-crux wants to merge 1 commit into
vllm-project:mainfrom
jacob-crux:fix/gdn-mtp-cudagraph-warmup

Conversation

@jacob-crux
Copy link
Copy Markdown

@jacob-crux jacob-crux commented Apr 10, 2026

Purpose

Fix a Triton "operation not permitted when stream is capturing" error during FULL CUDA graph capture for models using GDN (Gated Delta Net) attention with MTP speculative decoding.

Reproduced with Qwen/Qwen3.5-35B-A3B-GPTQ-Int4 (Qwen3.5 MoE architecture, GPTQ Int4 quantized) on NVIDIA A100-SXM4-80GB.

When --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' is enabled, the engine fails to start during CUDA graph capture with:

RuntimeError: Triton Error [CUDA]: operation not permitted when stream is capturing
  at vllm/model_executor/layers/mamba/ops/causal_conv1d.py → _causal_conv1d_update_kernel
RuntimeError: Engine core initialization failed.

The spec-decode Triton kernels (causal_conv1d_update with IS_SPEC_DECODING=True, fused_sigmoid_gating_delta_rule_update, fused_gdn_gating, l2norm_fwd_kernel2, etc.) attempt to JIT-compile during the actual CUDA graph capture, calling Triton's load_binary while a CUDA stream capture is active — which CUDA forbids.

Root cause: During _warmup_and_capture, the warmup _dummy_run calls do not pass is_graph_capturing=True. As a result, attention metadata builders use build() instead of build_for_cudagraph_capture(). For GDN attention, build() receives dummy num_decode_draft_tokens=0, so spec_sequence_masks=None and the non-spec decode path is taken during warmup. The spec-decode Triton kernels are therefore never JIT-compiled during warmup. When the actual capture then runs (which always uses is_graph_capturing=True), build_for_cudagraph_capture() auto-generates spec-decode metadata from query_start_loc and triggers the spec-decode code path — at which point Triton tries to JIT-compile the kernels for the first time inside an active stream capture and crashes.

Fix: Pass is_graph_capturing=force_attention to the warmup _dummy_run so that FULL-mode warmup exercises the same build_for_cudagraph_capture() path as the actual capture, ensuring all required Triton kernels are pre-compiled before capture begins. Only affects FULL cudagraph warmup; PIECEWISE mode is unchanged because GDN attention runs as a splitting op outside the captured graph there.

Test Plan

  • Server startup with FULL CUDA graph + MTP — Verify the engine initializes and serves requests with qwen3_next_mtp speculative decoding enabled:
# Start server
vllm serve Qwen/Qwen3.5-35B-A3B-GPTQ-Int4 \
  --gpu-memory-utilization 0.85 \
  --language-model-only \
  --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' \
  --port 8000

# Smoke test
curl -s http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "Qwen/Qwen3.5-35B-A3B-GPTQ-Int4",
    "messages": [{"role":"user","content":"What is the capital of South Korea?"}],
    "max_tokens": 100
  }'

Test Result

  • Environment: NVIDIA A100-SXM4-80GB x 1

Before fix — Engine fails to start during CUDA graph capture:

(EngineCore_DP0) ERROR [core.py:1108] EngineCore failed to start.
(EngineCore_DP0) ERROR [core.py:1108] RuntimeError: Triton Error [CUDA]: operation not permitted when stream is capturing
(EngineCore_DP0) ERROR [core.py:1108]   at vllm/model_executor/layers/mamba/ops/causal_conv1d.py
(EngineCore_DP0) ERROR [core.py:1108]      → _causal_conv1d_update_kernel
RuntimeError: Engine core initialization failed.

After fix — Engine starts successfully and serves requests normally with MTP speculative decoding enabled.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

…h MTP speculative decoding

Signed-off-by: jacob-crux <jacob.crux@kakaocorp.com>
@jacob-crux jacob-crux requested a review from njhill as a code owner April 10, 2026 07:24
@mergify mergify Bot added v1 bug Something isn't working labels Apr 10, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the _warmup_and_capture method in vllm/v1/worker/gpu_model_runner.py to pass the is_graph_capturing argument, set to the value of force_attention, during the model warmup phase. I have no feedback to provide as there are no review comments to evaluate.

@JaheimLee
Copy link
Copy Markdown

any update?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants