Skip to content

[Bug] Fix torch inductor issue (shape passing through sub-graphs)#30914

Closed
yewentao256 wants to merge 5 commits into
mainfrom
wentao-fix-torch-compile-issue
Closed

[Bug] Fix torch inductor issue (shape passing through sub-graphs)#30914
yewentao256 wants to merge 5 commits into
mainfrom
wentao-fix-torch-compile-issue

Conversation

@yewentao256

@yewentao256 yewentao256 commented Dec 18, 2025

Copy link
Copy Markdown
Member

Purpose

Context: https://vllm-dev.slack.com/archives/C08U97ZRC0J/p1765934670913979

VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL="debug" python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2

There are two issues we found:

  1. the shape passed through different sub-graphs, which could be described in this image:
image 2. the address of input tensor changes, which we find the change (very few) and make a copy only when necessary

Note: For part 2, it involves a lot of change in the design, so let's put this in another PR. This PR focus on the first issue.
Basically, we can 1. copy through changed address; 2. refactor moe to support a output tensor.

Test

Originally:

(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1361, in wrapped
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     out = f(*tensors)  # type:ignore[call-arg]
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]           ^^^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "<string>", line 1, in <lambda>
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 70, in inner_f
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     out, out_descs = call_and_expect_output_descs(f, args)
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 549, in call_and_expect_output_descs
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     outs_pair = fn(*args)
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]                 ^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1095, in inner_fn
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     outs, outs_descs = call_and_expect_output_descs(fn, args)
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 549, in call_and_expect_output_descs
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     outs_pair = fn(*args)
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]                 ^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 801, in _functionalized_f_helper
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args)
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 549, in call_and_expect_output_descs
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     outs_pair = fn(*args)
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]                 ^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 110, in inner_fn
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     outs, outs_descs = call_and_expect_output_descs(fn, args)
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   File "/home/yewentao256/.venv/lib64/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 555, in call_and_expect_output_descs
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     assert out_spec == out_desc_spec, (
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866] AssertionError: ([<function aot_stage1_graph_capture.<locals>.orig_flat_fn2 at 0x7ff535c7bf60>, <function create_functional_call.<locals>.functional_call at 0x7ff535cdb920>, GraphModule()], (FunctionalTensor(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(s72, 2048), dtype=torch.bfloat16),
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]        device='cuda:0')), FunctionalTensor(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(s72, 60), dtype=torch.bfloat16),
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]        device='cuda:0')), torch.Size([s72, 2048]), FunctionalTensor(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(s72, 2048), dtype=torch.bfloat16),
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]        device='cuda:0'))), (PlainAOTOutput(idx=0), PlainAOTOutput(idx=1), PlainAOTOutput(idx=2), PlainAOTOutput(idx=3)), TreeSpec(tuple, None, [*,
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   *,
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   TreeSpec(Size, None, [*,
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]     *]),
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   *]), TreeSpec(tuple, None, [*,
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   *,
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   *,
(EngineCore_DP0 pid=2074384) ERROR 12-18 00:32:02 [core.py:866]   *]))

Now everything fixed

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mergify mergify Bot added the qwen Related to Qwen models label Dec 18, 2025

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a torch.compile issue in Qwen2MoeSparseMoeBlock by refactoring how input tensor shapes are handled. Instead of saving the original shape and using view() at the end, it now uses a boolean flag is_input_1d to conditionally squeeze() the output tensor. This is a good approach to make the code more friendly to torch.compile. I've found one potential issue with the new assertion which could lead to a runtime error.

Comment thread vllm/model_executor/models/qwen2_moe.py Outdated
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 18, 2025
@LucasWilkinson

Copy link
Copy Markdown
Collaborator

do you get reasonable output? @BoyuanFeng tried this but there was still CG address issues

@yewentao256

Copy link
Copy Markdown
Member Author

@LucasWilkinson Thanks for the feedback, I see the issue, which will trigger in debug mode. It is a cuda address replay issue, I will fix it tomorrow.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to audit:

step3_text.py
qwen3_next.py
phimoe.py
mixtral.py
olmoe.py
lfm2_moe.py
ernie45_moe.py
dbrx.py
granitemoe.py
hunyuan_v1.py
grok1.py
jamba.py
flex_olmo.py
ernie45_vl_moe.py

for the same issues

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alot of these models are old-ish or probably too small to be used with wideEP so I think it would probably be good enough if we just audit

step3_text.py
qwen3_next.py
granitemoe.py

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All fixed, thanks!

@github-project-automation github-project-automation Bot moved this to In review in NVIDIA Dec 18, 2025
@yewentao256 yewentao256 force-pushed the wentao-fix-torch-compile-issue branch from aeeb81d to 546bad2 Compare December 18, 2025 19:44
@mergify

mergify Bot commented Dec 18, 2025

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yewentao256.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@ProExpertProg ProExpertProg left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe in more detail why this is needed? We're adding a lot of complexity (and potential bugs) so I want to understand the use case.

Also, I thought there was existing logic to copy inputs into cudagraph addresses, can we reuse that?

My understanding of the issue is that the output of the fused moe op is outside the cudagraph and that's why we have to copy the tensor into the cudagraph address. What we could do instead is make sure that the output address of the Moe op appears in the previous cudagraph/compiled graph; that's what we do for attention. Could we try that as well?

cc @youkaichao @BoyuanFeng as well for review.

Comment thread vllm/config/compilation.py Outdated
Comment thread vllm/compilation/cuda_graph.py Outdated
Comment thread vllm/config/compilation.py Outdated
@yewentao256 yewentao256 marked this pull request as draft December 19, 2025 18:08
Comment thread vllm/compilation/cuda_graph.py Outdated
@BoyuanFeng

Copy link
Copy Markdown
Collaborator

For the changing address in moe, this pr copies it to static tensor address in cuda_graph.py.

Another (probably better) way is to modify the moe op such that the output tensors have the same address as the input tensors. Because input tensor comes from previous piecewise cudagraph and has static address, this guarantees that the output tensor has static address.

For example, in attention op, instead of out = attention(q,k,v), we can have attention_out(q,k,v,out=q) to write the output to the q tensor. Could we do something similar for the moe op?

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 force-pushed the wentao-fix-torch-compile-issue branch from e877962 to d9cba20 Compare December 22, 2025 21:09
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 marked this pull request as ready for review December 22, 2025 21:14

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm/model_executor/models/qwen3_next.py
@yewentao256

Copy link
Copy Markdown
Member Author

Hi @LucasWilkinson @ProExpertProg @BoyuanFeng

Thanks for the review — after digging deeper, we realized there are two separate issues:
(1) Shape issue across piecewise sub-graphs (illustrated in image 2).
(2) Tensor address/pointer can change across sub-graphs in some cases (e.g. MoE output allocated outside the captured region), which would require either a copy into stable buffers or a design change.

Given the complexity/risk of (2), I’ve reverted the current CUDA-graph input-copy change from this PR and will keep this PR scoped to (1) only.

For (2), Two possible directions:

  • Copy-on-address-change: keep a minimal, well-tested mechanism that copies only when the address differs.
  • MoE out= / preallocated output buffer: refactor MoE to accept an explicit output tensor (similar to unified_attention_with_output) so the address is stable across partitions.

I do prefer MOE_out but seems really a lot of changes, so perhaps discussing in another issue/PR

@yewentao256 yewentao256 changed the title [Bug] Fix torch inductor issue [Bug] Fix torch inductor issue (shape passing through sub-graphs) Dec 22, 2025
@ProExpertProg

Copy link
Copy Markdown
Collaborator

Thanks for clarifying! For issue 1), could we just update the Dynamo/FX partitioning logic like described in #31043: basically make sure that size is never passed through subgraph boundaries? That way we don't have to update model definitions for that?

@yewentao256

Copy link
Copy Markdown
Member Author

The original cause is that DeepEPHT doesn't support MOE with cudagraph, after consideration, I don't see the benefits of supporting this feature specifically for DeepEPHT, closing this PR

@yewentao256 yewentao256 closed this Jan 9, 2026
@github-project-automation github-project-automation Bot moved this from In review to Done in NVIDIA Jan 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants