Skip to content

disable graph partition in custom op#26952

Merged
DarkLight1337 merged 8 commits intovllm-project:mainfrom
BoyuanFeng:bf/disable-partition-in-custom-op
Oct 17, 2025
Merged

disable graph partition in custom op#26952
DarkLight1337 merged 8 commits intovllm-project:mainfrom
BoyuanFeng:bf/disable-partition-in-custom-op

Conversation

@BoyuanFeng
Copy link
Copy Markdown
Collaborator

@BoyuanFeng BoyuanFeng commented Oct 15, 2025

This PR fixes a nested cudagraph capture issue.

Example:

  1. We apply torch.compile directly on some ops (e.g., grouped_topk) wrapped in custom ops. Inductor graph partition applies cudagraph within the custom op.

  2. At the same time, we compile the model which uses these custom ops. Inductor graph partition also wraps each graph partition with CUDAGraph. Some partitions may include custom ops, which has already been applied cudagraph. This leads to nested cudagraph which is not supported.

This context manager should be wrapped around torch.compile calls within custom ops to avoid the nested cudagraph capture.

Test:
VLLM_USE_STANDALONE_COMPILE=1 python examples/offline_inference/basic/generate.py --model deepseek-ai/DeepSeek-V2-Lite -O.use_inductor_graph_partition=True --max-model-len 1024

Signed-off-by: Boyuan Feng <boyuan@meta.com>
@BoyuanFeng BoyuanFeng requested a review from mgoin as a code owner October 15, 2025 23:14
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a disable_graph_partition context manager to resolve a nested CUDAGraph capture issue that occurs when torch.compile is used on both a model and its custom operations. The approach is sound and correctly implemented using a context manager to temporarily modify Inductor's configuration. The fix is applied to the grouped_topk operation, which aligns with the problem description. My primary concern is the reliance on a private PyTorch API (torch._inductor.config), which poses a maintainability risk for future PyTorch upgrades. I've added a comment to highlight this.

Comment on lines +104 to +109
old_val = torch._inductor.config.graph_partition
try:
torch._inductor.config.graph_partition = False
yield
finally:
torch._inductor.config.graph_partition = old_val
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation relies on modifying torch._inductor.config.graph_partition, which is an internal, undocumented API of PyTorch's Inductor backend. While this is a clever solution to the nested CUDAGraph problem, it makes the code brittle and susceptible to breaking with future PyTorch updates. It would be beneficial to add a comment here warning about this dependency to aid future maintenance.

Suggested change
old_val = torch._inductor.config.graph_partition
try:
torch._inductor.config.graph_partition = False
yield
finally:
torch._inductor.config.graph_partition = old_val
# NOTE: This relies on an internal PyTorch Inductor API.
# This may break in future PyTorch versions.
old_val = torch._inductor.config.graph_partition
try:
torch._inductor.config.graph_partition = False
yield
finally:
torch._inductor.config.graph_partition = old_val

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config will be BC and tested in pytorch x vllm ci.

@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 15, 2025
@zou3519 zou3519 requested a review from ProExpertProg October 15, 2025 23:36
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try to make this a decorator so that people can just add it to the same callsite as the torch.compile call?

@ProExpertProg
Copy link
Copy Markdown
Collaborator

e.g.

@disable_inductor_partition
@torch.compile(...)
def grouped_topk(...)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silly question: we can't just add options={"graph_partition": False} to the nested torch.compile decorator, can we? Assuming so, can you add a brief comment explaining why not?

BoyuanFeng and others added 2 commits October 15, 2025 17:22
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
@BoyuanFeng
Copy link
Copy Markdown
Collaborator Author

@ProExpertProg yeah options should work

@DarkLight1337 DarkLight1337 merged commit 0840560 into vllm-project:main Oct 17, 2025
50 checks passed
Zhuul pushed a commit to Zhuul/vllm that referenced this pull request Oct 17, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants