[compile][graph_partition]Add tensor size handling#36038
[compile][graph_partition]Add tensor size handling#36038vllm-bot merged 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an effective optimization by adding a pre-pass to split_graph that repositions sym_size.int nodes. This change prevents unnecessary tensor propagation across subgraph boundaries, which should improve memory efficiency. The implementation is clean and the accompanying tests are relevant. I've identified a minor issue in one of the new tests where an assertion was missing and have suggested a fix. Overall, this is a solid contribution.
There was a problem hiding this comment.
Code Review
This pull request introduces an effective optimization by adding a pre-pass to split_graph that moves sym_size.int operations into the producer subgraph. This prevents tensors from being unnecessarily passed across subgraph boundaries just for shape information, which should improve memory efficiency during compilation. The implementation is clean and the new tests correctly validate the core logic. I've included one suggestion to strengthen a check in the tests to make it an explicit assertion, improving its robustness.
There was a problem hiding this comment.
Code Review
This pull request introduces an important optimization by moving sym_size.int nodes into the producer subgraph during graph partitioning. This prevents tensors from being unnecessarily passed to consumer subgraphs just for shape information, improving memory efficiency. The implementation in vllm/compilation/backends.py is clean and follows FX best practices. The accompanying tests are mostly thorough, though I've pointed out a small issue in one of the new test cases where a verification loop is ineffective and should be removed.
|
This pull request has merge conflicts that must be resolved before it can be |
|
@fxdawnn I don't think this PR is solving the right problem. The problem is when we have a sym_size node in the graph, not a sym_size.int node. e.g.: #!/usr/bin/env python
import torch
import torch.fx as fx
from torch._inductor import standalone_compile
from vllm.compilation.backends import split_graph
captured_graph = None
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
global captured_graph
captured_graph = gm
return gm
def model_fn(x: torch.Tensor) -> torch.Tensor:
shape = x.shape
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(shape)
return x
x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)
compiled_fn = torch.compile(model_fn, backend=capturing_backend)
compiled_fn(x)
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
assert len(split_items) == 3
# the shape error
submod_0 = split_gm.submod_0
print(submod_0)
example_input = torch.randn(4, 8)
compiled = standalone_compile(
submod_0, [example_input, 4], dynamic_shapes="from_example_inputs"
) |
|
This method decompose the size() into list of valid inputs symint/int. This is cheaper memory cost than adding tensor as input to all subgraph that uses taht. The trade-off of saving the memory cost is runtime. Observing the runtime overhead on the torch.size() decomposition among all the major models. After some benchmarking on Llama/openAI/ZAI/MISTRAL, the runtime overhead is minimal (all below 10ms and some under 1ms in H100X8). |
vllm/compilation/backends.py
Outdated
| - Dynamic dims (SymInt) → new sym_size.int node | ||
| - Static dims (plain int) → inlined as literal constant | ||
| """ | ||
| # torch.compile captures x.size()/x.shape as call_method target="size". |
There was a problem hiding this comment.
nit: "Dynamo captures ..."
vllm/compilation/backends.py
Outdated
| if skip: | ||
| continue |
There was a problem hiding this comment.
we don't need the skip case if we raise AssertionError right?
vllm/compilation/backends.py
Outdated
| elif isinstance(arg, (list, tuple)): | ||
| expanded = [] | ||
| for a in arg: | ||
| if a is node: | ||
| expanded.extend(dims) | ||
| else: | ||
| expanded.append(a) | ||
| new_args.append(type(arg)(expanded)) |
There was a problem hiding this comment.
I don't think this case can happen?
There was a problem hiding this comment.
great catch! tuple are not valid for crossing...
zou3519
left a comment
There was a problem hiding this comment.
this lgtm but had some minor comments, please read
|
Documentation preview: https://vllm--36038.org.readthedocs.build/en/36038/ |
|
This pull request has merge conflicts that must be resolved before it can be |
|
This pull request has merge conflicts that must be resolved before it can be |
…ry crossing Signed-off-by: Xiao Fu <xiaofu@meta.com>
Purpose
Fix #31043
Redo #32747 since there was some issues with the git sign-off
Problem
When using
torch.compilewith dynamic shapes on models that callx.size()/x.shapebefore a splitting op (e.g. sigmoid) and use the shape after it, thetorch.Sizeobject crosses the split boundary as a submodule output.aot_autograd/standalone_compilecannot handletorch.Sizeas a submodule output — it expects flat tensors and scalars. This causes:Observed in production with MoE models (e.g. DeepSeek) where
torch.Size([s72, 2048])crossed a split boundary.Root Cause
torch.compilecapturesx.size()/x.shapeas acall_methodnode withtarget="size", which returns atorch.Sizeobject (a tuple of ints/SymInts). When this node is in the producer subgraph but its consumer (e.g.view(x, shape)) is in a later subgraph after a split point,split_modulethreads thetorch.Sizeacross the boundary.aot_autogradseesTreeSpec(Size, ...)in the output spec instead of flat scalars and raises an assertion error.Fix
Add a pre-pass (
_decompose_size_nodes) at the start ofsplit_graphthat decomposes everyx.size()call into individualsym_size.int(x, dim)calls — one per dimension:sym_size.int(x, dim)node placed in the producer subgraph.split_moduleautomatically handles cross-boundary data flow: when it sees a node in subgraph 0 used by a node in subgraph 2, it makes the result an output of subgraph 0, creates a placeholder (input) in subgraph 2, and wires them in the top-level orchestrator. We don't need to manually thread SymInt inputs —split_moduledoes this for any scalar or tensor that crosses a boundary.The new
sym_size.intnodes are placed right after their tensor operand, sosplit_modulenaturally puts them in the producer subgraph.example_valuemetadata is propagated to each new node so downstream code can introspect placeholder types.Debug logging (
VLLM_LOGGING_LEVEL=DEBUG) prints the graph before and after decomposition.Tests
5 new tests in
tests/compile/test_graph_partition.py:test_sym_size_whole_shape_boundary: basic repro —x.size()used across a split boundary, validatesstandalone_compilepassestest_symint_crosses_split_boundary: SymInt placeholders frommark_dynamicthread through multiple split boundaries correctlytest_shape_boundary_standalone_compile: repro of the production MoE error (TreeSpec mismatch), validates consumer has SymInt placeholders (not static int placeholders) andstandalone_compileworkstest_size_used_in_multiple_consumer_subgraphs: samex.size()consumed by two subgraphs across two split points, validates functional correctnesstest_sym_size_metadata_propagated:example_valuemetadata set on all new nodes,standalone_compileworks on every submoduleCompile Time Assurance
Our changes shouldn't increase the overhead for runtime. To ensure this, we benchmarked on before and after gpt-oss-120b and llama3-70b.
The changes in overhead are marginal and can be considered negligible. The TLParse per analysis for the decomposition also showed under 10ms consistently across 4 models.
Graph changes
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.