Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
440341a
[compile][graph_partition]Add tensor size handling
fxdawnn Jan 20, 2026
074f5bb
Add more test
fxdawnn Jan 21, 2026
63a55a4
Add replication to all consumer
fxdawnn Jan 22, 2026
592307a
Revert "remove cuda graph copy"
fxdawnn Jan 23, 2026
ee67880
Add repro-level bug on cuda_graph address assignment to ensure the fix
fxdawnn Jan 23, 2026
a239cd3
Revert "remove cuda graph copy"
fxdawnn Jan 22, 2026
832b8b1
Merge branch 'main' of https://github.com/vllm-project/vllm
fxdawnn Jan 23, 2026
26f7680
Merge branch 'main' of https://github.com/vllm-project/vllm
fxdawnn Jan 26, 2026
009f916
[compile][graph_partition]Add tensor size handling
fxdawnn Jan 20, 2026
313eef1
Add more test
fxdawnn Jan 21, 2026
77ecf1b
Update vllm/compilation/backends.py
fxdawnn Jan 27, 2026
3a709cd
Merge conflict
fxdawnn Jan 27, 2026
e726e43
Modify the test for scenario with torch.tensor()
fxdawnn Jan 28, 2026
ef98db7
Merge branch 'main' of https://github.com/vllm-project/vllm
fxdawnn Feb 2, 2026
6bd90a0
[compile][graph_partition]Add tensor size handling
fxdawnn Jan 20, 2026
2a55ef3
Add more test
fxdawnn Jan 21, 2026
cbf3c10
[compile][graph_partition]Add tensor size handling
fxdawnn Jan 20, 2026
0e12498
Add more test
fxdawnn Jan 21, 2026
10c9793
Add replication to all consumer
fxdawnn Jan 22, 2026
8bd2fc9
Add repro-level bug on cuda_graph address assignment to ensure the fix
fxdawnn Jan 23, 2026
b204b4d
Update vllm/compilation/backends.py
fxdawnn Jan 27, 2026
24c0ea6
Modify the test for scenario with torch.tensor()
fxdawnn Jan 28, 2026
018fe84
Fix rebase and arrangement
fxdawnn Feb 2, 2026
16ad25d
Merge branch 'tensor_size' of https://github.com/fxdawnn/vllm into te…
fxdawnn Feb 2, 2026
d936732
Fix error
fxdawnn Feb 2, 2026
b0666e0
Reduce logic overhead
fxdawnn Feb 13, 2026
c31b32d
Fix pre-commit issue
fxdawnn Feb 13, 2026
4de21b8
Move sym_size.int to producer subgraph to reduce tensor boundary cros…
fxdawnn Feb 25, 2026
1ccdb56
Fix and the unclear part of boundary crossing allowlist
fxdawnn Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion tests/compile/test_graph_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,28 @@

import pytest
import torch
import torch._dynamo
import torch.fx as fx
from torch.fx.experimental.proxy_tensor import make_fx

from vllm.compilation.backends import split_graph
from vllm.compilation.backends import (
split_graph,
)
from vllm.compilation.fx_utils import find_op_nodes

# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401


@pytest.fixture
def vllm_compile_env(monkeypatch):
"""Set up vLLM compilation environment variables for testing."""
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
monkeypatch.setenv("VLLM_LOGGING_LEVEL", "debug")
yield


def test_getitem_moved_to_producer_subgraph():
"""
Test that getitem operations are moved to the same subgraph as their input,
Expand Down Expand Up @@ -184,3 +197,153 @@ def model_fn(x: torch.Tensor) -> torch.Tensor:
assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [
"call_function"
] + ["output"]


def test_sym_size_in_producer_subgraph():
"""
Test that sym_size operations are assigned to the same subgraph as their
tensor operand (the producer), so only the SymInt result crosses the
split boundary — not the original tensor.

This avoids passing tensors to consumer subgraphs just for .size() calls,
which would keep the tensor alive longer and increase memory usage.
"""

def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
hidden_size = x.shape[1]

# This becomes a splitting operation
z = torch.sigmoid(x)

# Use the shape values after the split point
reshaped_y = y.view(batch_size, hidden_size)

return z + reshaped_y

x = torch.randn(4, 8)
y = torch.randn(32) # Will be reshaped to (4, 8)
gm = make_fx(model_fn, tracing_mode="symbolic")(x, y)

# Verify the graph contains sym_size operations
sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, gm.graph))
assert len(sym_size_nodes) > 0, (
"Test setup failed: graph should contain sym_size operations"
)

split_ops = ["aten::sigmoid"]
split_gm, split_items = split_graph(gm, split_ops)

# Find producer subgraph (before sigmoid) and consumer subgraph (with view)
splitting_items = [item for item in split_items if item.is_splitting_graph]
assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph"

view_subgraph = None
for item in split_items:
view_nodes = list(find_op_nodes(torch.ops.aten.view, item.graph.graph))
if view_nodes:
view_subgraph = item
break
assert view_subgraph is not None, "Should have a subgraph with view operation"

# KEY VERIFICATION: sym_size should NOT be in the consumer (view) subgraph.
# It should be in the producer subgraph, with only the SymInt result
# crossing the boundary.
sym_size_in_view_subgraph = list(
find_op_nodes(torch.ops.aten.sym_size, view_subgraph.graph.graph)
)
assert len(sym_size_in_view_subgraph) == 0, (
"sym_size operations should NOT be in the consumer subgraph. "
"They should be in the producer subgraph so only the SymInt result "
"crosses the boundary, avoiding passing the tensor for .size() calls."
)

# Verify sym_size is in a producer subgraph (before sigmoid)
producer_subgraphs_with_sym_size = []
for item in split_items:
if item.is_splitting_graph:
continue
if item.graph_id > splitting_items[0].graph_id:
continue
sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph))
if sym_size_nodes:
producer_subgraphs_with_sym_size.append(item.submod_name)

assert len(producer_subgraphs_with_sym_size) > 0, (
"sym_size operations should be in a producer subgraph (before sigmoid)."
)

# Verify the consumer subgraph does NOT receive the original tensor x
# as an input (it should only receive y, z, and SymInt values)
view_placeholders = [
n for n in view_subgraph.graph.graph.nodes if n.op == "placeholder"
]
for ph in view_placeholders:
ev = ph.meta.get("example_value")
if isinstance(ev, torch.Tensor) and ev.shape == x.shape:
# This placeholder matches x's shape — it should be y or z,
# not x itself being passed just for .size()
pass # Allow tensors that are actually used for computation

# Verify functional correctness
output_original = gm(x, y)
output_split = split_gm(x, y)
assert torch.allclose(output_original, output_split), "Output mismatch after split"


def test_symint_crosses_split_boundary():
"""
Test that SymInt placeholders from torch.compile + mark_dynamic
cross split boundaries safely via split_module's natural threading.

SymInt values are threaded through subgraphs by split_module and
handled correctly by inductor — no special replacement is needed.
"""
captured_graph = None

def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
nonlocal captured_graph
captured_graph = gm
return gm

def model_fn(x: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
hidden_size = x.shape[1]
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
return x

x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)

compiled_fn = torch.compile(model_fn, backend=capturing_backend)
compiled_fn(x)

assert captured_graph is not None, "Graph should be captured by backend"

# SymInt placeholders should exist in the captured graph
symint_placeholders = [
node
for node in captured_graph.graph.nodes
if node.op == "placeholder"
and isinstance(node.meta.get("example_value"), torch.SymInt)
]
assert len(symint_placeholders) > 0, (
"Captured graph should have SymInt placeholders from mark_dynamic."
)

# split_graph should handle SymInt placeholders without error
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])

# Should have 3 splitting subgraphs (3 sigmoids)
splitting_subgraphs = [item for item in split_items if item.is_splitting_graph]
assert len(splitting_subgraphs) == 3, (
f"Expected 3 splitting subgraphs (3 sigmoids), got {len(splitting_subgraphs)}"
)
assert len(split_items) >= 6, (
f"Expected at least 6 total subgraphs, got {len(split_items)}"
)
15 changes: 15 additions & 0 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,21 @@ class SplitItem:
def split_graph(
graph: fx.GraphModule, splitting_ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
# Move sym_size.int nodes to right after their tensor operand so they
# end up in the producer subgraph. This avoids passing the tensor to
# consumer subgraphs just for .size() calls — only the SymInt result
# crosses the boundary.
for node in list(graph.graph.nodes):
if node.op == "call_function" and node.target == torch.ops.aten.sym_size.int:
tensor_node = node.args[0]
with graph.graph.inserting_after(tensor_node):
new_node = graph.graph.call_function(
torch.ops.aten.sym_size.int, args=node.args
)
new_node.meta = node.meta.copy()
node.replace_all_uses_with(new_node)
graph.graph.erase_node(node)

# split graph by ops
subgraph_id = 0
node_to_subgraph_id: dict[fx.Node, int] = {}
Expand Down