Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
440341a
[compile][graph_partition]Add tensor size handling
fxdawnn Jan 20, 2026
074f5bb
Add more test
fxdawnn Jan 21, 2026
63a55a4
Add replication to all consumer
fxdawnn Jan 22, 2026
592307a
Revert "remove cuda graph copy"
fxdawnn Jan 23, 2026
ee67880
Add repro-level bug on cuda_graph address assignment to ensure the fix
fxdawnn Jan 23, 2026
a239cd3
Revert "remove cuda graph copy"
fxdawnn Jan 22, 2026
832b8b1
Merge branch 'main' of https://github.com/vllm-project/vllm
fxdawnn Jan 23, 2026
26f7680
Merge branch 'main' of https://github.com/vllm-project/vllm
fxdawnn Jan 26, 2026
009f916
[compile][graph_partition]Add tensor size handling
fxdawnn Jan 20, 2026
313eef1
Add more test
fxdawnn Jan 21, 2026
77ecf1b
Update vllm/compilation/backends.py
fxdawnn Jan 27, 2026
3a709cd
Merge conflict
fxdawnn Jan 27, 2026
e726e43
Modify the test for scenario with torch.tensor()
fxdawnn Jan 28, 2026
ef98db7
Merge branch 'main' of https://github.com/vllm-project/vllm
fxdawnn Feb 2, 2026
6bd90a0
[compile][graph_partition]Add tensor size handling
fxdawnn Jan 20, 2026
2a55ef3
Add more test
fxdawnn Jan 21, 2026
cbf3c10
[compile][graph_partition]Add tensor size handling
fxdawnn Jan 20, 2026
0e12498
Add more test
fxdawnn Jan 21, 2026
10c9793
Add replication to all consumer
fxdawnn Jan 22, 2026
8bd2fc9
Add repro-level bug on cuda_graph address assignment to ensure the fix
fxdawnn Jan 23, 2026
b204b4d
Update vllm/compilation/backends.py
fxdawnn Jan 27, 2026
24c0ea6
Modify the test for scenario with torch.tensor()
fxdawnn Jan 28, 2026
018fe84
Fix rebase and arrangement
fxdawnn Feb 2, 2026
16ad25d
Merge branch 'tensor_size' of https://github.com/fxdawnn/vllm into te…
fxdawnn Feb 2, 2026
d936732
Fix error
fxdawnn Feb 2, 2026
b0666e0
Reduce logic overhead
fxdawnn Feb 13, 2026
c31b32d
Fix pre-commit issue
fxdawnn Feb 13, 2026
4de21b8
Move sym_size.int to producer subgraph to reduce tensor boundary cros…
fxdawnn Feb 25, 2026
1ccdb56
Fix and the unclear part of boundary crossing allowlist
fxdawnn Feb 25, 2026
581e02f
Remove unused subgraph inputs after split_module
fxdawnn Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 260 additions & 1 deletion tests/compile/test_graph_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,28 @@

import pytest
import torch
import torch._dynamo
import torch.fx as fx
from torch.fx.experimental.proxy_tensor import make_fx

from vllm.compilation.backends import split_graph
from vllm.compilation.backends import (
split_graph,
)
from vllm.compilation.fx_utils import find_op_nodes

# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401


@pytest.fixture
def vllm_compile_env(monkeypatch):
"""Set up vLLM compilation environment variables for testing."""
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
monkeypatch.setenv("VLLM_LOGGING_LEVEL", "debug")
yield
Comment on lines +21 to +27
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This pytest fixture vllm_compile_env is defined but does not appear to be used by any of the tests in this file. If it's not needed, it should be removed to avoid clutter. If it is intended to be used, please apply it to the relevant tests.



def test_getitem_moved_to_producer_subgraph():
"""
Test that getitem operations are moved to the same subgraph as their input,
Expand Down Expand Up @@ -184,3 +197,249 @@ def model_fn(x: torch.Tensor) -> torch.Tensor:
assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [
"call_function"
] + ["output"]


def test_sym_size_in_producer_subgraph():
"""
Test that sym_size operations are assigned to the same subgraph as their
tensor operand (the producer), so only the SymInt result crosses the
split boundary — not the original tensor.

This avoids passing tensors to consumer subgraphs just for .size() calls,
which would keep the tensor alive longer and increase memory usage.
"""

def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
hidden_size = x.shape[1]

# This becomes a splitting operation
z = torch.sigmoid(x)

# Use the shape values after the split point
reshaped_y = y.view(batch_size, hidden_size)

return z + reshaped_y

x = torch.randn(4, 8)
y = torch.randn(32) # Will be reshaped to (4, 8)
gm = make_fx(model_fn, tracing_mode="symbolic")(x, y)

# Verify the graph contains sym_size operations
sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, gm.graph))
assert len(sym_size_nodes) > 0, (
"Test setup failed: graph should contain sym_size operations"
)

split_ops = ["aten::sigmoid"]
split_gm, split_items = split_graph(gm, split_ops)

# Find producer subgraph (before sigmoid) and consumer subgraph (with view)
splitting_items = [item for item in split_items if item.is_splitting_graph]
assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph"

view_subgraph = None
for item in split_items:
view_nodes = list(find_op_nodes(torch.ops.aten.view, item.graph.graph))
if view_nodes:
view_subgraph = item
break
assert view_subgraph is not None, "Should have a subgraph with view operation"

# KEY VERIFICATION: sym_size should NOT be in the consumer (view) subgraph.
# It should be in the producer subgraph, with only the SymInt result
# crossing the boundary.
sym_size_in_view_subgraph = list(
find_op_nodes(torch.ops.aten.sym_size, view_subgraph.graph.graph)
)
assert len(sym_size_in_view_subgraph) == 0, (
"sym_size operations should NOT be in the consumer subgraph. "
"They should be in the producer subgraph so only the SymInt result "
"crosses the boundary, avoiding passing the tensor for .size() calls."
)

# Verify sym_size is in a producer subgraph (before sigmoid)
producer_subgraphs_with_sym_size = []
for item in split_items:
if item.is_splitting_graph:
continue
if item.graph_id > splitting_items[0].graph_id:
continue
sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph))
if sym_size_nodes:
producer_subgraphs_with_sym_size.append(item.submod_name)

assert len(producer_subgraphs_with_sym_size) > 0, (
"sym_size operations should be in a producer subgraph (before sigmoid)."
)

# Verify the consumer subgraph does NOT receive the original tensor x
# as an input (it should only receive y, z, and SymInt values)
view_placeholders = [
n for n in view_subgraph.graph.graph.nodes if n.op == "placeholder"
]
for ph in view_placeholders:
ev = ph.meta.get("example_value")
if isinstance(ev, torch.Tensor) and ev.shape == x.shape:
# This placeholder matches x's shape — it should be y or z,
# not x itself being passed just for .size()
pass # Allow tensors that are actually used for computation
Comment on lines +281 to +286
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This loop for verifying that the original tensor x is not passed to the consumer subgraph is currently ineffective as it only contains a pass statement and performs no assertions. This can be misleading as it looks like a verification is being done. The primary assertion assert len(sym_size_in_view_subgraph) == 0 already covers the main goal of this test. I recommend removing this loop to avoid confusion, unless a reliable assertion can be added.


# Verify functional correctness
output_original = gm(x, y)
output_split = split_gm(x, y)
assert torch.allclose(output_original, output_split), "Output mismatch after split"


def test_symint_crosses_split_boundary():
"""
Test that SymInt placeholders from torch.compile + mark_dynamic
cross split boundaries safely via split_module's natural threading.

SymInt values are threaded through subgraphs by split_module and
handled correctly by inductor — no special replacement is needed.
"""
captured_graph = None

def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
nonlocal captured_graph
captured_graph = gm
return gm

def model_fn(x: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
hidden_size = x.shape[1]
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
return x

x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)

compiled_fn = torch.compile(model_fn, backend=capturing_backend)
compiled_fn(x)

assert captured_graph is not None, "Graph should be captured by backend"

# SymInt placeholders should exist in the captured graph
symint_placeholders = [
node
for node in captured_graph.graph.nodes
if node.op == "placeholder"
and isinstance(node.meta.get("example_value"), torch.SymInt)
]
assert len(symint_placeholders) > 0, (
"Captured graph should have SymInt placeholders from mark_dynamic."
)

# split_graph should handle SymInt placeholders without error
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])

# Should have 3 splitting subgraphs (3 sigmoids)
splitting_subgraphs = [item for item in split_items if item.is_splitting_graph]
assert len(splitting_subgraphs) == 3, (
f"Expected 3 splitting subgraphs (3 sigmoids), got {len(splitting_subgraphs)}"
)
assert len(split_items) >= 6, (
f"Expected at least 6 total subgraphs, got {len(split_items)}"
)


def test_unused_subgraph_inputs_removed():
"""
Test that unused inputs threaded by split_module are removed from subgraphs.

split_module threads values (e.g., SymInt) to all subgraphs in the chain,
even those that don't reference them. This test verifies that the cleanup
pass removes these unnecessary inputs, keeping subgraph signatures clean.
"""

def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
hidden_size = x.shape[1]

z = torch.sigmoid(x)

reshaped_y = y.view(batch_size, hidden_size)
return z + reshaped_y

x = torch.randn(4, 8)
y = torch.randn(32)
gm = make_fx(model_fn, tracing_mode="symbolic")(x, y)

split_gm, split_items = split_graph(gm, ["aten::sigmoid"])

# Every subgraph should only have inputs it actually uses
for item in split_items:
for node in item.graph.graph.nodes:
if node.op == "placeholder":
assert len(node.users) > 0, (
f"Subgraph {item.submod_name} has unused input '{node.name}'. "
"Unused inputs should be removed by the cleanup pass."
)

# Verify functional correctness
output_original = gm(x, y)
output_split = split_gm(x, y)
assert torch.allclose(output_original, output_split), "Output mismatch after split"


def test_unused_symint_inputs_removed_multi_split():
"""
Test that with torch.compile + mark_dynamic and multiple split points,
SymInt inputs are removed from subgraphs that don't use them.

split_module threads SymInt (e.g., s77) to every subgraph in the chain.
Splitting subgraphs (sigmoid) don't reference the SymInt, so it should
be stripped from their inputs.
"""
captured_graph = None

def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
nonlocal captured_graph
captured_graph = gm
return gm

def model_fn(x: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
hidden_size = x.shape[1]
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
return x

x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)

compiled_fn = torch.compile(model_fn, backend=capturing_backend)
compiled_fn(x)

assert captured_graph is not None

split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])

# Splitting subgraphs (sigmoid) should NOT have SymInt inputs
for item in split_items:
if not item.is_splitting_graph:
continue
for node in item.graph.graph.nodes:
if node.op == "placeholder":
ev = node.meta.get("example_value")
assert not isinstance(ev, torch.SymInt), (
f"Splitting subgraph {item.submod_name} has unused SymInt "
f"input '{node.name}'. SymInt should only appear in "
"subgraphs that reference it."
)

# All subgraphs: no unused inputs
for item in split_items:
for node in item.graph.graph.nodes:
if node.op == "placeholder":
assert len(node.users) > 0, (
f"Subgraph {item.submod_name} has unused input '{node.name}'."
)
45 changes: 39 additions & 6 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,21 @@ class SplitItem:
def split_graph(
graph: fx.GraphModule, splitting_ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
# Move sym_size.int nodes to right after their tensor operand so they
# end up in the producer subgraph. This avoids passing the tensor to
# consumer subgraphs just for .size() calls — only the SymInt result
# crosses the boundary.
for node in list(graph.graph.nodes):
if node.op == "call_function" and node.target == torch.ops.aten.sym_size.int:
tensor_node = node.args[0]
with graph.graph.inserting_after(tensor_node):
new_node = graph.graph.call_function(
torch.ops.aten.sym_size.int, args=node.args
)
new_node.meta = node.meta.copy()
node.replace_all_uses_with(new_node)
graph.graph.erase_node(node)

# split graph by ops
subgraph_id = 0
node_to_subgraph_id: dict[fx.Node, int] = {}
Expand Down Expand Up @@ -379,19 +394,37 @@ def split_graph(
)

outputs = []
parent_modified = False

names = [name for (name, module) in split_gm.named_modules()]

for name in names:
if "." in name or name == "":
# recursive child module or the root module
for node in split_gm.graph.nodes:
if node.op != "call_module":
continue

name = node.target
module = getattr(split_gm, name)

graph_id = int(name.replace("submod_", ""))

# Remove unused inputs that split_module may have threaded through
# unnecessarily (e.g., SymInt values passed to subgraphs that
# don't reference them).
placeholders = [n for n in module.graph.nodes if n.op == "placeholder"]
unused_indices = [i for i, ph in enumerate(placeholders) if not ph.users]
if unused_indices:
for i in reversed(unused_indices):
module.graph.erase_node(placeholders[i])
node.args = tuple(
arg for i, arg in enumerate(node.args) if i not in unused_indices
)
module.graph.lint()
module.recompile()
parent_modified = True

outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

if parent_modified:
split_gm.graph.lint()
split_gm.recompile()

# sort by integer graph_id, rather than string name
outputs.sort(key=lambda x: x.graph_id)

Expand Down