-
-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[compile][graph_partition] Remove unused subgraph inputs after split_module #35251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
440341a
074f5bb
63a55a4
592307a
ee67880
a239cd3
832b8b1
26f7680
009f916
313eef1
77ecf1b
3a709cd
e726e43
ef98db7
6bd90a0
2a55ef3
cbf3c10
0e12498
10c9793
8bd2fc9
b204b4d
24c0ea6
018fe84
16ad25d
d936732
b0666e0
c31b32d
4de21b8
1ccdb56
581e02f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,15 +5,28 @@ | |
|
|
||
| import pytest | ||
| import torch | ||
| import torch._dynamo | ||
| import torch.fx as fx | ||
| from torch.fx.experimental.proxy_tensor import make_fx | ||
|
|
||
| from vllm.compilation.backends import split_graph | ||
| from vllm.compilation.backends import ( | ||
| split_graph, | ||
| ) | ||
| from vllm.compilation.fx_utils import find_op_nodes | ||
|
|
||
| # This import automatically registers `torch.ops.silly.attention` | ||
| from . import silly_attention # noqa: F401 | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def vllm_compile_env(monkeypatch): | ||
| """Set up vLLM compilation environment variables for testing.""" | ||
| monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") | ||
| monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") | ||
| monkeypatch.setenv("VLLM_LOGGING_LEVEL", "debug") | ||
| yield | ||
|
|
||
|
|
||
| def test_getitem_moved_to_producer_subgraph(): | ||
| """ | ||
| Test that getitem operations are moved to the same subgraph as their input, | ||
|
|
@@ -184,3 +197,249 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: | |
| assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [ | ||
| "call_function" | ||
| ] + ["output"] | ||
|
|
||
|
|
||
| def test_sym_size_in_producer_subgraph(): | ||
| """ | ||
| Test that sym_size operations are assigned to the same subgraph as their | ||
| tensor operand (the producer), so only the SymInt result crosses the | ||
| split boundary — not the original tensor. | ||
|
|
||
| This avoids passing tensors to consumer subgraphs just for .size() calls, | ||
| which would keep the tensor alive longer and increase memory usage. | ||
| """ | ||
|
|
||
| def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| batch_size = x.shape[0] | ||
| hidden_size = x.shape[1] | ||
|
|
||
| # This becomes a splitting operation | ||
| z = torch.sigmoid(x) | ||
|
|
||
| # Use the shape values after the split point | ||
| reshaped_y = y.view(batch_size, hidden_size) | ||
|
|
||
| return z + reshaped_y | ||
|
|
||
| x = torch.randn(4, 8) | ||
| y = torch.randn(32) # Will be reshaped to (4, 8) | ||
| gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) | ||
|
|
||
| # Verify the graph contains sym_size operations | ||
| sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, gm.graph)) | ||
| assert len(sym_size_nodes) > 0, ( | ||
| "Test setup failed: graph should contain sym_size operations" | ||
| ) | ||
|
|
||
| split_ops = ["aten::sigmoid"] | ||
| split_gm, split_items = split_graph(gm, split_ops) | ||
|
|
||
| # Find producer subgraph (before sigmoid) and consumer subgraph (with view) | ||
| splitting_items = [item for item in split_items if item.is_splitting_graph] | ||
| assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" | ||
|
|
||
| view_subgraph = None | ||
| for item in split_items: | ||
| view_nodes = list(find_op_nodes(torch.ops.aten.view, item.graph.graph)) | ||
| if view_nodes: | ||
| view_subgraph = item | ||
| break | ||
| assert view_subgraph is not None, "Should have a subgraph with view operation" | ||
|
|
||
| # KEY VERIFICATION: sym_size should NOT be in the consumer (view) subgraph. | ||
| # It should be in the producer subgraph, with only the SymInt result | ||
| # crossing the boundary. | ||
| sym_size_in_view_subgraph = list( | ||
| find_op_nodes(torch.ops.aten.sym_size, view_subgraph.graph.graph) | ||
| ) | ||
| assert len(sym_size_in_view_subgraph) == 0, ( | ||
| "sym_size operations should NOT be in the consumer subgraph. " | ||
| "They should be in the producer subgraph so only the SymInt result " | ||
| "crosses the boundary, avoiding passing the tensor for .size() calls." | ||
| ) | ||
|
|
||
| # Verify sym_size is in a producer subgraph (before sigmoid) | ||
| producer_subgraphs_with_sym_size = [] | ||
| for item in split_items: | ||
| if item.is_splitting_graph: | ||
| continue | ||
| if item.graph_id > splitting_items[0].graph_id: | ||
| continue | ||
| sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph)) | ||
| if sym_size_nodes: | ||
| producer_subgraphs_with_sym_size.append(item.submod_name) | ||
|
|
||
| assert len(producer_subgraphs_with_sym_size) > 0, ( | ||
| "sym_size operations should be in a producer subgraph (before sigmoid)." | ||
| ) | ||
|
|
||
| # Verify the consumer subgraph does NOT receive the original tensor x | ||
| # as an input (it should only receive y, z, and SymInt values) | ||
| view_placeholders = [ | ||
| n for n in view_subgraph.graph.graph.nodes if n.op == "placeholder" | ||
| ] | ||
| for ph in view_placeholders: | ||
| ev = ph.meta.get("example_value") | ||
| if isinstance(ev, torch.Tensor) and ev.shape == x.shape: | ||
| # This placeholder matches x's shape — it should be y or z, | ||
| # not x itself being passed just for .size() | ||
| pass # Allow tensors that are actually used for computation | ||
|
Comment on lines
+281
to
+286
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loop for verifying that the original tensor |
||
|
|
||
| # Verify functional correctness | ||
| output_original = gm(x, y) | ||
| output_split = split_gm(x, y) | ||
| assert torch.allclose(output_original, output_split), "Output mismatch after split" | ||
|
|
||
|
|
||
| def test_symint_crosses_split_boundary(): | ||
| """ | ||
| Test that SymInt placeholders from torch.compile + mark_dynamic | ||
| cross split boundaries safely via split_module's natural threading. | ||
|
|
||
| SymInt values are threaded through subgraphs by split_module and | ||
| handled correctly by inductor — no special replacement is needed. | ||
| """ | ||
| captured_graph = None | ||
|
|
||
| def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: | ||
| nonlocal captured_graph | ||
| captured_graph = gm | ||
| return gm | ||
|
|
||
| def model_fn(x: torch.Tensor) -> torch.Tensor: | ||
| batch_size = x.shape[0] | ||
| hidden_size = x.shape[1] | ||
| x = torch.ops.aten.sigmoid.default(x) | ||
| x = x.clone().view(batch_size, hidden_size) | ||
| x = torch.ops.aten.sigmoid.default(x) | ||
| x = x.clone().view(batch_size, hidden_size) | ||
| x = torch.ops.aten.sigmoid.default(x) | ||
| x = x.clone().view(batch_size, hidden_size) | ||
| return x | ||
|
|
||
| x = torch.randn(4, 8) | ||
| torch._dynamo.mark_dynamic(x, 0) | ||
|
|
||
| compiled_fn = torch.compile(model_fn, backend=capturing_backend) | ||
| compiled_fn(x) | ||
|
|
||
| assert captured_graph is not None, "Graph should be captured by backend" | ||
|
|
||
| # SymInt placeholders should exist in the captured graph | ||
| symint_placeholders = [ | ||
| node | ||
| for node in captured_graph.graph.nodes | ||
| if node.op == "placeholder" | ||
| and isinstance(node.meta.get("example_value"), torch.SymInt) | ||
| ] | ||
| assert len(symint_placeholders) > 0, ( | ||
| "Captured graph should have SymInt placeholders from mark_dynamic." | ||
| ) | ||
|
|
||
| # split_graph should handle SymInt placeholders without error | ||
| split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) | ||
|
|
||
| # Should have 3 splitting subgraphs (3 sigmoids) | ||
| splitting_subgraphs = [item for item in split_items if item.is_splitting_graph] | ||
| assert len(splitting_subgraphs) == 3, ( | ||
| f"Expected 3 splitting subgraphs (3 sigmoids), got {len(splitting_subgraphs)}" | ||
| ) | ||
| assert len(split_items) >= 6, ( | ||
| f"Expected at least 6 total subgraphs, got {len(split_items)}" | ||
| ) | ||
|
|
||
|
|
||
| def test_unused_subgraph_inputs_removed(): | ||
| """ | ||
| Test that unused inputs threaded by split_module are removed from subgraphs. | ||
|
|
||
| split_module threads values (e.g., SymInt) to all subgraphs in the chain, | ||
| even those that don't reference them. This test verifies that the cleanup | ||
| pass removes these unnecessary inputs, keeping subgraph signatures clean. | ||
| """ | ||
|
|
||
| def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| batch_size = x.shape[0] | ||
| hidden_size = x.shape[1] | ||
|
|
||
| z = torch.sigmoid(x) | ||
|
|
||
| reshaped_y = y.view(batch_size, hidden_size) | ||
| return z + reshaped_y | ||
|
|
||
| x = torch.randn(4, 8) | ||
| y = torch.randn(32) | ||
| gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) | ||
|
|
||
| split_gm, split_items = split_graph(gm, ["aten::sigmoid"]) | ||
|
|
||
| # Every subgraph should only have inputs it actually uses | ||
| for item in split_items: | ||
| for node in item.graph.graph.nodes: | ||
| if node.op == "placeholder": | ||
| assert len(node.users) > 0, ( | ||
| f"Subgraph {item.submod_name} has unused input '{node.name}'. " | ||
| "Unused inputs should be removed by the cleanup pass." | ||
| ) | ||
|
|
||
| # Verify functional correctness | ||
| output_original = gm(x, y) | ||
| output_split = split_gm(x, y) | ||
| assert torch.allclose(output_original, output_split), "Output mismatch after split" | ||
|
|
||
|
|
||
| def test_unused_symint_inputs_removed_multi_split(): | ||
| """ | ||
| Test that with torch.compile + mark_dynamic and multiple split points, | ||
| SymInt inputs are removed from subgraphs that don't use them. | ||
|
|
||
| split_module threads SymInt (e.g., s77) to every subgraph in the chain. | ||
| Splitting subgraphs (sigmoid) don't reference the SymInt, so it should | ||
| be stripped from their inputs. | ||
| """ | ||
| captured_graph = None | ||
|
|
||
| def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: | ||
| nonlocal captured_graph | ||
| captured_graph = gm | ||
| return gm | ||
|
|
||
| def model_fn(x: torch.Tensor) -> torch.Tensor: | ||
| batch_size = x.shape[0] | ||
| hidden_size = x.shape[1] | ||
| x = torch.ops.aten.sigmoid.default(x) | ||
| x = x.clone().view(batch_size, hidden_size) | ||
| x = torch.ops.aten.sigmoid.default(x) | ||
| x = x.clone().view(batch_size, hidden_size) | ||
| return x | ||
|
|
||
| x = torch.randn(4, 8) | ||
| torch._dynamo.mark_dynamic(x, 0) | ||
|
|
||
| compiled_fn = torch.compile(model_fn, backend=capturing_backend) | ||
| compiled_fn(x) | ||
|
|
||
| assert captured_graph is not None | ||
|
|
||
| split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) | ||
|
|
||
| # Splitting subgraphs (sigmoid) should NOT have SymInt inputs | ||
| for item in split_items: | ||
| if not item.is_splitting_graph: | ||
| continue | ||
| for node in item.graph.graph.nodes: | ||
| if node.op == "placeholder": | ||
| ev = node.meta.get("example_value") | ||
| assert not isinstance(ev, torch.SymInt), ( | ||
| f"Splitting subgraph {item.submod_name} has unused SymInt " | ||
| f"input '{node.name}'. SymInt should only appear in " | ||
| "subgraphs that reference it." | ||
| ) | ||
|
|
||
| # All subgraphs: no unused inputs | ||
| for item in split_items: | ||
| for node in item.graph.graph.nodes: | ||
| if node.op == "placeholder": | ||
| assert len(node.users) > 0, ( | ||
| f"Subgraph {item.submod_name} has unused input '{node.name}'." | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
pytestfixturevllm_compile_envis defined but does not appear to be used by any of the tests in this file. If it's not needed, it should be removed to avoid clutter. If it is intended to be used, please apply it to the relevant tests.