Skip to content

[compile][graph_partition]Add tensor size handling#36038

Merged
vllm-bot merged 1 commit intovllm-project:mainfrom
fxdawnn:symint_cross
Mar 20, 2026
Merged

[compile][graph_partition]Add tensor size handling#36038
vllm-bot merged 1 commit intovllm-project:mainfrom
fxdawnn:symint_cross

Conversation

@fxdawnn
Copy link
Contributor

@fxdawnn fxdawnn commented Mar 4, 2026

Purpose

Fix #31043
Redo #32747 since there was some issues with the git sign-off

Problem

When using torch.compile with dynamic shapes on models that call x.size() / x.shape before a splitting op (e.g. sigmoid) and use the shape after it, the torch.Size object crosses the split boundary as a submodule output. aot_autograd / standalone_compile cannot handle torch.Size as a submodule output — it expects flat tensors and scalars. This causes:

AssertionError: output spec mismatch
TreeSpec(tuple, None, [*, *, TreeSpec(Size, None, [*, *]), *])
vs
TreeSpec(tuple, None, [*, *, *, *])

Observed in production with MoE models (e.g. DeepSeek) where torch.Size([s72, 2048]) crossed a split boundary.

Root Cause

torch.compile captures x.size() / x.shape as a call_method node with target="size", which returns a torch.Size object (a tuple of ints/SymInts). When this node is in the producer subgraph but its consumer (e.g. view(x, shape)) is in a later subgraph after a split point, split_module threads the torch.Size across the boundary. aot_autograd sees TreeSpec(Size, ...) in the output spec instead of flat scalars and raises an assertion error.

Fix

Add a pre-pass (_decompose_size_nodes) at the start of split_graph that decomposes every x.size() call into individual sym_size.int(x, dim) calls — one per dimension:

Before: view(clone, size)       # size = torch.Size([s77, 8])
After:  view(clone, [sym_size_int, 8])  # s77 as SymInt node, 8 as literal
  • Dynamic dims (SymInt) → new sym_size.int(x, dim) node placed in the producer subgraph. split_module automatically handles cross-boundary data flow: when it sees a node in subgraph 0 used by a node in subgraph 2, it makes the result an output of subgraph 0, creates a placeholder (input) in subgraph 2, and wires them in the top-level orchestrator. We don't need to manually thread SymInt inputs — split_module does this for any scalar or tensor that crosses a boundary.
  • Static dims (plain int) → inlined as literal constant, never crosses the boundary

The new sym_size.int nodes are placed right after their tensor operand, so split_module naturally puts them in the producer subgraph. example_value metadata is propagated to each new node so downstream code can introspect placeholder types.

Debug logging (VLLM_LOGGING_LEVEL=DEBUG) prints the graph before and after decomposition.

Tests

5 new tests in tests/compile/test_graph_partition.py:

  • test_sym_size_whole_shape_boundary: basic repro — x.size() used across a split boundary, validates standalone_compile passes
  • test_symint_crosses_split_boundary: SymInt placeholders from mark_dynamic thread through multiple split boundaries correctly
  • test_shape_boundary_standalone_compile: repro of the production MoE error (TreeSpec mismatch), validates consumer has SymInt placeholders (not static int placeholders) and standalone_compile works
  • test_size_used_in_multiple_consumer_subgraphs: same x.size() consumed by two subgraphs across two split points, validates functional correctness
  • test_sym_size_metadata_propagated: example_value metadata set on all new nodes, standalone_compile works on every submodule

Compile Time Assurance

Our changes shouldn't increase the overhead for runtime. To ensure this, we benchmarked on before and after gpt-oss-120b and llama3-70b.

  • before: gpt-oss-120b 16.05 s llama 27.43 s
  • After: gpt-oss-120b 15.56 s llama 26.70 s

The changes in overhead are marginal and can be considered negligible. The TLParse per analysis for the decomposition also showed under 10ms consistently across 4 models.

Graph changes

  • test_sym_size_whole_shape_boundary (focus on size node is decomposed into symInt in consumer graph)
DEBUG 03-17 11:43:13 [compilation/backends.py:476] Graph before size decomposition:
DEBUG 03-17 11:43:13 [compilation/backends.py:476] graph():
DEBUG 03-17 11:43:13 [compilation/backends.py:476]     %s77 : torch.SymInt [num_users=0] = placeholder[target=s77]
DEBUG 03-17 11:43:13 [compilation/backends.py:476]     %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
DEBUG 03-17 11:43:13 [compilation/backends.py:476]     %size : [num_users=1] = call_method[target=size](args = (%l_x_,), kwargs = {})
DEBUG 03-17 11:43:13 [compilation/backends.py:476]     %x : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%l_x_,), kwargs = {})
DEBUG 03-17 11:43:13 [compilation/backends.py:476]     %clone : [num_users=1] = call_method[target=clone](args = (%x,), kwargs = {})
DEBUG 03-17 11:43:13 [compilation/backends.py:476]     %x_1 : [num_users=1] = call_method[target=view](args = (%clone, %size), kwargs = {})
DEBUG 03-17 11:43:13 [compilation/backends.py:476]     return (x_1,)
DEBUG 03-17 11:43:13 [compilation/backends.py:534] Graph after size decomposition:
DEBUG 03-17 11:43:13 [compilation/backends.py:534] graph():
DEBUG 03-17 11:43:13 [compilation/backends.py:534]     %s77 : torch.SymInt [num_users=0] = placeholder[target=s77]
DEBUG 03-17 11:43:13 [compilation/backends.py:534]     %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
DEBUG 03-17 11:43:13 [compilation/backends.py:534]     %sym_size_int : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%l_x_, 0), kwargs = {})
DEBUG 03-17 11:43:13 [compilation/backends.py:534]     %x : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%l_x_,), kwargs = {})
DEBUG 03-17 11:43:13 [compilation/backends.py:534]     %clone : [num_users=1] = call_method[target=clone](args = (%x,), kwargs = {})
DEBUG 03-17 11:43:13 [compilation/backends.py:534]     %x_1 : [num_users=1] = call_method[target=view](args = (%clone, %sym_size_int, 8), kwargs = {})
DEBUG 03-17 11:43:13 [compilation/backends.py:534]     return (x_1,)
  • test_size_used_in_multiple_consumer_subgraphs (Size to symint and inline int in consumer subgraphs)
DEBUG 03-17 11:46:48 [compilation/backends.py:476] Graph before size decomposition:
DEBUG 03-17 11:46:48 [compilation/backends.py:476] graph():
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %s77 : torch.SymInt [num_users=0] = placeholder[target=s77]
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %l_y_ : torch.Tensor [num_users=2] = placeholder[target=L_y_]
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %size : [num_users=2] = call_method[target=size](args = (%l_x_,), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %z1 : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%l_x_,), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %y1 : [num_users=1] = call_method[target=view](args = (%l_y_, %size), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %z2 : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%z1,), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %y2 : [num_users=1] = call_method[target=view](args = (%l_y_, %size), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %add : [num_users=1] = call_function[target=operator.add](args = (%z2, %y1), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     %add_1 : [num_users=1] = call_function[target=operator.add](args = (%add, %y2), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:476]     return (add_1,)
DEBUG 03-17 11:46:48 [compilation/backends.py:534] Graph after size decomposition:
DEBUG 03-17 11:46:48 [compilation/backends.py:534] graph():
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %s77 : torch.SymInt [num_users=0] = placeholder[target=s77]
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%l_x_, 0), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %l_y_ : torch.Tensor [num_users=2] = placeholder[target=L_y_]
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %z1 : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%l_x_,), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %y1 : [num_users=1] = call_method[target=view](args = (%l_y_, %sym_size_int, 8), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %z2 : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%z1,), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %y2 : [num_users=1] = call_method[target=view](args = (%l_y_, %sym_size_int, 8), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %add : [num_users=1] = call_function[target=operator.add](args = (%z2, %y1), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     %add_1 : [num_users=1] = call_function[target=operator.add](args = (%add, %y2), kwargs = {})
DEBUG 03-17 11:46:48 [compilation/backends.py:534]     return (add_1,)

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an effective optimization by adding a pre-pass to split_graph that repositions sym_size.int nodes. This change prevents unnecessary tensor propagation across subgraph boundaries, which should improve memory efficiency. The implementation is clean and the accompanying tests are relevant. I've identified a minor issue in one of the new tests where an assertion was missing and have suggested a fix. Overall, this is a solid contribution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an effective optimization by adding a pre-pass to split_graph that moves sym_size.int operations into the producer subgraph. This prevents tensors from being unnecessarily passed across subgraph boundaries just for shape information, which should improve memory efficiency during compilation. The implementation is clean and the new tests correctly validate the core logic. I've included one suggestion to strengthen a check in the tests to make it an explicit assertion, improving its robustness.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an important optimization by moving sym_size.int nodes into the producer subgraph during graph partitioning. This prevents tensors from being unnecessarily passed to consumer subgraphs just for shape information, improving memory efficiency. The implementation in vllm/compilation/backends.py is clean and follows FX best practices. The accompanying tests are mostly thorough, though I've pointed out a small issue in one of the new test cases where a verification loop is ineffective and should be removed.

@mergify
Copy link

mergify bot commented Mar 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fxdawnn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@zou3519
Copy link
Collaborator

zou3519 commented Mar 11, 2026

@fxdawnn I don't think this PR is solving the right problem. The problem is when we have a sym_size node in the graph, not a sym_size.int node. e.g.:

#!/usr/bin/env python
import torch
import torch.fx as fx
from torch._inductor import standalone_compile
from vllm.compilation.backends import split_graph


captured_graph = None

def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
    global captured_graph
    captured_graph = gm
    return gm

def model_fn(x: torch.Tensor) -> torch.Tensor:
    shape = x.shape
    x = torch.ops.aten.sigmoid.default(x)
    x = x.clone().view(shape)
    return x

x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)
compiled_fn = torch.compile(model_fn, backend=capturing_backend)
compiled_fn(x)

split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
assert len(split_items) == 3

# the shape error
submod_0 = split_gm.submod_0
print(submod_0)
example_input = torch.randn(4, 8)
compiled = standalone_compile(
    submod_0, [example_input, 4], dynamic_shapes="from_example_inputs"
)

@fxdawnn
Copy link
Contributor Author

fxdawnn commented Mar 18, 2026

This method decompose the size() into list of valid inputs symint/int. This is cheaper memory cost than adding tensor as input to all subgraph that uses taht. The trade-off of saving the memory cost is runtime. Observing the runtime overhead on the torch.size() decomposition among all the major models. After some benchmarking on Llama/openAI/ZAI/MISTRAL, the runtime overhead is minimal (all below 10ms and some under 1ms in H100X8).

- Dynamic dims (SymInt) → new sym_size.int node
- Static dims (plain int) → inlined as literal constant
"""
# torch.compile captures x.size()/x.shape as call_method target="size".
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "Dynamo captures ..."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated! Thanks.

Comment on lines +506 to +507
if skip:
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need the skip case if we raise AssertionError right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed! Thanks!

Comment on lines +516 to +523
elif isinstance(arg, (list, tuple)):
expanded = []
for a in arg:
if a is node:
expanded.extend(dims)
else:
expanded.append(a)
new_args.append(type(arg)(expanded))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this case can happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch! tuple are not valid for crossing...

Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this lgtm but had some minor comments, please read

@zou3519 zou3519 added the ready-run-all-tests Trigger CI with all tests for wide-ranging PRs label Mar 18, 2026
@mergify
Copy link

mergify bot commented Mar 19, 2026

Documentation preview: https://vllm--36038.org.readthedocs.build/en/36038/

@mergify mergify bot added documentation Improvements or additions to documentation ci/build deepseek Related to DeepSeek models frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models nvidia labels Mar 19, 2026
@mergify mergify bot added the rocm Related to AMD ROCm label Mar 19, 2026
@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 19, 2026
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Mar 19, 2026
@mergify mergify bot added the cpu Related to CPU backends label Mar 19, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 19, 2026
@mergify
Copy link

mergify bot commented Mar 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fxdawnn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 19, 2026
@mergify
Copy link

mergify bot commented Mar 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fxdawnn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

…ry crossing

Signed-off-by: Xiao Fu <xiaofu@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models ready-run-all-tests Trigger CI with all tests for wide-ranging PRs rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

[BugFix]: move torch.Size across graphs in split_graph

4 participants