Skip to content

[Bugfix] Remove nested torch.compile in GDN rearrange_mixed_qkv causing CUDA graph capture failure#42070

Merged
vllm-bot merged 2 commits into
vllm-project:mainfrom
tdoublep:fix/gdn-cudagraph-capture
May 9, 2026
Merged

[Bugfix] Remove nested torch.compile in GDN rearrange_mixed_qkv causing CUDA graph capture failure#42070
vllm-bot merged 2 commits into
vllm-project:mainfrom
tdoublep:fix/gdn-cudagraph-capture

Conversation

@tdoublep
Copy link
Copy Markdown
Member

@tdoublep tdoublep commented May 8, 2026

Summary

  • Remove @torch.compile(fullgraph=True) from rearrange_mixed_qkv which triggers Triton autotuning (torch.cuda.synchronize()) during CUDA graph capture
  • The method is already compiled by the outer AOT compilation pass, making the nested decorator redundant and harmful

Reproduction

vllm serve Qwen/Qwen3.5-35B-A3B   --language-model-only   --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}'

Fails with:

torch.AcceleratorError: CUDA error: operation not permitted when stream is capturing

Related

Test plan

  • Verify Qwen/Qwen3.5-35B-A3B with MTP spec decode starts successfully on GB200

🤖 Generated with Claude Code

Remove nested @torch.compile(fullgraph=True) decorator that triggered
Triton autotuning (torch.cuda.synchronize) during CUDA graph capture.
The method is already compiled by the outer AOT compilation pass.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@mergify mergify Bot added nvidia bug Something isn't working labels May 8, 2026
@tdoublep tdoublep marked this pull request as ready for review May 8, 2026 13:39
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the @torch.compile(fullgraph=True) decorator from the rearrange_mixed_qkv method in vllm/model_executor/layers/mamba/gdn_linear_attn.py. I have no feedback to provide as there were no review comments to evaluate.

Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this only run on rocm?

@github-project-automation github-project-automation Bot moved this to In review in NVIDIA May 8, 2026
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented May 8, 2026

@tdoublep a question. In the PR that introduced this seems to be able to run Qwen3.5 on B200 https://buildkite.com/vllm/ci/builds/65043/canvas?jid=019e05df-d38b-4432-b813-5f66b42e419a&tab=output

If I understand correctly, the rearrange_mixed_qkv is invoked within the scope of custom op gdn_attention_core which is opaque to the torch compile.

@tdoublep
Copy link
Copy Markdown
Member Author

tdoublep commented May 8, 2026

@ZJY0516 @tjtanaa I am seeing latest main failing on GB200 without this change.

@github-project-automation github-project-automation Bot moved this from In review to Ready in NVIDIA May 8, 2026
@tpopp
Copy link
Copy Markdown
Contributor

tpopp commented May 8, 2026

Sorry about this and I give my spiritual approval. When this started, there were some noticeable additional fusions with torch.compile (with an older vLLM though and before some other GDN changes), so there might be some lost perf, but of course this should be done to fix breakages.

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented May 8, 2026

reproduced using

vllm serve Qwen/Qwen3.5-35B-A3B   --language-model-only   --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}'

And it only happens with spec decoding

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented May 8, 2026

Sorry about this and I give my spiritual approval. When this started, there were some noticeable additional fusions with torch.compile (with an older vLLM though and before some other GDN changes), so there might be some lost perf, but of course this should be done to fix breakages.

I think the best way is to change the kernel to accept non-contiguous input

@ZJY0516 ZJY0516 added the ready ONLY add when PR is ready to merge/full CI is needed label May 8, 2026
@ZJY0516 ZJY0516 enabled auto-merge (squash) May 8, 2026 15:37
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented May 8, 2026

reproduced using

vllm serve Qwen/Qwen3.5-35B-A3B   --language-model-only   --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}'

And it only happens with spec decoding

That's why the CI also didn't capture this issue.

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented May 8, 2026

@ZJY0516 @tjtanaa I am seeing latest main failing on GB200 without this change.

@tdoublep may I know which test group is that?

@tdoublep
Copy link
Copy Markdown
Member Author

tdoublep commented May 8, 2026

@tjtanaa No test group, I was just deploying model with MTP - similar to the above example. Surprised it is not caught by tests though.

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented May 8, 2026

@tjtanaa No test group, I was just deploying model with MTP - similar to the above example. Surprised it is not caught by tests though.

There is no Qwen 3..5 test on CI it seems. There is only one test group lm-eval Qwen3.5 (B200) that I had to triggered manually.

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented May 8, 2026

@tjtanaa related failure https://buildkite.com/vllm/ci/builds/65133/canvas?sid=019e06ae-7db5-4cdb-b6ad-1d49ae6cb583&tab=output

@SoluMilken
Copy link
Copy Markdown
Contributor

Should we rebase this on latest main and rerun the relevant Buildkite jobs? Thanks.

@SoluMilken
Copy link
Copy Markdown
Contributor

Could a maintainer please rerun the two failing Buildkite jobs?

  • fusion-e2e-tp2-b200 appears to be infra: pytest never started, and the B200 k8s pod stayed Pending for 15m with imagecheck-0 incomplete.
  • amd-multi-modal-models-standard-2-qwen3-plus-gemm is not reproducible for me locally because I do not have AMD/MI300 access; the log ends with termination/cancellation rather than a clear assertion failure.

The PR change is limited to removing one nested torch.compile(fullgraph=True), so these look worth rerunning first.

Thanks.

@vllm-bot vllm-bot merged commit 3dda9ae into vllm-project:main May 9, 2026
57 of 60 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 9, 2026
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request May 11, 2026
…ng CUDA graph capture failure (vllm-project#42070)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants