From 440341ac74c0571f614dea09271d1803e409f196 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 20 Jan 2026 13:37:30 -0800 Subject: [PATCH 01/25] [compile][graph_partition]Add tensor size handling --- tests/compile/test_graph_partition.py | 124 ++++++++++++++++++++++++++ vllm/compilation/backends.py | 84 +++++++++++++++++ 2 files changed, 208 insertions(+) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 1cd783843a62..e634c30f51f2 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -122,3 +122,127 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +def test_sym_size_moved_across_split_boundary(): + """ + Test that sym_size operations (tensor.shape accesses) are moved to the same + subgraph as their consumers when they would otherwise cross subgraph boundaries. + + This prevents issues where PT2 doesn't fully support torch.Size as submodule + output when sym_size is in one subgraph and its consumer is in another. + + Pattern being tested: + # Original order that causes issues: + size = tensor_a.shape[0] # subgraph 0 + some_cg_unsafe_op # subgraph 1 (split point) + tensor_b = tensor_b.view(size) # subgraph 2 (would fail without fix) + + # After fix, sym_size is moved: + some_cg_unsafe_op # subgraph 1 (split point) + size = tensor_a.shape[0] # moved to subgraph 2 + tensor_b = tensor_b.view(size) # subgraph 2 (works correctly) + """ + + def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get shape before the split point - this creates sym_size ops + batch_size = x.shape[0] + hidden_size = x.shape[1] + + # This becomes a splitting operation + z = torch.sigmoid(x) + + # Use the shape values after the split point + # Without the fix, this would fail because batch_size/hidden_size + # would be outputs of the first subgraph (as torch.Size/SymInt) + reshaped_y = y.view(batch_size, hidden_size) + + return z + reshaped_y + + x = torch.randn(4, 8) + y = torch.randn(32) # Will be reshaped to (4, 8) + # Use symbolic tracing to generate sym_size operations + gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + + # Verify the graph contains sym_size operations + sym_size_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" and "sym_size" in str(node.target) + ] + assert ( + len(sym_size_nodes) > 0 + ), "Test setup failed: graph should contain sym_size operations" + + # Split on sigmoid which is the split point + split_ops = ["aten::sigmoid"] + split_gm, split_items = split_graph(gm, split_ops) + + # After the fix, we expect 2 submodules: + # - subgraph 1: sigmoid (split point) + # - subgraph 2: sym_size ops + view + add (consumer subgraph) + # The original subgraph 0 becomes empty because sym_size ops are moved + # to the consumer subgraph, so it's not created. + assert len(split_items) == 2, f"Expected 2 submodules, got {len(split_items)}" + + # Verify that one is the splitting graph (sigmoid) and one is not + splitting_items = [item for item in split_items if item.is_splitting_graph] + non_splitting_items = [item for item in split_items if not item.is_splitting_graph] + assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" + assert len(non_splitting_items) == 1, "Should have exactly 1 non-splitting subgraph" + + # The non-splitting subgraph should contain the sym_size operations + # (they were moved from before the split to after) + consumer_subgraph = non_splitting_items[0].graph + sym_size_in_consumer = [ + node + for node in consumer_subgraph.graph.nodes + if node.op == "call_function" and "sym_size" in str(node.target) + ] + assert len(sym_size_in_consumer) > 0, ( + "sym_size operations should be in the consumer subgraph (after split)" + ) + + # Verify functional correctness with same-shaped inputs + output_original = gm(x, y) + output_split = split_gm(x, y) + assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +def test_sym_size_with_multiple_consumers_in_different_subgraphs(): + """ + Test that when a sym_size result is used by consumers in multiple different + subgraphs, it's placed in the earliest consumer subgraph. + """ + + def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get shape before any split points + size = x.shape[0] + + # First split point + z1 = torch.sigmoid(x) + + # Use size after first split + y1 = y[:size] + + # Second split point + z2 = torch.sigmoid(z1) + + # Use size again after second split + y2 = y[:size] + + return z2 + y1 + y2 + + x = torch.randn(4, 8) + y = torch.randn(8, 8) + # Use symbolic tracing to generate sym_size operations + gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + + # Split on both sigmoid operations + split_ops = ["aten::sigmoid"] + split_gm, split_items = split_graph(gm, split_ops) + + # Verify functional correctness with same-shaped inputs + output_original = gm(x, y) + output_split = split_gm(x, y) + assert torch.allclose(output_original, output_split), "Output mismatch after split" diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index f06047be61b9..886ca686c60c 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -293,6 +293,84 @@ class SplitItem: graph: fx.GraphModule +def _is_sym_size_op(node: fx.Node) -> bool: + """Check if a node is a sym_size operation (tensor.shape access).""" + if node.op != "call_function": + return False + target = node.target + # Handle both torch.ops.aten.sym_size.int and sym_size.default + if hasattr(torch.ops.aten, "sym_size"): + sym_size_ops = ( + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_size.default, + ) + return target in sym_size_ops + return False + + +def _move_sym_size_nodes_for_split( + graph: fx.GraphModule, + node_to_subgraph_id: dict[fx.Node, int], +) -> None: + """ + Move sym_size operations to the same subgraph as their consumers. + + When splitting a graph, if a sym_size call is in one submodule and its + consumer is in another, PyTorch 2 has issues because torch.Size is not + fully supported as a submodule output. This function reorders sym_size + nodes to be just before their consumers when they would otherwise cross + subgraph boundaries. + + Pattern being fixed: + # Old (causes issues): + size = tensor_a.shape # subgraph 0 + some_cg_unsafe_op # subgraph 1 (split point) + tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) + + # New (fixed): + some_cg_unsafe_op # subgraph 1 (split point) + size = tensor_a.shape # moved to subgraph 2 + tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) + """ + # Collect all sym_size nodes that need to be moved + sym_size_nodes_to_move: list[tuple[fx.Node, int]] = [] + + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + + if not _is_sym_size_op(node): + continue + + node_subgraph = node_to_subgraph_id.get(node) + if node_subgraph is None: + continue + + # Find the minimum subgraph ID among all consumers of this sym_size + consumer_subgraph_ids: list[int] = [] + for user in node.users: + if user.op == "output": + continue + user_subgraph = node_to_subgraph_id.get(user) + if user_subgraph is not None: + consumer_subgraph_ids.append(user_subgraph) + + if not consumer_subgraph_ids: + continue + + # The minimum consumer subgraph is where we want to move the sym_size + min_consumer_subgraph = min(consumer_subgraph_ids) + + # Only move if the sym_size would cross into a later subgraph + if min_consumer_subgraph > node_subgraph: + sym_size_nodes_to_move.append((node, min_consumer_subgraph)) + + # Update the subgraph assignments for sym_size nodes that need to move + for node, new_subgraph_id in sym_size_nodes_to_move: + node_to_subgraph_id[node] = new_subgraph_id + + def split_graph( graph: fx.GraphModule, splitting_ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: @@ -324,6 +402,12 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id + # Move sym_size operations (tensor.shape accesses) to be closer to their + # consumers. This avoids issues where PT2 doesn't support torch.Size as + # submodule output when sym_size is in one subgraph and its consumer is + # in another. + _move_sym_size_nodes_for_split(graph, node_to_subgraph_id) + # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and # the semantics of the graph will change when we From 074f5bb62f237a5668c16861bca37f64ec7c56ed Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 20 Jan 2026 18:37:40 -0800 Subject: [PATCH 02/25] Add more test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 90 ++++++++++++++++++--------- 1 file changed, 59 insertions(+), 31 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index e634c30f51f2..336140e337e5 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -131,17 +131,6 @@ def test_sym_size_moved_across_split_boundary(): This prevents issues where PT2 doesn't fully support torch.Size as submodule output when sym_size is in one subgraph and its consumer is in another. - - Pattern being tested: - # Original order that causes issues: - size = tensor_a.shape[0] # subgraph 0 - some_cg_unsafe_op # subgraph 1 (split point) - tensor_b = tensor_b.view(size) # subgraph 2 (would fail without fix) - - # After fix, sym_size is moved: - some_cg_unsafe_op # subgraph 1 (split point) - size = tensor_a.shape[0] # moved to subgraph 2 - tensor_b = tensor_b.view(size) # subgraph 2 (works correctly) """ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -170,37 +159,76 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for node in gm.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert ( - len(sym_size_nodes) > 0 - ), "Test setup failed: graph should contain sym_size operations" + assert len(sym_size_nodes) > 0, ( + "Test setup failed: graph should contain sym_size operations" + ) # Split on sigmoid which is the split point split_ops = ["aten::sigmoid"] split_gm, split_items = split_graph(gm, split_ops) - # After the fix, we expect 2 submodules: - # - subgraph 1: sigmoid (split point) - # - subgraph 2: sym_size ops + view + add (consumer subgraph) - # The original subgraph 0 becomes empty because sym_size ops are moved - # to the consumer subgraph, so it's not created. - assert len(split_items) == 2, f"Expected 2 submodules, got {len(split_items)}" - - # Verify that one is the splitting graph (sigmoid) and one is not + # Find the sigmoid (splitting) subgraph and the consumer subgraph splitting_items = [item for item in split_items if item.is_splitting_graph] - non_splitting_items = [item for item in split_items if not item.is_splitting_graph] assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" - assert len(non_splitting_items) == 1, "Should have exactly 1 non-splitting subgraph" - # The non-splitting subgraph should contain the sym_size operations - # (they were moved from before the split to after) - consumer_subgraph = non_splitting_items[0].graph - sym_size_in_consumer = [ + # KEY VERIFICATION: sym_size operations should be in the same subgraph + # as the view operation (their consumer), NOT in an earlier subgraph. + # This prevents torch.Size from crossing subgraph boundaries. + + # Find which subgraph contains the view operation + view_subgraph = None + for item in split_items: + for node in item.graph.graph.nodes: + if node.op == "call_function" and "view" in str(node.target).lower(): + view_subgraph = item + break + if view_subgraph: + break + + assert view_subgraph is not None, "Should have a subgraph with view operation" + + # Verify sym_size operations are in the SAME subgraph as view + sym_size_in_view_subgraph = [ node - for node in consumer_subgraph.graph.nodes + for node in view_subgraph.graph.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert len(sym_size_in_consumer) > 0, ( - "sym_size operations should be in the consumer subgraph (after split)" + assert len(sym_size_in_view_subgraph) > 0, ( + "sym_size operations should be in the same subgraph as their consumer " + "(view). This ensures torch.Size doesn't cross subgraph boundaries." + ) + + # Verify ordering within the consumer subgraph: sym_size before view + consumer_nodes = list(view_subgraph.graph.graph.nodes) + # CRITICAL VERIFICATION: The sigmoid (splitting/unsafe op) subgraph must + # have a LOWER graph_id than the consumer subgraph. Since subgraphs execute + # in order of graph_id, this proves that: + # 1. Sigmoid runs FIRST + # 2. sym_size + view run SECOND (in consumer subgraph) + # Therefore, sym_size now happens AFTER the unsafe op. + sigmoid_subgraph = splitting_items[0] + assert sigmoid_subgraph.graph_id < view_subgraph.graph_id, ( + f"Sigmoid subgraph (graph_id={sigmoid_subgraph.graph_id}) must execute " + f"before consumer subgraph (graph_id={view_subgraph.graph_id}). " + "This ensures sym_size happens AFTER the unsafe operation." + ) + + sym_size_indices = [ + i + for i, node in enumerate(consumer_nodes) + if node.op == "call_function" and "sym_size" in str(node.target) + ] + view_indices = [ + i + for i, node in enumerate(consumer_nodes) + if node.op == "call_function" and "view" in str(node.target).lower() + ] + + max_sym_size_idx = max(sym_size_indices) + min_view_idx = min(view_indices) + assert max_sym_size_idx < min_view_idx, ( + f"sym_size (max index {max_sym_size_idx}) should come before " + f"view (min index {min_view_idx}) in the consumer subgraph." ) # Verify functional correctness with same-shaped inputs From 63a55a4f66b13cc34c8c43e4e0f532bf209f52bd Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Wed, 21 Jan 2026 22:00:21 -0800 Subject: [PATCH 03/25] Add replication to all consumer Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 97 ++++++++++++++++----------- vllm/compilation/backends.py | 70 +++++++++---------- 2 files changed, 89 insertions(+), 78 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 336140e337e5..78f2fa8ea8d8 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -8,6 +8,7 @@ from torch.fx.experimental.proxy_tensor import make_fx from vllm.compilation.backends import split_graph +from vllm.compilation.fx_utils import find_op_nodes, is_func def test_getitem_moved_to_producer_subgraph(): @@ -154,11 +155,7 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) # Verify the graph contains sym_size operations - sym_size_nodes = [ - node - for node in gm.graph.nodes - if node.op == "call_function" and "sym_size" in str(node.target) - ] + sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, gm.graph)) assert len(sym_size_nodes) > 0, ( "Test setup failed: graph should contain sym_size operations" ) @@ -178,21 +175,17 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Find which subgraph contains the view operation view_subgraph = None for item in split_items: - for node in item.graph.graph.nodes: - if node.op == "call_function" and "view" in str(node.target).lower(): - view_subgraph = item - break - if view_subgraph: + view_nodes = list(find_op_nodes(torch.ops.aten.view, item.graph.graph)) + if view_nodes: + view_subgraph = item break assert view_subgraph is not None, "Should have a subgraph with view operation" # Verify sym_size operations are in the SAME subgraph as view - sym_size_in_view_subgraph = [ - node - for node in view_subgraph.graph.graph.nodes - if node.op == "call_function" and "sym_size" in str(node.target) - ] + sym_size_in_view_subgraph = list( + find_op_nodes(torch.ops.aten.sym_size, view_subgraph.graph.graph) + ) assert len(sym_size_in_view_subgraph) > 0, ( "sym_size operations should be in the same subgraph as their consumer " "(view). This ensures torch.Size doesn't cross subgraph boundaries." @@ -216,12 +209,12 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: sym_size_indices = [ i for i, node in enumerate(consumer_nodes) - if node.op == "call_function" and "sym_size" in str(node.target) + if is_func(node, torch.ops.aten.sym_size.int) ] view_indices = [ i for i, node in enumerate(consumer_nodes) - if node.op == "call_function" and "view" in str(node.target).lower() + if is_func(node, torch.ops.aten.view.default) ] max_sym_size_idx = max(sym_size_indices) @@ -237,40 +230,66 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: assert torch.allclose(output_original, output_split), "Output mismatch after split" -def test_sym_size_with_multiple_consumers_in_different_subgraphs(): +def test_sym_size_replicated_to_all_consumer_subgraphs(): """ - Test that when a sym_size result is used by consumers in multiple different - subgraphs, it's placed in the earliest consumer subgraph. + Test that sym_size operations are replicated to ALL consumer subgraphs. + + This validates the pattern where each consumer subgraph computes sym_size + locally from the input tensor, rather than receiving it as an input: + + def f(x, y, z): + + sym_size = x.sym_size() # computed locally in subgraph 2 + y2 = y.view(sym_size) + + sym_size = x.sym_size() # computed locally in subgraph 4 + z2 = z.view(sym_size) + + sym_size = x.sym_size() # computed locally in subgraph 6 + w2 = w.view(sym_size) """ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - # Get shape before any split points size = x.shape[0] - - # First split point z1 = torch.sigmoid(x) - - # Use size after first split y1 = y[:size] - - # Second split point z2 = torch.sigmoid(z1) - - # Use size again after second split y2 = y[:size] - - return z2 + y1 + y2 + z3 = torch.sigmoid(z2) + y3 = y[:size] + return z3 + y1 + y2 + y3 x = torch.randn(4, 8) y = torch.randn(8, 8) - # Use symbolic tracing to generate sym_size operations gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) - # Split on both sigmoid operations - split_ops = ["aten::sigmoid"] - split_gm, split_items = split_graph(gm, split_ops) + assert len(list(find_op_nodes(torch.ops.aten.sym_size, gm.graph))) > 0, ( + "Test setup failed: graph should contain sym_size operations" + ) - # Verify functional correctness with same-shaped inputs - output_original = gm(x, y) - output_split = split_gm(x, y) - assert torch.allclose(output_original, output_split), "Output mismatch after split" + split_gm, split_items = split_graph(gm, ["aten::sigmoid"]) + + # Find subgraphs that contain slice operations (consumers of sym_size) + subgraphs_with_slice = [ + item + for item in split_items + if len(list(find_op_nodes(torch.ops.aten.slice, item.graph.graph))) > 0 + ] + + # Find subgraphs that contain sym_size operations + subgraphs_with_sym_size = [ + item + for item in split_items + if len(list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph))) > 0 + ] + + # KEY VERIFICATION: The number of subgraphs with sym_size should equal + # the number of consumer subgraphs (each consumer has its own sym_size) + assert len(subgraphs_with_sym_size) == len(subgraphs_with_slice), ( + f"Expected {len(subgraphs_with_slice)} subgraphs with sym_size " + f"(one per consumer), but found {len(subgraphs_with_sym_size)}. " + "This indicates sym_size was not properly replicated to all consumers." + ) + + # Verify functional correctness + assert torch.allclose(gm(x, y), split_gm(x, y)), "Output mismatch after split" diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 886ca686c60c..bc9017d6e002 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -309,34 +309,19 @@ def _is_sym_size_op(node: fx.Node) -> bool: return False -def _move_sym_size_nodes_for_split( +def _replicate_sym_size_nodes_for_split( graph: fx.GraphModule, node_to_subgraph_id: dict[fx.Node, int], ) -> None: """ - Move sym_size operations to the same subgraph as their consumers. - - When splitting a graph, if a sym_size call is in one submodule and its - consumer is in another, PyTorch 2 has issues because torch.Size is not - fully supported as a submodule output. This function reorders sym_size - nodes to be just before their consumers when they would otherwise cross - subgraph boundaries. - - Pattern being fixed: - # Old (causes issues): - size = tensor_a.shape # subgraph 0 - some_cg_unsafe_op # subgraph 1 (split point) - tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) - - # New (fixed): - some_cg_unsafe_op # subgraph 1 (split point) - size = tensor_a.shape # moved to subgraph 2 - tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) - """ - # Collect all sym_size nodes that need to be moved - sym_size_nodes_to_move: list[tuple[fx.Node, int]] = [] + Replicate sym_size operations to ALL consumer subgraphs. - for node in graph.graph.nodes: + When splitting a graph, if a sym_size call has consumers in multiple + subgraphs, we replicate the sym_size operation to each consumer subgraph. + This ensures each subgraph computes sym_size locally rather than receiving + it as an input, avoiding torch.Size crossing subgraph boundaries. + """ + for node in list(graph.graph.nodes): if node.op in ("output", "placeholder"): continue @@ -347,28 +332,35 @@ def _move_sym_size_nodes_for_split( if node_subgraph is None: continue - # Find the minimum subgraph ID among all consumers of this sym_size - consumer_subgraph_ids: list[int] = [] + # Group consumers by their subgraph ID (only those in later subgraphs) + subgraph_to_consumers: dict[int, list[fx.Node]] = {} for user in node.users: if user.op == "output": continue user_subgraph = node_to_subgraph_id.get(user) - if user_subgraph is not None: - consumer_subgraph_ids.append(user_subgraph) + if user_subgraph is not None and user_subgraph > node_subgraph: + if user_subgraph not in subgraph_to_consumers: + subgraph_to_consumers[user_subgraph] = [] + subgraph_to_consumers[user_subgraph].append(user) - if not consumer_subgraph_ids: + if not subgraph_to_consumers: continue - # The minimum consumer subgraph is where we want to move the sym_size - min_consumer_subgraph = min(consumer_subgraph_ids) + # Create a copy of sym_size for EACH consumer subgraph + for subgraph_id, consumer_list in subgraph_to_consumers.items(): + with graph.graph.inserting_before(consumer_list[0]): + new_sym_size = graph.graph.call_function( + node.target, + args=node.args, + kwargs=node.kwargs, + ) + if node.meta: + new_sym_size.meta = node.meta.copy() - # Only move if the sym_size would cross into a later subgraph - if min_consumer_subgraph > node_subgraph: - sym_size_nodes_to_move.append((node, min_consumer_subgraph)) + node_to_subgraph_id[new_sym_size] = subgraph_id - # Update the subgraph assignments for sym_size nodes that need to move - for node, new_subgraph_id in sym_size_nodes_to_move: - node_to_subgraph_id[node] = new_subgraph_id + for consumer in consumer_list: + consumer.replace_input_with(node, new_sym_size) def split_graph( @@ -402,11 +394,11 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id - # Move sym_size operations (tensor.shape accesses) to be closer to their - # consumers. This avoids issues where PT2 doesn't support torch.Size as + # Replicate sym_size operations (tensor.shape accesses) to all consumer + # subgraphs. This avoids issues where PT2 doesn't support torch.Size as # submodule output when sym_size is in one subgraph and its consumer is # in another. - _move_sym_size_nodes_for_split(graph, node_to_subgraph_id) + _replicate_sym_size_nodes_for_split(graph, node_to_subgraph_id) # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and From 592307abf1351fc5d7a8ae21298b16409ac644c7 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Fri, 23 Jan 2026 09:08:10 -0800 Subject: [PATCH 04/25] Revert "remove cuda graph copy" This reverts commit a409cf42cefaca81bcd866eaaf1c9ad1ffdb3017. --- vllm/compilation/cuda_graph.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 7ffa74d0d7e6..729da2c3acad 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -247,7 +247,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) ] - entry.input_addresses = input_addresses cudagraph = torch.cuda.CUDAGraph() with ExitStack() as stack: @@ -293,7 +292,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: # the weak ref of the output, so that pytorch can correctly # manage the memory during cuda graph capture return output - if self.is_debugging_mode: # check if the input addresses are the same new_input_addresses = [ From ee678800235d6b140deb6fc91a4288c7eb0e662c Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Fri, 23 Jan 2026 09:25:40 -0800 Subject: [PATCH 05/25] Add repro-level bug on cuda_graph address assignment to ensure the fix Signed-off-by: Xiao Fu --- vllm/compilation/cuda_graph.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 729da2c3acad..753ccaac8f3f 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -288,6 +288,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: compilation_counter.num_cudagraph_captured += 1 + # Save input addresses for debugging/replay verification + if self.is_debugging_mode: + entry.input_addresses = input_addresses + # important: we need to return the output, rather than # the weak ref of the output, so that pytorch can correctly # manage the memory during cuda graph capture From a239cd32cde13e70fa21120b3dd3b2774d3cdb54 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Thu, 22 Jan 2026 15:13:43 -0800 Subject: [PATCH 06/25] Revert "remove cuda graph copy" This reverts commit a409cf42cefaca81bcd866eaaf1c9ad1ffdb3017. --- vllm/compilation/cuda_graph.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 7ffa74d0d7e6..729da2c3acad 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -247,7 +247,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) ] - entry.input_addresses = input_addresses cudagraph = torch.cuda.CUDAGraph() with ExitStack() as stack: @@ -293,7 +292,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: # the weak ref of the output, so that pytorch can correctly # manage the memory during cuda graph capture return output - if self.is_debugging_mode: # check if the input addresses are the same new_input_addresses = [ From 009f91613041fcdfc98d74aca4500c7a39c63a34 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 20 Jan 2026 13:37:30 -0800 Subject: [PATCH 07/25] [compile][graph_partition]Add tensor size handling --- tests/compile/test_graph_partition.py | 124 ++++++++++++++++++++++++++ vllm/compilation/backends.py | 84 +++++++++++++++++ 2 files changed, 208 insertions(+) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 1cd783843a62..e634c30f51f2 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -122,3 +122,127 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +def test_sym_size_moved_across_split_boundary(): + """ + Test that sym_size operations (tensor.shape accesses) are moved to the same + subgraph as their consumers when they would otherwise cross subgraph boundaries. + + This prevents issues where PT2 doesn't fully support torch.Size as submodule + output when sym_size is in one subgraph and its consumer is in another. + + Pattern being tested: + # Original order that causes issues: + size = tensor_a.shape[0] # subgraph 0 + some_cg_unsafe_op # subgraph 1 (split point) + tensor_b = tensor_b.view(size) # subgraph 2 (would fail without fix) + + # After fix, sym_size is moved: + some_cg_unsafe_op # subgraph 1 (split point) + size = tensor_a.shape[0] # moved to subgraph 2 + tensor_b = tensor_b.view(size) # subgraph 2 (works correctly) + """ + + def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get shape before the split point - this creates sym_size ops + batch_size = x.shape[0] + hidden_size = x.shape[1] + + # This becomes a splitting operation + z = torch.sigmoid(x) + + # Use the shape values after the split point + # Without the fix, this would fail because batch_size/hidden_size + # would be outputs of the first subgraph (as torch.Size/SymInt) + reshaped_y = y.view(batch_size, hidden_size) + + return z + reshaped_y + + x = torch.randn(4, 8) + y = torch.randn(32) # Will be reshaped to (4, 8) + # Use symbolic tracing to generate sym_size operations + gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + + # Verify the graph contains sym_size operations + sym_size_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" and "sym_size" in str(node.target) + ] + assert ( + len(sym_size_nodes) > 0 + ), "Test setup failed: graph should contain sym_size operations" + + # Split on sigmoid which is the split point + split_ops = ["aten::sigmoid"] + split_gm, split_items = split_graph(gm, split_ops) + + # After the fix, we expect 2 submodules: + # - subgraph 1: sigmoid (split point) + # - subgraph 2: sym_size ops + view + add (consumer subgraph) + # The original subgraph 0 becomes empty because sym_size ops are moved + # to the consumer subgraph, so it's not created. + assert len(split_items) == 2, f"Expected 2 submodules, got {len(split_items)}" + + # Verify that one is the splitting graph (sigmoid) and one is not + splitting_items = [item for item in split_items if item.is_splitting_graph] + non_splitting_items = [item for item in split_items if not item.is_splitting_graph] + assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" + assert len(non_splitting_items) == 1, "Should have exactly 1 non-splitting subgraph" + + # The non-splitting subgraph should contain the sym_size operations + # (they were moved from before the split to after) + consumer_subgraph = non_splitting_items[0].graph + sym_size_in_consumer = [ + node + for node in consumer_subgraph.graph.nodes + if node.op == "call_function" and "sym_size" in str(node.target) + ] + assert len(sym_size_in_consumer) > 0, ( + "sym_size operations should be in the consumer subgraph (after split)" + ) + + # Verify functional correctness with same-shaped inputs + output_original = gm(x, y) + output_split = split_gm(x, y) + assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +def test_sym_size_with_multiple_consumers_in_different_subgraphs(): + """ + Test that when a sym_size result is used by consumers in multiple different + subgraphs, it's placed in the earliest consumer subgraph. + """ + + def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get shape before any split points + size = x.shape[0] + + # First split point + z1 = torch.sigmoid(x) + + # Use size after first split + y1 = y[:size] + + # Second split point + z2 = torch.sigmoid(z1) + + # Use size again after second split + y2 = y[:size] + + return z2 + y1 + y2 + + x = torch.randn(4, 8) + y = torch.randn(8, 8) + # Use symbolic tracing to generate sym_size operations + gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + + # Split on both sigmoid operations + split_ops = ["aten::sigmoid"] + split_gm, split_items = split_graph(gm, split_ops) + + # Verify functional correctness with same-shaped inputs + output_original = gm(x, y) + output_split = split_gm(x, y) + assert torch.allclose(output_original, output_split), "Output mismatch after split" diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 454d81317ebd..76282da813a6 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -334,6 +334,84 @@ class SplitItem: graph: fx.GraphModule +def _is_sym_size_op(node: fx.Node) -> bool: + """Check if a node is a sym_size operation (tensor.shape access).""" + if node.op != "call_function": + return False + target = node.target + # Handle both torch.ops.aten.sym_size.int and sym_size.default + if hasattr(torch.ops.aten, "sym_size"): + sym_size_ops = ( + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_size.default, + ) + return target in sym_size_ops + return False + + +def _move_sym_size_nodes_for_split( + graph: fx.GraphModule, + node_to_subgraph_id: dict[fx.Node, int], +) -> None: + """ + Move sym_size operations to the same subgraph as their consumers. + + When splitting a graph, if a sym_size call is in one submodule and its + consumer is in another, PyTorch 2 has issues because torch.Size is not + fully supported as a submodule output. This function reorders sym_size + nodes to be just before their consumers when they would otherwise cross + subgraph boundaries. + + Pattern being fixed: + # Old (causes issues): + size = tensor_a.shape # subgraph 0 + some_cg_unsafe_op # subgraph 1 (split point) + tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) + + # New (fixed): + some_cg_unsafe_op # subgraph 1 (split point) + size = tensor_a.shape # moved to subgraph 2 + tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) + """ + # Collect all sym_size nodes that need to be moved + sym_size_nodes_to_move: list[tuple[fx.Node, int]] = [] + + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + + if not _is_sym_size_op(node): + continue + + node_subgraph = node_to_subgraph_id.get(node) + if node_subgraph is None: + continue + + # Find the minimum subgraph ID among all consumers of this sym_size + consumer_subgraph_ids: list[int] = [] + for user in node.users: + if user.op == "output": + continue + user_subgraph = node_to_subgraph_id.get(user) + if user_subgraph is not None: + consumer_subgraph_ids.append(user_subgraph) + + if not consumer_subgraph_ids: + continue + + # The minimum consumer subgraph is where we want to move the sym_size + min_consumer_subgraph = min(consumer_subgraph_ids) + + # Only move if the sym_size would cross into a later subgraph + if min_consumer_subgraph > node_subgraph: + sym_size_nodes_to_move.append((node, min_consumer_subgraph)) + + # Update the subgraph assignments for sym_size nodes that need to move + for node, new_subgraph_id in sym_size_nodes_to_move: + node_to_subgraph_id[node] = new_subgraph_id + + def split_graph( graph: fx.GraphModule, splitting_ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: @@ -365,6 +443,12 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id + # Move sym_size operations (tensor.shape accesses) to be closer to their + # consumers. This avoids issues where PT2 doesn't support torch.Size as + # submodule output when sym_size is in one subgraph and its consumer is + # in another. + _move_sym_size_nodes_for_split(graph, node_to_subgraph_id) + # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and # the semantics of the graph will change when we From 313eef1ffc9d4303ba83e86242a62f0ccdafc8d5 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 20 Jan 2026 18:37:40 -0800 Subject: [PATCH 08/25] Add more test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 90 ++++++++++++++++++--------- 1 file changed, 59 insertions(+), 31 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index e634c30f51f2..336140e337e5 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -131,17 +131,6 @@ def test_sym_size_moved_across_split_boundary(): This prevents issues where PT2 doesn't fully support torch.Size as submodule output when sym_size is in one subgraph and its consumer is in another. - - Pattern being tested: - # Original order that causes issues: - size = tensor_a.shape[0] # subgraph 0 - some_cg_unsafe_op # subgraph 1 (split point) - tensor_b = tensor_b.view(size) # subgraph 2 (would fail without fix) - - # After fix, sym_size is moved: - some_cg_unsafe_op # subgraph 1 (split point) - size = tensor_a.shape[0] # moved to subgraph 2 - tensor_b = tensor_b.view(size) # subgraph 2 (works correctly) """ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -170,37 +159,76 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for node in gm.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert ( - len(sym_size_nodes) > 0 - ), "Test setup failed: graph should contain sym_size operations" + assert len(sym_size_nodes) > 0, ( + "Test setup failed: graph should contain sym_size operations" + ) # Split on sigmoid which is the split point split_ops = ["aten::sigmoid"] split_gm, split_items = split_graph(gm, split_ops) - # After the fix, we expect 2 submodules: - # - subgraph 1: sigmoid (split point) - # - subgraph 2: sym_size ops + view + add (consumer subgraph) - # The original subgraph 0 becomes empty because sym_size ops are moved - # to the consumer subgraph, so it's not created. - assert len(split_items) == 2, f"Expected 2 submodules, got {len(split_items)}" - - # Verify that one is the splitting graph (sigmoid) and one is not + # Find the sigmoid (splitting) subgraph and the consumer subgraph splitting_items = [item for item in split_items if item.is_splitting_graph] - non_splitting_items = [item for item in split_items if not item.is_splitting_graph] assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" - assert len(non_splitting_items) == 1, "Should have exactly 1 non-splitting subgraph" - # The non-splitting subgraph should contain the sym_size operations - # (they were moved from before the split to after) - consumer_subgraph = non_splitting_items[0].graph - sym_size_in_consumer = [ + # KEY VERIFICATION: sym_size operations should be in the same subgraph + # as the view operation (their consumer), NOT in an earlier subgraph. + # This prevents torch.Size from crossing subgraph boundaries. + + # Find which subgraph contains the view operation + view_subgraph = None + for item in split_items: + for node in item.graph.graph.nodes: + if node.op == "call_function" and "view" in str(node.target).lower(): + view_subgraph = item + break + if view_subgraph: + break + + assert view_subgraph is not None, "Should have a subgraph with view operation" + + # Verify sym_size operations are in the SAME subgraph as view + sym_size_in_view_subgraph = [ node - for node in consumer_subgraph.graph.nodes + for node in view_subgraph.graph.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert len(sym_size_in_consumer) > 0, ( - "sym_size operations should be in the consumer subgraph (after split)" + assert len(sym_size_in_view_subgraph) > 0, ( + "sym_size operations should be in the same subgraph as their consumer " + "(view). This ensures torch.Size doesn't cross subgraph boundaries." + ) + + # Verify ordering within the consumer subgraph: sym_size before view + consumer_nodes = list(view_subgraph.graph.graph.nodes) + # CRITICAL VERIFICATION: The sigmoid (splitting/unsafe op) subgraph must + # have a LOWER graph_id than the consumer subgraph. Since subgraphs execute + # in order of graph_id, this proves that: + # 1. Sigmoid runs FIRST + # 2. sym_size + view run SECOND (in consumer subgraph) + # Therefore, sym_size now happens AFTER the unsafe op. + sigmoid_subgraph = splitting_items[0] + assert sigmoid_subgraph.graph_id < view_subgraph.graph_id, ( + f"Sigmoid subgraph (graph_id={sigmoid_subgraph.graph_id}) must execute " + f"before consumer subgraph (graph_id={view_subgraph.graph_id}). " + "This ensures sym_size happens AFTER the unsafe operation." + ) + + sym_size_indices = [ + i + for i, node in enumerate(consumer_nodes) + if node.op == "call_function" and "sym_size" in str(node.target) + ] + view_indices = [ + i + for i, node in enumerate(consumer_nodes) + if node.op == "call_function" and "view" in str(node.target).lower() + ] + + max_sym_size_idx = max(sym_size_indices) + min_view_idx = min(view_indices) + assert max_sym_size_idx < min_view_idx, ( + f"sym_size (max index {max_sym_size_idx}) should come before " + f"view (min index {min_view_idx}) in the consumer subgraph." ) # Verify functional correctness with same-shaped inputs From 77ecf1bad7daa2b9254f573eab141f1905bfda8e Mon Sep 17 00:00:00 2001 From: Xiao <31429901+fxdawnn@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:18:38 -0800 Subject: [PATCH 09/25] Update vllm/compilation/backends.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: Xiao <31429901+fxdawnn@users.noreply.github.com> --- vllm/compilation/backends.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index bc9017d6e002..408cc5586d14 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -297,16 +297,16 @@ def _is_sym_size_op(node: fx.Node) -> bool: """Check if a node is a sym_size operation (tensor.shape access).""" if node.op != "call_function": return False - target = node.target + + if not hasattr(torch.ops.aten, "sym_size"): + return False + # Handle both torch.ops.aten.sym_size.int and sym_size.default - if hasattr(torch.ops.aten, "sym_size"): - sym_size_ops = ( - torch.ops.aten.sym_size, - torch.ops.aten.sym_size.int, - torch.ops.aten.sym_size.default, - ) - return target in sym_size_ops - return False + return node.target in ( + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_size.default, + ) def _replicate_sym_size_nodes_for_split( From e726e436c9aeab7d9dced469b15d9bcf14020318 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Wed, 28 Jan 2026 14:52:58 -0800 Subject: [PATCH 10/25] Modify the test for scenario with torch.tensor() Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 230 ++++++++++++---- vllm/compilation/backends.py | 375 +++++++++++++++++++++++--- 2 files changed, 512 insertions(+), 93 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 78f2fa8ea8d8..f1d428cf4e6a 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -5,12 +5,25 @@ import pytest import torch +import torch._dynamo +import torch.fx as fx from torch.fx.experimental.proxy_tensor import make_fx -from vllm.compilation.backends import split_graph +from vllm.compilation.backends import ( + split_graph, +) from vllm.compilation.fx_utils import find_op_nodes, is_func +@pytest.fixture +def vllm_compile_env(monkeypatch): + """Set up vLLM compilation environment variables for testing.""" + monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") + monkeypatch.setenv("VLLM_LOGGING_LEVEL", "debug") + yield + + def test_getitem_moved_to_producer_subgraph(): """ Test that getitem operations are moved to the same subgraph as their input, @@ -230,66 +243,185 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: assert torch.allclose(output_original, output_split), "Output mismatch after split" -def test_sym_size_replicated_to_all_consumer_subgraphs(): +def test_sym_size_with_torch_compile_and_mark_dynamic(): """ - Test that sym_size operations are replicated to ALL consumer subgraphs. - - This validates the pattern where each consumer subgraph computes sym_size - locally from the input tensor, rather than receiving it as an input: - - def f(x, y, z): - - sym_size = x.sym_size() # computed locally in subgraph 2 - y2 = y.view(sym_size) - - sym_size = x.sym_size() # computed locally in subgraph 4 - z2 = z.view(sym_size) - - sym_size = x.sym_size() # computed locally in subgraph 6 - w2 = w.view(sym_size) + Test handling of SymInt placeholders from torch.compile with mark_dynamic + across MULTIPLE split subgraphs. + + When using torch.compile + mark_dynamic, the captured graph has: + - SymInt placeholders (e.g., s77) as separate inputs + - Operations that use the SymInt directly (e.g., view([s77, 8])) + + standalone_compile / inductor expects only tensor inputs. split_graph must: + 1. Replace SymInt placeholder uses with sym_size calls on tensor inputs + 2. Replicate sym_size to ALL consumer subgraphs that need the dynamic size + 3. Remove unused SymInt placeholders from the final graph + + This test validates the complete SymInt -> sym_size pipeline with MULTIPLE + split boundaries to ensure sym_size is correctly replicated across subgraphs: + - Phase 1: SymInt placeholders exist in the captured graph + - Phase 2 & 3: split_graph handles SymInt replacement and removal + - Phase 4: sym_size.int exists in EACH consumer subgraph that needs it + - Phase 5: Functional correctness with original input + - Phase 6: Functional correctness with different batch size + - Phase 7: Validate multiple split subgraphs exist """ + captured_graph = None - def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - size = x.shape[0] - z1 = torch.sigmoid(x) - y1 = y[:size] - z2 = torch.sigmoid(z1) - y2 = y[:size] - z3 = torch.sigmoid(z2) - y3 = y[:size] - return z3 + y1 + y2 + y3 + def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: + nonlocal captured_graph + captured_graph = gm + return gm + + def model_fn(x: torch.Tensor) -> torch.Tensor: + # Get the dynamic shape before any splits + batch_size = x.shape[0] + hidden_size = x.shape[1] + + # First split point - sigmoid #1 + x = torch.ops.aten.sigmoid.default(x) + + # Use dynamic size after first split - creates sym_size consumer + x = x.clone().view(batch_size, hidden_size) + + # Second split point - sigmoid #2 + x = torch.ops.aten.sigmoid.default(x) + + # Use dynamic size again after second split - another sym_size consumer + x = x.clone().view(batch_size, hidden_size) + + # Third split point - sigmoid #3 + x = torch.ops.aten.sigmoid.default(x) + + # Use dynamic size again after third split - yet another consumer + x = x.clone().view(batch_size, hidden_size) + + return x x = torch.randn(4, 8) - y = torch.randn(8, 8) - gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + # Mark the first dimension as dynamic + torch._dynamo.mark_dynamic(x, 0) - assert len(list(find_op_nodes(torch.ops.aten.sym_size, gm.graph))) > 0, ( - "Test setup failed: graph should contain sym_size operations" + compiled_fn = torch.compile(model_fn, backend=capturing_backend) + compiled_fn(x) + + assert captured_graph is not None, "Graph should be captured by backend" + + # ===== PHASE 1: Validate SymInt placeholders exist in captured graph ===== + symint_placeholders = [ + node + for node in captured_graph.graph.nodes + if node.op == "placeholder" + and node.meta.get("example_value") is not None + and isinstance(node.meta.get("example_value"), torch.SymInt) + ] + assert len(symint_placeholders) > 0, ( + "Phase 1 FAILED: Captured graph should have SymInt placeholders from " + "mark_dynamic. This is the prerequisite for testing the sym_size pipeline." ) - split_gm, split_items = split_graph(gm, ["aten::sigmoid"]) + # Record original SymInt users for later validation + original_symint_users = {} + for symint_node in symint_placeholders: + users = [u for u in symint_node.users if u.op != "output"] + original_symint_users[symint_node.name] = [u.name for u in users] + + # ===== PHASE 2 & 3: split_graph handles SymInt replacement and removal ===== + # NOTE: split_graph modifies the input graph in-place! + # With 3 sigmoid operations, we expect 7 subgraphs: + # submod_0 (before sigmoid #1), submod_1 (sigmoid #1), + # submod_2 (between sigmoid #1 and #2), submod_3 (sigmoid #2), + # submod_4 (between sigmoid #2 and #3), submod_5 (sigmoid #3), + # submod_6 (after sigmoid #3) + split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) + + # ===== PHASE 7: Validate multiple split subgraphs exist ===== + # Count splitting subgraphs (the sigmoid operations) + splitting_subgraphs = [item for item in split_items if item.is_splitting_graph] + + assert len(splitting_subgraphs) == 3, ( + f"Phase 7 FAILED: Expected 3 splitting subgraphs (3 sigmoids), " + f"got {len(splitting_subgraphs)}" + ) + # Note: Total subgraphs can be 6 or 7 depending on whether there are + # operations before the first sigmoid. With torch.compile, shape access + # operations may be folded differently, resulting in 6 subgraphs: + # submod_1 (sigmoid #1), submod_2 (compute), submod_3 (sigmoid #2), + # submod_4 (compute), submod_5 (sigmoid #3), submod_6 (compute) + assert len(split_items) >= 6, ( + f"Phase 7 FAILED: Expected at least 6 total subgraphs " + f"(3 sigmoids + at least 3 compute blocks), got {len(split_items)}" + ) - # Find subgraphs that contain slice operations (consumers of sym_size) - subgraphs_with_slice = [ - item - for item in split_items - if len(list(find_op_nodes(torch.ops.aten.slice, item.graph.graph))) > 0 + # ===== PHASE 3: Validate SymInt placeholders are removed from split_gm ===== + split_placeholders = [ + node for node in split_gm.graph.nodes if node.op == "placeholder" ] - # Find subgraphs that contain sym_size operations - subgraphs_with_sym_size = [ - item - for item in split_items - if len(list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph))) > 0 + remaining_symint_placeholders = [ + node + for node in split_placeholders + if node.meta.get("example_value") is not None + and isinstance(node.meta.get("example_value"), torch.SymInt) ] + assert len(remaining_symint_placeholders) == 0, ( + f"Phase 3 FAILED: split_gm should not have SymInt placeholders after " + f"_remove_symint_placeholders. Found: " + f"{[n.name for n in remaining_symint_placeholders]}. " + "This means SymInt would be passed as input which inductor doesn't support." + ) + + # ===== PHASE 4: Validate sym_size.int exists in consumer subgraphs ===== + # Each non-splitting subgraph that uses dynamic sizes should have sym_size.int + # to compute the dynamic dimension locally from the tensor input. + total_sym_size_nodes = 0 + subgraphs_with_sym_size = [] + + for item in split_items: + sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph)) + + if sym_size_nodes: + total_sym_size_nodes += len(sym_size_nodes) + subgraphs_with_sym_size.append(item.submod_name) - # KEY VERIFICATION: The number of subgraphs with sym_size should equal - # the number of consumer subgraphs (each consumer has its own sym_size) - assert len(subgraphs_with_sym_size) == len(subgraphs_with_slice), ( - f"Expected {len(subgraphs_with_slice)} subgraphs with sym_size " - f"(one per consumer), but found {len(subgraphs_with_sym_size)}. " - "This indicates sym_size was not properly replicated to all consumers." + assert total_sym_size_nodes > 0, ( + "Phase 4 FAILED: No sym_size.int nodes found in any subgraph. " + "split_graph should replace SymInt placeholders with sym_size.int calls " + "that compute dynamic sizes from tensor inputs." ) - # Verify functional correctness - assert torch.allclose(gm(x, y), split_gm(x, y)), "Output mismatch after split" + # With 3 split boundaries and dynamic size usage after each split, + # we expect sym_size to be replicated to multiple consumer subgraphs + assert len(subgraphs_with_sym_size) >= 3, ( + f"Phase 4 FAILED: sym_size should exist in consumer subgraphs. " + f"Found sym_size in {len(subgraphs_with_sym_size)} subgraphs: " + f"{subgraphs_with_sym_size}" + ) + + # ===== PHASE 5: Validate functional correctness ===== + # split_gm should work with tensor-only input (no SymInt) + output_split = split_gm(x) + + # Handle case where output is a tuple + if isinstance(output_split, tuple): + output_split = output_split[0] + + # For reference, run the model directly to get expected output + expected_output = model_fn(x) + + assert torch.allclose(expected_output, output_split), ( + "Phase 5 FAILED: Output mismatch after split. The sym_size pipeline " + "should preserve functional correctness." + ) + + # ===== PHASE 6: Validate with different batch size ===== + # The dynamic dimension should work with different sizes + x_different = torch.randn(8, 8) # Different batch size + output_different = split_gm(x_different) + if isinstance(output_different, tuple): + output_different = output_different[0] + expected_different = model_fn(x_different) + assert torch.allclose(expected_different, output_different), ( + "Phase 6 FAILED: Output mismatch with different batch size. " + "sym_size should correctly compute the dynamic dimension at runtime." + ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index e3c5f72c65f1..a18a82d3e352 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -334,74 +334,123 @@ class SplitItem: graph: fx.GraphModule -def _is_sym_size_op(node: fx.Node) -> bool: - """Check if a node is a sym_size operation (tensor.shape access).""" - if node.op != "call_function": +def _is_symint_placeholder(node: fx.Node) -> bool: + """Check if a node is a SymInt placeholder (from torch.compile + mark_dynamic).""" + if node.op != "placeholder": return False + example_value = node.meta.get("example_value") + return example_value is not None and isinstance(example_value, torch.SymInt) - if not hasattr(torch.ops.aten, "sym_size"): - return False - # Handle both torch.ops.aten.sym_size.int and sym_size.default - return node.target in ( - torch.ops.aten.sym_size, - torch.ops.aten.sym_size.int, - torch.ops.aten.sym_size.default, - ) +def _find_tensor_for_symint( + symint_value: torch.SymInt, + graph: fx.GraphModule, +) -> tuple[fx.Node, int] | None: + """ + Find a tensor placeholder with a dimension matching the given SymInt. + Returns (tensor_node, dim) or None if no match found. + """ + for node in graph.graph.nodes: + if node.op != "placeholder": + continue + tensor_value = node.meta.get("example_value") + if tensor_value is None or not isinstance(tensor_value, torch.Tensor): + continue + if not hasattr(tensor_value, "shape"): + continue -def _replicate_sym_size_nodes_for_split( + for dim, size in enumerate(tensor_value.shape): + # Match by identity + if size is symint_value: + return (node, dim) + # Match by underlying symbolic node + if ( + hasattr(size, "node") + and hasattr(symint_value, "node") + and size.node is symint_value.node + ): + return (node, dim) + # Match by string representation (fallback) + if str(size) == str(symint_value): + return (node, dim) + + return None + + +def _replace_symint_placeholders( graph: fx.GraphModule, node_to_subgraph_id: dict[fx.Node, int], ) -> None: """ - Replicate sym_size operations to ALL consumer subgraphs. + Replace SymInt placeholder uses with sym_size calls. - When splitting a graph, if a sym_size call has consumers in multiple - subgraphs, we replicate the sym_size operation to each consumer subgraph. - This ensures each subgraph computes sym_size locally rather than receiving - it as an input, avoiding torch.Size crossing subgraph boundaries. + When using torch.compile with mark_dynamic, the captured graph has SymInt + placeholders (e.g., s77) as separate inputs. standalone_compile / inductor + expects only tensor inputs. + + This function creates sym_size.int nodes to replace SymInt placeholder uses. + + IMPORTANT: We do NOT delete the SymInt placeholders here because split_module + needs them for its symbol_to_node mapping. If we delete them, split_module + fails with KeyError when processing tensors whose shapes contain the symbol. + The placeholders are removed AFTER split_module by _remove_symint_placeholders. """ for node in list(graph.graph.nodes): - if node.op in ("output", "placeholder"): + if not _is_symint_placeholder(node): continue - if not _is_sym_size_op(node): + symint_value = node.meta.get("example_value") + if symint_value is None: continue - node_subgraph = node_to_subgraph_id.get(node) - if node_subgraph is None: + tensor_dim = _find_tensor_for_symint(symint_value, graph) + if tensor_dim is None: + logger.warning( + "Could not find tensor dimension for SymInt placeholder %s", + node.name, + ) continue - # Group consumers by their subgraph ID (only those in later subgraphs) + tensor_node, dim = tensor_dim + + # Get list of users before modifying + users_list = list(node.users.keys()) + if not users_list: + # No users, keep the placeholder for symbol_to_node mapping + continue + + # Create sym_size for each subgraph that uses this SymInt subgraph_to_consumers: dict[int, list[fx.Node]] = {} - for user in node.users: + for user in users_list: if user.op == "output": continue - user_subgraph = node_to_subgraph_id.get(user) - if user_subgraph is not None and user_subgraph > node_subgraph: - if user_subgraph not in subgraph_to_consumers: - subgraph_to_consumers[user_subgraph] = [] - subgraph_to_consumers[user_subgraph].append(user) - - if not subgraph_to_consumers: - continue + user_subgraph = node_to_subgraph_id.get(user, 0) + if user_subgraph not in subgraph_to_consumers: + subgraph_to_consumers[user_subgraph] = [] + subgraph_to_consumers[user_subgraph].append(user) - # Create a copy of sym_size for EACH consumer subgraph for subgraph_id, consumer_list in subgraph_to_consumers.items(): with graph.graph.inserting_before(consumer_list[0]): - new_sym_size = graph.graph.call_function( - node.target, - args=node.args, - kwargs=node.kwargs, + sym_size_node = graph.graph.call_function( + torch.ops.aten.sym_size.int, + args=(tensor_node, dim), ) if node.meta: - new_sym_size.meta = node.meta.copy() + sym_size_node.meta = node.meta.copy() - node_to_subgraph_id[new_sym_size] = subgraph_id + node_to_subgraph_id[sym_size_node] = subgraph_id for consumer in consumer_list: - consumer.replace_input_with(node, new_sym_size) + consumer.replace_input_with(node, sym_size_node) + + # NOTE: We do NOT delete the SymInt placeholder here! + # split_module needs it for symbol_to_node mapping. + # It will be removed by _remove_symint_placeholders after split_module. + + # NOTE: We skip lint()/recompile() here since split_module reads from + # graph.graph.nodes directly, not the forward() method. This avoids + # potential issues with graph state changes before split_module. def split_graph( @@ -435,11 +484,11 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id - # Replicate sym_size operations (tensor.shape accesses) to all consumer - # subgraphs. This avoids issues where PT2 doesn't support torch.Size as - # submodule output when sym_size is in one subgraph and its consumer is - # in another. - _replicate_sym_size_nodes_for_split(graph, node_to_subgraph_id) + # Replace SymInt placeholders with sym_size.int calls and delete them. + # This is needed for torch.compile + mark_dynamic, where the captured graph + # has SymInt placeholders as separate inputs. standalone_compile / inductor + # expects only tensor inputs. + _replace_symint_placeholders(graph, node_to_subgraph_id) # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and @@ -449,6 +498,13 @@ def split_graph( graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True ) + # Note: With the simplified approach, _replace_symint_placeholders_with_sym_size + # now DELETES SymInt placeholders BEFORE split_module runs. This prevents + # split_module from threading SymInt through submodules. The post-split cleanup + # _remove_symint_placeholders is still called as a safety net in case any + # SymInt placeholders remain (e.g., if they couldn't be replaced). + _remove_symint_placeholders(split_gm) + outputs = [] names = [name for (name, module) in split_gm.named_modules()] @@ -469,6 +525,237 @@ def split_graph( return split_gm, outputs +def _remove_symint_placeholders(gm: fx.GraphModule) -> None: + """ + Remove SymInt placeholders from a GraphModule after split_module. + + After split_module, SymInt placeholders may still exist and may have users + (call_module nodes that pass the SymInt to submodules). This function: + 1. Replaces SymInt arguments in call_module nodes with sym_size.int calls + 2. Removes the now-unused SymInt placeholders + + This ensures the final graph only requires tensor inputs. + """ + # Collect SymInt and tensor placeholders + symint_placeholders: list[fx.Node] = [] + tensor_placeholders: list[fx.Node] = [] + + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + example_value = node.meta.get("example_value") + if example_value is None: + continue + if isinstance(example_value, torch.SymInt): + symint_placeholders.append(node) + elif isinstance(example_value, torch.Tensor): + tensor_placeholders.append(node) + + if not symint_placeholders: + return + + # Build mapping from SymInt placeholder to (tensor, dim) that can compute it + symint_to_tensor_dim: dict[fx.Node, tuple[fx.Node, int]] = {} + + for symint_node in symint_placeholders: + symint_value = symint_node.meta.get("example_value") + if symint_value is None: + continue + + # Find a tensor with a dynamic dimension matching this SymInt + for tensor_node in tensor_placeholders: + tensor_value = tensor_node.meta.get("example_value") + if tensor_value is None or not hasattr(tensor_value, "shape"): + continue + + for dim, size in enumerate(tensor_value.shape): + # Match by identity + if size is symint_value: + symint_to_tensor_dim[symint_node] = (tensor_node, dim) + break + # Match by underlying symbolic node + if ( + hasattr(size, "node") + and hasattr(symint_value, "node") + and size.node is symint_value.node + ): + symint_to_tensor_dim[symint_node] = (tensor_node, dim) + break + # Match by string representation (fallback) + if str(size) == str(symint_value): + symint_to_tensor_dim[symint_node] = (tensor_node, dim) + break + + if symint_node in symint_to_tensor_dim: + break + + logger.debug( + "Mapped SymInt placeholders to tensor dims: %s", + {n.name: (t.name, d) for n, (t, d) in symint_to_tensor_dim.items()}, + ) + + # For each SymInt placeholder that has users (call_module nodes), replace + # the SymInt argument with a sym_size.int call on the corresponding tensor + nodes_modified = False + for symint_node in symint_placeholders: + if not symint_node.users: + # No users, can just delete + gm.graph.erase_node(symint_node) + nodes_modified = True + continue + + if symint_node not in symint_to_tensor_dim: + logger.warning( + "Could not find tensor dimension for SymInt placeholder %s", + symint_node.name, + ) + continue + + tensor_node, dim = symint_to_tensor_dim[symint_node] + + # Replace each use of the SymInt with a sym_size.int call + # We need to create a new sym_size node before each user + users_list = list(symint_node.users.keys()) + for user in users_list: + if user.op != "call_module": + # For non-call_module users, create sym_size before them + with gm.graph.inserting_before(user): + sym_size_node = gm.graph.call_function( + torch.ops.aten.sym_size.int, + args=(tensor_node, dim), + ) + if symint_node.meta: + sym_size_node.meta = symint_node.meta.copy() + user.replace_input_with(symint_node, sym_size_node) + else: + # For call_module nodes, we need to remove the SymInt from args + # and update the submodule to compute sym_size locally + _update_submodule_to_compute_symint_locally( + gm, user, symint_node, tensor_node, dim + ) + + # Now the SymInt placeholder should have no users + if not symint_node.users: + gm.graph.erase_node(symint_node) + nodes_modified = True + else: + logger.warning( + "SymInt placeholder %s still has %d users after processing: %s", + symint_node.name, + len(symint_node.users), + list(symint_node.users.keys()), + ) + + if nodes_modified: + gm.graph.lint() + gm.recompile() + + +def _update_submodule_to_compute_symint_locally( + gm: fx.GraphModule, + call_module_node: fx.Node, + symint_node: fx.Node, + tensor_node: fx.Node, + dim: int, +) -> None: + """ + Update a submodule call to compute SymInt locally instead of receiving it. + + This modifies: + 1. The call_module node's args to remove the SymInt and ensure tensor is passed + 2. The submodule to compute sym_size.int from the tensor instead of taking + SymInt as a parameter + """ + submod_name = call_module_node.target + submodule = getattr(gm, submod_name) + + # Find which argument position(s) correspond to symint_node and tensor_node + old_args = list(call_module_node.args) + symint_arg_indices = [i for i, arg in enumerate(old_args) if arg is symint_node] + tensor_arg_indices = [i for i, arg in enumerate(old_args) if arg is tensor_node] + + if not symint_arg_indices: + return + + # Get the submodule's placeholder nodes + submod_placeholders = [n for n in submodule.graph.nodes if n.op == "placeholder"] + + # Find the placeholder in submodule that corresponds to the SymInt + symint_placeholder_idx = symint_arg_indices[0] + if symint_placeholder_idx >= len(submod_placeholders): + logger.warning( + "SymInt arg index %d out of range for submodule %s with %d placeholders", + symint_placeholder_idx, + submod_name, + len(submod_placeholders), + ) + return + + symint_submod_placeholder = submod_placeholders[symint_placeholder_idx] + + # Find or ensure there's a placeholder for the tensor in the submodule + tensor_submod_placeholder = None + if tensor_arg_indices: + tensor_placeholder_idx = tensor_arg_indices[0] + if tensor_placeholder_idx < len(submod_placeholders): + tensor_submod_placeholder = submod_placeholders[tensor_placeholder_idx] + + if tensor_submod_placeholder is None: + # Tensor is not currently passed to this submodule, need to add it + # Add tensor to call_module args (at the end) + new_args = list(old_args) + [tensor_node] + # Also remove the SymInt from args + new_args = [ + arg for i, arg in enumerate(new_args) if i not in symint_arg_indices + ] + call_module_node.args = tuple(new_args) + + # Add new placeholder to submodule at the end + last_placeholder = submod_placeholders[-1] + with submodule.graph.inserting_after(last_placeholder): + tensor_submod_placeholder = submodule.graph.placeholder("tensor_for_symint") + if tensor_node.meta: + tensor_submod_placeholder.meta = tensor_node.meta.copy() + + else: + # Tensor is already passed, just need to update args to remove SymInt + new_args = [ + arg for i, arg in enumerate(old_args) if i not in symint_arg_indices + ] + call_module_node.args = tuple(new_args) + + # Find first node to insert sym_size before (after placeholders/get_attr) + insert_point = None + for node in submodule.graph.nodes: + if node.op not in ("placeholder", "get_attr"): + insert_point = node + break + + if insert_point is None: + logger.warning("Could not find insertion point in submodule %s", submod_name) + return + + # Create sym_size.int node in submodule + with submodule.graph.inserting_before(insert_point): + sym_size_node = submodule.graph.call_function( + torch.ops.aten.sym_size.int, + args=(tensor_submod_placeholder, dim), + ) + if symint_submod_placeholder.meta: + sym_size_node.meta = symint_submod_placeholder.meta.copy() + + # Replace all uses + + # Replace all uses of SymInt placeholder with sym_size node + symint_submod_placeholder.replace_all_uses_with(sym_size_node) + + # Remove the SymInt placeholder from submodule + submodule.graph.erase_node(symint_submod_placeholder) + + submodule.graph.lint() + submodule.recompile() + + compilation_start_time = 0.0 From 6bd90a0fea8253ddb1b245446cae138d41924983 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 20 Jan 2026 13:37:30 -0800 Subject: [PATCH 11/25] [compile][graph_partition]Add tensor size handling --- tests/compile/test_graph_partition.py | 125 ++++++++++++++++++++++++++ vllm/compilation/backends.py | 84 +++++++++++++++++ 2 files changed, 209 insertions(+) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 38e3e038a8c4..f2b3ef9f516c 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -128,6 +128,7 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: assert torch.allclose(output_original, output_split), "Output mismatch after split" +<<<<<<< HEAD def test_consecutive_ops_in_split(): """ Test that consecutive splitting operations are grouped into the same subgraph @@ -184,3 +185,127 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [ "call_function" ] + ["output"] +======= +def test_sym_size_moved_across_split_boundary(): + """ + Test that sym_size operations (tensor.shape accesses) are moved to the same + subgraph as their consumers when they would otherwise cross subgraph boundaries. + + This prevents issues where PT2 doesn't fully support torch.Size as submodule + output when sym_size is in one subgraph and its consumer is in another. + + Pattern being tested: + # Original order that causes issues: + size = tensor_a.shape[0] # subgraph 0 + some_cg_unsafe_op # subgraph 1 (split point) + tensor_b = tensor_b.view(size) # subgraph 2 (would fail without fix) + + # After fix, sym_size is moved: + some_cg_unsafe_op # subgraph 1 (split point) + size = tensor_a.shape[0] # moved to subgraph 2 + tensor_b = tensor_b.view(size) # subgraph 2 (works correctly) + """ + + def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get shape before the split point - this creates sym_size ops + batch_size = x.shape[0] + hidden_size = x.shape[1] + + # This becomes a splitting operation + z = torch.sigmoid(x) + + # Use the shape values after the split point + # Without the fix, this would fail because batch_size/hidden_size + # would be outputs of the first subgraph (as torch.Size/SymInt) + reshaped_y = y.view(batch_size, hidden_size) + + return z + reshaped_y + + x = torch.randn(4, 8) + y = torch.randn(32) # Will be reshaped to (4, 8) + # Use symbolic tracing to generate sym_size operations + gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + + # Verify the graph contains sym_size operations + sym_size_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" and "sym_size" in str(node.target) + ] + assert ( + len(sym_size_nodes) > 0 + ), "Test setup failed: graph should contain sym_size operations" + + # Split on sigmoid which is the split point + split_ops = ["aten::sigmoid"] + split_gm, split_items = split_graph(gm, split_ops) + + # After the fix, we expect 2 submodules: + # - subgraph 1: sigmoid (split point) + # - subgraph 2: sym_size ops + view + add (consumer subgraph) + # The original subgraph 0 becomes empty because sym_size ops are moved + # to the consumer subgraph, so it's not created. + assert len(split_items) == 2, f"Expected 2 submodules, got {len(split_items)}" + + # Verify that one is the splitting graph (sigmoid) and one is not + splitting_items = [item for item in split_items if item.is_splitting_graph] + non_splitting_items = [item for item in split_items if not item.is_splitting_graph] + assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" + assert len(non_splitting_items) == 1, "Should have exactly 1 non-splitting subgraph" + + # The non-splitting subgraph should contain the sym_size operations + # (they were moved from before the split to after) + consumer_subgraph = non_splitting_items[0].graph + sym_size_in_consumer = [ + node + for node in consumer_subgraph.graph.nodes + if node.op == "call_function" and "sym_size" in str(node.target) + ] + assert len(sym_size_in_consumer) > 0, ( + "sym_size operations should be in the consumer subgraph (after split)" + ) + + # Verify functional correctness with same-shaped inputs + output_original = gm(x, y) + output_split = split_gm(x, y) + assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +def test_sym_size_with_multiple_consumers_in_different_subgraphs(): + """ + Test that when a sym_size result is used by consumers in multiple different + subgraphs, it's placed in the earliest consumer subgraph. + """ + + def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get shape before any split points + size = x.shape[0] + + # First split point + z1 = torch.sigmoid(x) + + # Use size after first split + y1 = y[:size] + + # Second split point + z2 = torch.sigmoid(z1) + + # Use size again after second split + y2 = y[:size] + + return z2 + y1 + y2 + + x = torch.randn(4, 8) + y = torch.randn(8, 8) + # Use symbolic tracing to generate sym_size operations + gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + + # Split on both sigmoid operations + split_ops = ["aten::sigmoid"] + split_gm, split_items = split_graph(gm, split_ops) + + # Verify functional correctness with same-shaped inputs + output_original = gm(x, y) + output_split = split_gm(x, y) + assert torch.allclose(output_original, output_split), "Output mismatch after split" +>>>>>>> 009f91613 ([compile][graph_partition]Add tensor size handling) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 89981fc29963..935b36381342 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -332,6 +332,84 @@ class SplitItem: graph: fx.GraphModule +def _is_sym_size_op(node: fx.Node) -> bool: + """Check if a node is a sym_size operation (tensor.shape access).""" + if node.op != "call_function": + return False + target = node.target + # Handle both torch.ops.aten.sym_size.int and sym_size.default + if hasattr(torch.ops.aten, "sym_size"): + sym_size_ops = ( + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_size.default, + ) + return target in sym_size_ops + return False + + +def _move_sym_size_nodes_for_split( + graph: fx.GraphModule, + node_to_subgraph_id: dict[fx.Node, int], +) -> None: + """ + Move sym_size operations to the same subgraph as their consumers. + + When splitting a graph, if a sym_size call is in one submodule and its + consumer is in another, PyTorch 2 has issues because torch.Size is not + fully supported as a submodule output. This function reorders sym_size + nodes to be just before their consumers when they would otherwise cross + subgraph boundaries. + + Pattern being fixed: + # Old (causes issues): + size = tensor_a.shape # subgraph 0 + some_cg_unsafe_op # subgraph 1 (split point) + tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) + + # New (fixed): + some_cg_unsafe_op # subgraph 1 (split point) + size = tensor_a.shape # moved to subgraph 2 + tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) + """ + # Collect all sym_size nodes that need to be moved + sym_size_nodes_to_move: list[tuple[fx.Node, int]] = [] + + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + + if not _is_sym_size_op(node): + continue + + node_subgraph = node_to_subgraph_id.get(node) + if node_subgraph is None: + continue + + # Find the minimum subgraph ID among all consumers of this sym_size + consumer_subgraph_ids: list[int] = [] + for user in node.users: + if user.op == "output": + continue + user_subgraph = node_to_subgraph_id.get(user) + if user_subgraph is not None: + consumer_subgraph_ids.append(user_subgraph) + + if not consumer_subgraph_ids: + continue + + # The minimum consumer subgraph is where we want to move the sym_size + min_consumer_subgraph = min(consumer_subgraph_ids) + + # Only move if the sym_size would cross into a later subgraph + if min_consumer_subgraph > node_subgraph: + sym_size_nodes_to_move.append((node, min_consumer_subgraph)) + + # Update the subgraph assignments for sym_size nodes that need to move + for node, new_subgraph_id in sym_size_nodes_to_move: + node_to_subgraph_id[node] = new_subgraph_id + + def split_graph( graph: fx.GraphModule, splitting_ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: @@ -370,6 +448,12 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id + # Move sym_size operations (tensor.shape accesses) to be closer to their + # consumers. This avoids issues where PT2 doesn't support torch.Size as + # submodule output when sym_size is in one subgraph and its consumer is + # in another. + _move_sym_size_nodes_for_split(graph, node_to_subgraph_id) + # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and # the semantics of the graph will change when we From 2a55ef327daf577c59f5872758aee246cd0c7a88 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 20 Jan 2026 18:37:40 -0800 Subject: [PATCH 12/25] Add more test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 90 ++++++++++++++++++--------- 1 file changed, 59 insertions(+), 31 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index f2b3ef9f516c..3a06f46e641a 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -193,17 +193,6 @@ def test_sym_size_moved_across_split_boundary(): This prevents issues where PT2 doesn't fully support torch.Size as submodule output when sym_size is in one subgraph and its consumer is in another. - - Pattern being tested: - # Original order that causes issues: - size = tensor_a.shape[0] # subgraph 0 - some_cg_unsafe_op # subgraph 1 (split point) - tensor_b = tensor_b.view(size) # subgraph 2 (would fail without fix) - - # After fix, sym_size is moved: - some_cg_unsafe_op # subgraph 1 (split point) - size = tensor_a.shape[0] # moved to subgraph 2 - tensor_b = tensor_b.view(size) # subgraph 2 (works correctly) """ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -232,37 +221,76 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for node in gm.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert ( - len(sym_size_nodes) > 0 - ), "Test setup failed: graph should contain sym_size operations" + assert len(sym_size_nodes) > 0, ( + "Test setup failed: graph should contain sym_size operations" + ) # Split on sigmoid which is the split point split_ops = ["aten::sigmoid"] split_gm, split_items = split_graph(gm, split_ops) - # After the fix, we expect 2 submodules: - # - subgraph 1: sigmoid (split point) - # - subgraph 2: sym_size ops + view + add (consumer subgraph) - # The original subgraph 0 becomes empty because sym_size ops are moved - # to the consumer subgraph, so it's not created. - assert len(split_items) == 2, f"Expected 2 submodules, got {len(split_items)}" - - # Verify that one is the splitting graph (sigmoid) and one is not + # Find the sigmoid (splitting) subgraph and the consumer subgraph splitting_items = [item for item in split_items if item.is_splitting_graph] - non_splitting_items = [item for item in split_items if not item.is_splitting_graph] assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" - assert len(non_splitting_items) == 1, "Should have exactly 1 non-splitting subgraph" - # The non-splitting subgraph should contain the sym_size operations - # (they were moved from before the split to after) - consumer_subgraph = non_splitting_items[0].graph - sym_size_in_consumer = [ + # KEY VERIFICATION: sym_size operations should be in the same subgraph + # as the view operation (their consumer), NOT in an earlier subgraph. + # This prevents torch.Size from crossing subgraph boundaries. + + # Find which subgraph contains the view operation + view_subgraph = None + for item in split_items: + for node in item.graph.graph.nodes: + if node.op == "call_function" and "view" in str(node.target).lower(): + view_subgraph = item + break + if view_subgraph: + break + + assert view_subgraph is not None, "Should have a subgraph with view operation" + + # Verify sym_size operations are in the SAME subgraph as view + sym_size_in_view_subgraph = [ node - for node in consumer_subgraph.graph.nodes + for node in view_subgraph.graph.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert len(sym_size_in_consumer) > 0, ( - "sym_size operations should be in the consumer subgraph (after split)" + assert len(sym_size_in_view_subgraph) > 0, ( + "sym_size operations should be in the same subgraph as their consumer " + "(view). This ensures torch.Size doesn't cross subgraph boundaries." + ) + + # Verify ordering within the consumer subgraph: sym_size before view + consumer_nodes = list(view_subgraph.graph.graph.nodes) + # CRITICAL VERIFICATION: The sigmoid (splitting/unsafe op) subgraph must + # have a LOWER graph_id than the consumer subgraph. Since subgraphs execute + # in order of graph_id, this proves that: + # 1. Sigmoid runs FIRST + # 2. sym_size + view run SECOND (in consumer subgraph) + # Therefore, sym_size now happens AFTER the unsafe op. + sigmoid_subgraph = splitting_items[0] + assert sigmoid_subgraph.graph_id < view_subgraph.graph_id, ( + f"Sigmoid subgraph (graph_id={sigmoid_subgraph.graph_id}) must execute " + f"before consumer subgraph (graph_id={view_subgraph.graph_id}). " + "This ensures sym_size happens AFTER the unsafe operation." + ) + + sym_size_indices = [ + i + for i, node in enumerate(consumer_nodes) + if node.op == "call_function" and "sym_size" in str(node.target) + ] + view_indices = [ + i + for i, node in enumerate(consumer_nodes) + if node.op == "call_function" and "view" in str(node.target).lower() + ] + + max_sym_size_idx = max(sym_size_indices) + min_view_idx = min(view_indices) + assert max_sym_size_idx < min_view_idx, ( + f"sym_size (max index {max_sym_size_idx}) should come before " + f"view (min index {min_view_idx}) in the consumer subgraph." ) # Verify functional correctness with same-shaped inputs From cbf3c105fe2e5f0df438a4afc4081a5bc1e2636a Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 20 Jan 2026 13:37:30 -0800 Subject: [PATCH 13/25] [compile][graph_partition]Add tensor size handling --- tests/compile/test_graph_partition.py | 96 ++++++++++----------------- 1 file changed, 34 insertions(+), 62 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 3a06f46e641a..3022b07ead80 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -128,7 +128,6 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: assert torch.allclose(output_original, output_split), "Output mismatch after split" -<<<<<<< HEAD def test_consecutive_ops_in_split(): """ Test that consecutive splitting operations are grouped into the same subgraph @@ -185,7 +184,9 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [ "call_function" ] + ["output"] -======= + + + def test_sym_size_moved_across_split_boundary(): """ Test that sym_size operations (tensor.shape accesses) are moved to the same @@ -193,6 +194,17 @@ def test_sym_size_moved_across_split_boundary(): This prevents issues where PT2 doesn't fully support torch.Size as submodule output when sym_size is in one subgraph and its consumer is in another. + + Pattern being tested: + # Original order that causes issues: + size = tensor_a.shape[0] # subgraph 0 + some_cg_unsafe_op # subgraph 1 (split point) + tensor_b = tensor_b.view(size) # subgraph 2 (would fail without fix) + + # After fix, sym_size is moved: + some_cg_unsafe_op # subgraph 1 (split point) + size = tensor_a.shape[0] # moved to subgraph 2 + tensor_b = tensor_b.view(size) # subgraph 2 (works correctly) """ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -221,76 +233,37 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for node in gm.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert len(sym_size_nodes) > 0, ( - "Test setup failed: graph should contain sym_size operations" - ) + assert ( + len(sym_size_nodes) > 0 + ), "Test setup failed: graph should contain sym_size operations" # Split on sigmoid which is the split point split_ops = ["aten::sigmoid"] split_gm, split_items = split_graph(gm, split_ops) - # Find the sigmoid (splitting) subgraph and the consumer subgraph + # After the fix, we expect 2 submodules: + # - subgraph 1: sigmoid (split point) + # - subgraph 2: sym_size ops + view + add (consumer subgraph) + # The original subgraph 0 becomes empty because sym_size ops are moved + # to the consumer subgraph, so it's not created. + assert len(split_items) == 2, f"Expected 2 submodules, got {len(split_items)}" + + # Verify that one is the splitting graph (sigmoid) and one is not splitting_items = [item for item in split_items if item.is_splitting_graph] + non_splitting_items = [item for item in split_items if not item.is_splitting_graph] assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" + assert len(non_splitting_items) == 1, "Should have exactly 1 non-splitting subgraph" - # KEY VERIFICATION: sym_size operations should be in the same subgraph - # as the view operation (their consumer), NOT in an earlier subgraph. - # This prevents torch.Size from crossing subgraph boundaries. - - # Find which subgraph contains the view operation - view_subgraph = None - for item in split_items: - for node in item.graph.graph.nodes: - if node.op == "call_function" and "view" in str(node.target).lower(): - view_subgraph = item - break - if view_subgraph: - break - - assert view_subgraph is not None, "Should have a subgraph with view operation" - - # Verify sym_size operations are in the SAME subgraph as view - sym_size_in_view_subgraph = [ + # The non-splitting subgraph should contain the sym_size operations + # (they were moved from before the split to after) + consumer_subgraph = non_splitting_items[0].graph + sym_size_in_consumer = [ node - for node in view_subgraph.graph.graph.nodes + for node in consumer_subgraph.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert len(sym_size_in_view_subgraph) > 0, ( - "sym_size operations should be in the same subgraph as their consumer " - "(view). This ensures torch.Size doesn't cross subgraph boundaries." - ) - - # Verify ordering within the consumer subgraph: sym_size before view - consumer_nodes = list(view_subgraph.graph.graph.nodes) - # CRITICAL VERIFICATION: The sigmoid (splitting/unsafe op) subgraph must - # have a LOWER graph_id than the consumer subgraph. Since subgraphs execute - # in order of graph_id, this proves that: - # 1. Sigmoid runs FIRST - # 2. sym_size + view run SECOND (in consumer subgraph) - # Therefore, sym_size now happens AFTER the unsafe op. - sigmoid_subgraph = splitting_items[0] - assert sigmoid_subgraph.graph_id < view_subgraph.graph_id, ( - f"Sigmoid subgraph (graph_id={sigmoid_subgraph.graph_id}) must execute " - f"before consumer subgraph (graph_id={view_subgraph.graph_id}). " - "This ensures sym_size happens AFTER the unsafe operation." - ) - - sym_size_indices = [ - i - for i, node in enumerate(consumer_nodes) - if node.op == "call_function" and "sym_size" in str(node.target) - ] - view_indices = [ - i - for i, node in enumerate(consumer_nodes) - if node.op == "call_function" and "view" in str(node.target).lower() - ] - - max_sym_size_idx = max(sym_size_indices) - min_view_idx = min(view_indices) - assert max_sym_size_idx < min_view_idx, ( - f"sym_size (max index {max_sym_size_idx}) should come before " - f"view (min index {min_view_idx}) in the consumer subgraph." + assert len(sym_size_in_consumer) > 0, ( + "sym_size operations should be in the consumer subgraph (after split)" ) # Verify functional correctness with same-shaped inputs @@ -336,4 +309,3 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output_original = gm(x, y) output_split = split_gm(x, y) assert torch.allclose(output_original, output_split), "Output mismatch after split" ->>>>>>> 009f91613 ([compile][graph_partition]Add tensor size handling) From 0e1249869c4fbdbbf4575dca0d6bdf4853f5d649 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 20 Jan 2026 18:37:40 -0800 Subject: [PATCH 14/25] Add more test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 90 ++++++++++++++++++--------- 1 file changed, 59 insertions(+), 31 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 3022b07ead80..751676694ad5 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -194,17 +194,6 @@ def test_sym_size_moved_across_split_boundary(): This prevents issues where PT2 doesn't fully support torch.Size as submodule output when sym_size is in one subgraph and its consumer is in another. - - Pattern being tested: - # Original order that causes issues: - size = tensor_a.shape[0] # subgraph 0 - some_cg_unsafe_op # subgraph 1 (split point) - tensor_b = tensor_b.view(size) # subgraph 2 (would fail without fix) - - # After fix, sym_size is moved: - some_cg_unsafe_op # subgraph 1 (split point) - size = tensor_a.shape[0] # moved to subgraph 2 - tensor_b = tensor_b.view(size) # subgraph 2 (works correctly) """ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -233,37 +222,76 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for node in gm.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert ( - len(sym_size_nodes) > 0 - ), "Test setup failed: graph should contain sym_size operations" + assert len(sym_size_nodes) > 0, ( + "Test setup failed: graph should contain sym_size operations" + ) # Split on sigmoid which is the split point split_ops = ["aten::sigmoid"] split_gm, split_items = split_graph(gm, split_ops) - # After the fix, we expect 2 submodules: - # - subgraph 1: sigmoid (split point) - # - subgraph 2: sym_size ops + view + add (consumer subgraph) - # The original subgraph 0 becomes empty because sym_size ops are moved - # to the consumer subgraph, so it's not created. - assert len(split_items) == 2, f"Expected 2 submodules, got {len(split_items)}" - - # Verify that one is the splitting graph (sigmoid) and one is not + # Find the sigmoid (splitting) subgraph and the consumer subgraph splitting_items = [item for item in split_items if item.is_splitting_graph] - non_splitting_items = [item for item in split_items if not item.is_splitting_graph] assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" - assert len(non_splitting_items) == 1, "Should have exactly 1 non-splitting subgraph" - # The non-splitting subgraph should contain the sym_size operations - # (they were moved from before the split to after) - consumer_subgraph = non_splitting_items[0].graph - sym_size_in_consumer = [ + # KEY VERIFICATION: sym_size operations should be in the same subgraph + # as the view operation (their consumer), NOT in an earlier subgraph. + # This prevents torch.Size from crossing subgraph boundaries. + + # Find which subgraph contains the view operation + view_subgraph = None + for item in split_items: + for node in item.graph.graph.nodes: + if node.op == "call_function" and "view" in str(node.target).lower(): + view_subgraph = item + break + if view_subgraph: + break + + assert view_subgraph is not None, "Should have a subgraph with view operation" + + # Verify sym_size operations are in the SAME subgraph as view + sym_size_in_view_subgraph = [ node - for node in consumer_subgraph.graph.nodes + for node in view_subgraph.graph.graph.nodes if node.op == "call_function" and "sym_size" in str(node.target) ] - assert len(sym_size_in_consumer) > 0, ( - "sym_size operations should be in the consumer subgraph (after split)" + assert len(sym_size_in_view_subgraph) > 0, ( + "sym_size operations should be in the same subgraph as their consumer " + "(view). This ensures torch.Size doesn't cross subgraph boundaries." + ) + + # Verify ordering within the consumer subgraph: sym_size before view + consumer_nodes = list(view_subgraph.graph.graph.nodes) + # CRITICAL VERIFICATION: The sigmoid (splitting/unsafe op) subgraph must + # have a LOWER graph_id than the consumer subgraph. Since subgraphs execute + # in order of graph_id, this proves that: + # 1. Sigmoid runs FIRST + # 2. sym_size + view run SECOND (in consumer subgraph) + # Therefore, sym_size now happens AFTER the unsafe op. + sigmoid_subgraph = splitting_items[0] + assert sigmoid_subgraph.graph_id < view_subgraph.graph_id, ( + f"Sigmoid subgraph (graph_id={sigmoid_subgraph.graph_id}) must execute " + f"before consumer subgraph (graph_id={view_subgraph.graph_id}). " + "This ensures sym_size happens AFTER the unsafe operation." + ) + + sym_size_indices = [ + i + for i, node in enumerate(consumer_nodes) + if node.op == "call_function" and "sym_size" in str(node.target) + ] + view_indices = [ + i + for i, node in enumerate(consumer_nodes) + if node.op == "call_function" and "view" in str(node.target).lower() + ] + + max_sym_size_idx = max(sym_size_indices) + min_view_idx = min(view_indices) + assert max_sym_size_idx < min_view_idx, ( + f"sym_size (max index {max_sym_size_idx}) should come before " + f"view (min index {min_view_idx}) in the consumer subgraph." ) # Verify functional correctness with same-shaped inputs From 10c9793cc78065cfc162029f9e5148d415af6fba Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Wed, 21 Jan 2026 22:00:21 -0800 Subject: [PATCH 15/25] Add replication to all consumer Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 97 ++++++++++++++++----------- vllm/compilation/backends.py | 70 +++++++++---------- 2 files changed, 89 insertions(+), 78 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 751676694ad5..5c674888f9d6 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -12,6 +12,7 @@ # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 +from vllm.compilation.fx_utils import find_op_nodes, is_func def test_getitem_moved_to_producer_subgraph(): @@ -217,11 +218,7 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) # Verify the graph contains sym_size operations - sym_size_nodes = [ - node - for node in gm.graph.nodes - if node.op == "call_function" and "sym_size" in str(node.target) - ] + sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, gm.graph)) assert len(sym_size_nodes) > 0, ( "Test setup failed: graph should contain sym_size operations" ) @@ -241,21 +238,17 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Find which subgraph contains the view operation view_subgraph = None for item in split_items: - for node in item.graph.graph.nodes: - if node.op == "call_function" and "view" in str(node.target).lower(): - view_subgraph = item - break - if view_subgraph: + view_nodes = list(find_op_nodes(torch.ops.aten.view, item.graph.graph)) + if view_nodes: + view_subgraph = item break assert view_subgraph is not None, "Should have a subgraph with view operation" # Verify sym_size operations are in the SAME subgraph as view - sym_size_in_view_subgraph = [ - node - for node in view_subgraph.graph.graph.nodes - if node.op == "call_function" and "sym_size" in str(node.target) - ] + sym_size_in_view_subgraph = list( + find_op_nodes(torch.ops.aten.sym_size, view_subgraph.graph.graph) + ) assert len(sym_size_in_view_subgraph) > 0, ( "sym_size operations should be in the same subgraph as their consumer " "(view). This ensures torch.Size doesn't cross subgraph boundaries." @@ -279,12 +272,12 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: sym_size_indices = [ i for i, node in enumerate(consumer_nodes) - if node.op == "call_function" and "sym_size" in str(node.target) + if is_func(node, torch.ops.aten.sym_size.int) ] view_indices = [ i for i, node in enumerate(consumer_nodes) - if node.op == "call_function" and "view" in str(node.target).lower() + if is_func(node, torch.ops.aten.view.default) ] max_sym_size_idx = max(sym_size_indices) @@ -300,40 +293,66 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: assert torch.allclose(output_original, output_split), "Output mismatch after split" -def test_sym_size_with_multiple_consumers_in_different_subgraphs(): +def test_sym_size_replicated_to_all_consumer_subgraphs(): """ - Test that when a sym_size result is used by consumers in multiple different - subgraphs, it's placed in the earliest consumer subgraph. + Test that sym_size operations are replicated to ALL consumer subgraphs. + + This validates the pattern where each consumer subgraph computes sym_size + locally from the input tensor, rather than receiving it as an input: + + def f(x, y, z): + + sym_size = x.sym_size() # computed locally in subgraph 2 + y2 = y.view(sym_size) + + sym_size = x.sym_size() # computed locally in subgraph 4 + z2 = z.view(sym_size) + + sym_size = x.sym_size() # computed locally in subgraph 6 + w2 = w.view(sym_size) """ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - # Get shape before any split points size = x.shape[0] - - # First split point z1 = torch.sigmoid(x) - - # Use size after first split y1 = y[:size] - - # Second split point z2 = torch.sigmoid(z1) - - # Use size again after second split y2 = y[:size] - - return z2 + y1 + y2 + z3 = torch.sigmoid(z2) + y3 = y[:size] + return z3 + y1 + y2 + y3 x = torch.randn(4, 8) y = torch.randn(8, 8) - # Use symbolic tracing to generate sym_size operations gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) - # Split on both sigmoid operations - split_ops = ["aten::sigmoid"] - split_gm, split_items = split_graph(gm, split_ops) + assert len(list(find_op_nodes(torch.ops.aten.sym_size, gm.graph))) > 0, ( + "Test setup failed: graph should contain sym_size operations" + ) - # Verify functional correctness with same-shaped inputs - output_original = gm(x, y) - output_split = split_gm(x, y) - assert torch.allclose(output_original, output_split), "Output mismatch after split" + split_gm, split_items = split_graph(gm, ["aten::sigmoid"]) + + # Find subgraphs that contain slice operations (consumers of sym_size) + subgraphs_with_slice = [ + item + for item in split_items + if len(list(find_op_nodes(torch.ops.aten.slice, item.graph.graph))) > 0 + ] + + # Find subgraphs that contain sym_size operations + subgraphs_with_sym_size = [ + item + for item in split_items + if len(list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph))) > 0 + ] + + # KEY VERIFICATION: The number of subgraphs with sym_size should equal + # the number of consumer subgraphs (each consumer has its own sym_size) + assert len(subgraphs_with_sym_size) == len(subgraphs_with_slice), ( + f"Expected {len(subgraphs_with_slice)} subgraphs with sym_size " + f"(one per consumer), but found {len(subgraphs_with_sym_size)}. " + "This indicates sym_size was not properly replicated to all consumers." + ) + + # Verify functional correctness + assert torch.allclose(gm(x, y), split_gm(x, y)), "Output mismatch after split" diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 935b36381342..3f51eccf51c1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -348,34 +348,19 @@ def _is_sym_size_op(node: fx.Node) -> bool: return False -def _move_sym_size_nodes_for_split( +def _replicate_sym_size_nodes_for_split( graph: fx.GraphModule, node_to_subgraph_id: dict[fx.Node, int], ) -> None: """ - Move sym_size operations to the same subgraph as their consumers. - - When splitting a graph, if a sym_size call is in one submodule and its - consumer is in another, PyTorch 2 has issues because torch.Size is not - fully supported as a submodule output. This function reorders sym_size - nodes to be just before their consumers when they would otherwise cross - subgraph boundaries. - - Pattern being fixed: - # Old (causes issues): - size = tensor_a.shape # subgraph 0 - some_cg_unsafe_op # subgraph 1 (split point) - tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) - - # New (fixed): - some_cg_unsafe_op # subgraph 1 (split point) - size = tensor_a.shape # moved to subgraph 2 - tensor_b = tensor_b.view(size) # subgraph 2 (consumes size) - """ - # Collect all sym_size nodes that need to be moved - sym_size_nodes_to_move: list[tuple[fx.Node, int]] = [] + Replicate sym_size operations to ALL consumer subgraphs. - for node in graph.graph.nodes: + When splitting a graph, if a sym_size call has consumers in multiple + subgraphs, we replicate the sym_size operation to each consumer subgraph. + This ensures each subgraph computes sym_size locally rather than receiving + it as an input, avoiding torch.Size crossing subgraph boundaries. + """ + for node in list(graph.graph.nodes): if node.op in ("output", "placeholder"): continue @@ -386,28 +371,35 @@ def _move_sym_size_nodes_for_split( if node_subgraph is None: continue - # Find the minimum subgraph ID among all consumers of this sym_size - consumer_subgraph_ids: list[int] = [] + # Group consumers by their subgraph ID (only those in later subgraphs) + subgraph_to_consumers: dict[int, list[fx.Node]] = {} for user in node.users: if user.op == "output": continue user_subgraph = node_to_subgraph_id.get(user) - if user_subgraph is not None: - consumer_subgraph_ids.append(user_subgraph) + if user_subgraph is not None and user_subgraph > node_subgraph: + if user_subgraph not in subgraph_to_consumers: + subgraph_to_consumers[user_subgraph] = [] + subgraph_to_consumers[user_subgraph].append(user) - if not consumer_subgraph_ids: + if not subgraph_to_consumers: continue - # The minimum consumer subgraph is where we want to move the sym_size - min_consumer_subgraph = min(consumer_subgraph_ids) + # Create a copy of sym_size for EACH consumer subgraph + for subgraph_id, consumer_list in subgraph_to_consumers.items(): + with graph.graph.inserting_before(consumer_list[0]): + new_sym_size = graph.graph.call_function( + node.target, + args=node.args, + kwargs=node.kwargs, + ) + if node.meta: + new_sym_size.meta = node.meta.copy() - # Only move if the sym_size would cross into a later subgraph - if min_consumer_subgraph > node_subgraph: - sym_size_nodes_to_move.append((node, min_consumer_subgraph)) + node_to_subgraph_id[new_sym_size] = subgraph_id - # Update the subgraph assignments for sym_size nodes that need to move - for node, new_subgraph_id in sym_size_nodes_to_move: - node_to_subgraph_id[node] = new_subgraph_id + for consumer in consumer_list: + consumer.replace_input_with(node, new_sym_size) def split_graph( @@ -448,11 +440,11 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id - # Move sym_size operations (tensor.shape accesses) to be closer to their - # consumers. This avoids issues where PT2 doesn't support torch.Size as + # Replicate sym_size operations (tensor.shape accesses) to all consumer + # subgraphs. This avoids issues where PT2 doesn't support torch.Size as # submodule output when sym_size is in one subgraph and its consumer is # in another. - _move_sym_size_nodes_for_split(graph, node_to_subgraph_id) + _replicate_sym_size_nodes_for_split(graph, node_to_subgraph_id) # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and From 8bd2fc90070cd184d6a1ad7a09306a875af8e19a Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Fri, 23 Jan 2026 09:25:40 -0800 Subject: [PATCH 16/25] Add repro-level bug on cuda_graph address assignment to ensure the fix Signed-off-by: Xiao Fu --- vllm/compilation/cuda_graph.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 729da2c3acad..753ccaac8f3f 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -288,6 +288,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: compilation_counter.num_cudagraph_captured += 1 + # Save input addresses for debugging/replay verification + if self.is_debugging_mode: + entry.input_addresses = input_addresses + # important: we need to return the output, rather than # the weak ref of the output, so that pytorch can correctly # manage the memory during cuda graph capture From b204b4d19164acad06bf3c76a3fed85bb1b65145 Mon Sep 17 00:00:00 2001 From: Xiao <31429901+fxdawnn@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:18:38 -0800 Subject: [PATCH 17/25] Update vllm/compilation/backends.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: Xiao <31429901+fxdawnn@users.noreply.github.com> --- vllm/compilation/backends.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3f51eccf51c1..14c6ceeadbab 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -336,16 +336,16 @@ def _is_sym_size_op(node: fx.Node) -> bool: """Check if a node is a sym_size operation (tensor.shape access).""" if node.op != "call_function": return False - target = node.target + + if not hasattr(torch.ops.aten, "sym_size"): + return False + # Handle both torch.ops.aten.sym_size.int and sym_size.default - if hasattr(torch.ops.aten, "sym_size"): - sym_size_ops = ( - torch.ops.aten.sym_size, - torch.ops.aten.sym_size.int, - torch.ops.aten.sym_size.default, - ) - return target in sym_size_ops - return False + return node.target in ( + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_size.default, + ) def _replicate_sym_size_nodes_for_split( From 24c0ea684e34130d4bad5243ad20341f2c7a688a Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Wed, 28 Jan 2026 14:52:58 -0800 Subject: [PATCH 18/25] Modify the test for scenario with torch.tensor() Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 230 ++++++++++++---- vllm/compilation/backends.py | 372 +++++++++++++++++++++++--- 2 files changed, 517 insertions(+), 85 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 5c674888f9d6..8765c86de70c 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -5,9 +5,13 @@ import pytest import torch +import torch._dynamo +import torch.fx as fx from torch.fx.experimental.proxy_tensor import make_fx -from vllm.compilation.backends import split_graph +from vllm.compilation.backends import ( + split_graph, +) from vllm.compilation.fx_utils import find_op_nodes # This import automatically registers `torch.ops.silly.attention` @@ -15,6 +19,15 @@ from vllm.compilation.fx_utils import find_op_nodes, is_func +@pytest.fixture +def vllm_compile_env(monkeypatch): + """Set up vLLM compilation environment variables for testing.""" + monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") + monkeypatch.setenv("VLLM_LOGGING_LEVEL", "debug") + yield + + def test_getitem_moved_to_producer_subgraph(): """ Test that getitem operations are moved to the same subgraph as their input, @@ -293,66 +306,185 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: assert torch.allclose(output_original, output_split), "Output mismatch after split" -def test_sym_size_replicated_to_all_consumer_subgraphs(): +def test_sym_size_with_torch_compile_and_mark_dynamic(): """ - Test that sym_size operations are replicated to ALL consumer subgraphs. - - This validates the pattern where each consumer subgraph computes sym_size - locally from the input tensor, rather than receiving it as an input: - - def f(x, y, z): - - sym_size = x.sym_size() # computed locally in subgraph 2 - y2 = y.view(sym_size) - - sym_size = x.sym_size() # computed locally in subgraph 4 - z2 = z.view(sym_size) - - sym_size = x.sym_size() # computed locally in subgraph 6 - w2 = w.view(sym_size) + Test handling of SymInt placeholders from torch.compile with mark_dynamic + across MULTIPLE split subgraphs. + + When using torch.compile + mark_dynamic, the captured graph has: + - SymInt placeholders (e.g., s77) as separate inputs + - Operations that use the SymInt directly (e.g., view([s77, 8])) + + standalone_compile / inductor expects only tensor inputs. split_graph must: + 1. Replace SymInt placeholder uses with sym_size calls on tensor inputs + 2. Replicate sym_size to ALL consumer subgraphs that need the dynamic size + 3. Remove unused SymInt placeholders from the final graph + + This test validates the complete SymInt -> sym_size pipeline with MULTIPLE + split boundaries to ensure sym_size is correctly replicated across subgraphs: + - Phase 1: SymInt placeholders exist in the captured graph + - Phase 2 & 3: split_graph handles SymInt replacement and removal + - Phase 4: sym_size.int exists in EACH consumer subgraph that needs it + - Phase 5: Functional correctness with original input + - Phase 6: Functional correctness with different batch size + - Phase 7: Validate multiple split subgraphs exist """ + captured_graph = None - def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - size = x.shape[0] - z1 = torch.sigmoid(x) - y1 = y[:size] - z2 = torch.sigmoid(z1) - y2 = y[:size] - z3 = torch.sigmoid(z2) - y3 = y[:size] - return z3 + y1 + y2 + y3 + def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: + nonlocal captured_graph + captured_graph = gm + return gm + + def model_fn(x: torch.Tensor) -> torch.Tensor: + # Get the dynamic shape before any splits + batch_size = x.shape[0] + hidden_size = x.shape[1] + + # First split point - sigmoid #1 + x = torch.ops.aten.sigmoid.default(x) + + # Use dynamic size after first split - creates sym_size consumer + x = x.clone().view(batch_size, hidden_size) + + # Second split point - sigmoid #2 + x = torch.ops.aten.sigmoid.default(x) + + # Use dynamic size again after second split - another sym_size consumer + x = x.clone().view(batch_size, hidden_size) + + # Third split point - sigmoid #3 + x = torch.ops.aten.sigmoid.default(x) + + # Use dynamic size again after third split - yet another consumer + x = x.clone().view(batch_size, hidden_size) + + return x x = torch.randn(4, 8) - y = torch.randn(8, 8) - gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + # Mark the first dimension as dynamic + torch._dynamo.mark_dynamic(x, 0) - assert len(list(find_op_nodes(torch.ops.aten.sym_size, gm.graph))) > 0, ( - "Test setup failed: graph should contain sym_size operations" + compiled_fn = torch.compile(model_fn, backend=capturing_backend) + compiled_fn(x) + + assert captured_graph is not None, "Graph should be captured by backend" + + # ===== PHASE 1: Validate SymInt placeholders exist in captured graph ===== + symint_placeholders = [ + node + for node in captured_graph.graph.nodes + if node.op == "placeholder" + and node.meta.get("example_value") is not None + and isinstance(node.meta.get("example_value"), torch.SymInt) + ] + assert len(symint_placeholders) > 0, ( + "Phase 1 FAILED: Captured graph should have SymInt placeholders from " + "mark_dynamic. This is the prerequisite for testing the sym_size pipeline." ) - split_gm, split_items = split_graph(gm, ["aten::sigmoid"]) + # Record original SymInt users for later validation + original_symint_users = {} + for symint_node in symint_placeholders: + users = [u for u in symint_node.users if u.op != "output"] + original_symint_users[symint_node.name] = [u.name for u in users] + + # ===== PHASE 2 & 3: split_graph handles SymInt replacement and removal ===== + # NOTE: split_graph modifies the input graph in-place! + # With 3 sigmoid operations, we expect 7 subgraphs: + # submod_0 (before sigmoid #1), submod_1 (sigmoid #1), + # submod_2 (between sigmoid #1 and #2), submod_3 (sigmoid #2), + # submod_4 (between sigmoid #2 and #3), submod_5 (sigmoid #3), + # submod_6 (after sigmoid #3) + split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) + + # ===== PHASE 7: Validate multiple split subgraphs exist ===== + # Count splitting subgraphs (the sigmoid operations) + splitting_subgraphs = [item for item in split_items if item.is_splitting_graph] + + assert len(splitting_subgraphs) == 3, ( + f"Phase 7 FAILED: Expected 3 splitting subgraphs (3 sigmoids), " + f"got {len(splitting_subgraphs)}" + ) + # Note: Total subgraphs can be 6 or 7 depending on whether there are + # operations before the first sigmoid. With torch.compile, shape access + # operations may be folded differently, resulting in 6 subgraphs: + # submod_1 (sigmoid #1), submod_2 (compute), submod_3 (sigmoid #2), + # submod_4 (compute), submod_5 (sigmoid #3), submod_6 (compute) + assert len(split_items) >= 6, ( + f"Phase 7 FAILED: Expected at least 6 total subgraphs " + f"(3 sigmoids + at least 3 compute blocks), got {len(split_items)}" + ) - # Find subgraphs that contain slice operations (consumers of sym_size) - subgraphs_with_slice = [ - item - for item in split_items - if len(list(find_op_nodes(torch.ops.aten.slice, item.graph.graph))) > 0 + # ===== PHASE 3: Validate SymInt placeholders are removed from split_gm ===== + split_placeholders = [ + node for node in split_gm.graph.nodes if node.op == "placeholder" ] - # Find subgraphs that contain sym_size operations - subgraphs_with_sym_size = [ - item - for item in split_items - if len(list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph))) > 0 + remaining_symint_placeholders = [ + node + for node in split_placeholders + if node.meta.get("example_value") is not None + and isinstance(node.meta.get("example_value"), torch.SymInt) ] + assert len(remaining_symint_placeholders) == 0, ( + f"Phase 3 FAILED: split_gm should not have SymInt placeholders after " + f"_remove_symint_placeholders. Found: " + f"{[n.name for n in remaining_symint_placeholders]}. " + "This means SymInt would be passed as input which inductor doesn't support." + ) + + # ===== PHASE 4: Validate sym_size.int exists in consumer subgraphs ===== + # Each non-splitting subgraph that uses dynamic sizes should have sym_size.int + # to compute the dynamic dimension locally from the tensor input. + total_sym_size_nodes = 0 + subgraphs_with_sym_size = [] + + for item in split_items: + sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph)) + + if sym_size_nodes: + total_sym_size_nodes += len(sym_size_nodes) + subgraphs_with_sym_size.append(item.submod_name) - # KEY VERIFICATION: The number of subgraphs with sym_size should equal - # the number of consumer subgraphs (each consumer has its own sym_size) - assert len(subgraphs_with_sym_size) == len(subgraphs_with_slice), ( - f"Expected {len(subgraphs_with_slice)} subgraphs with sym_size " - f"(one per consumer), but found {len(subgraphs_with_sym_size)}. " - "This indicates sym_size was not properly replicated to all consumers." + assert total_sym_size_nodes > 0, ( + "Phase 4 FAILED: No sym_size.int nodes found in any subgraph. " + "split_graph should replace SymInt placeholders with sym_size.int calls " + "that compute dynamic sizes from tensor inputs." ) - # Verify functional correctness - assert torch.allclose(gm(x, y), split_gm(x, y)), "Output mismatch after split" + # With 3 split boundaries and dynamic size usage after each split, + # we expect sym_size to be replicated to multiple consumer subgraphs + assert len(subgraphs_with_sym_size) >= 3, ( + f"Phase 4 FAILED: sym_size should exist in consumer subgraphs. " + f"Found sym_size in {len(subgraphs_with_sym_size)} subgraphs: " + f"{subgraphs_with_sym_size}" + ) + + # ===== PHASE 5: Validate functional correctness ===== + # split_gm should work with tensor-only input (no SymInt) + output_split = split_gm(x) + + # Handle case where output is a tuple + if isinstance(output_split, tuple): + output_split = output_split[0] + + # For reference, run the model directly to get expected output + expected_output = model_fn(x) + + assert torch.allclose(expected_output, output_split), ( + "Phase 5 FAILED: Output mismatch after split. The sym_size pipeline " + "should preserve functional correctness." + ) + + # ===== PHASE 6: Validate with different batch size ===== + # The dynamic dimension should work with different sizes + x_different = torch.randn(8, 8) # Different batch size + output_different = split_gm(x_different) + if isinstance(output_different, tuple): + output_different = output_different[0] + expected_different = model_fn(x_different) + assert torch.allclose(expected_different, output_different), ( + "Phase 6 FAILED: Output mismatch with different batch size. " + "sym_size should correctly compute the dynamic dimension at runtime." + ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 14c6ceeadbab..7758bb57de5c 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -332,10 +332,11 @@ class SplitItem: graph: fx.GraphModule -def _is_sym_size_op(node: fx.Node) -> bool: - """Check if a node is a sym_size operation (tensor.shape access).""" - if node.op != "call_function": +def _is_symint_placeholder(node: fx.Node) -> bool: + """Check if a node is a SymInt placeholder (from torch.compile + mark_dynamic).""" + if node.op != "placeholder": return False +<<<<<<< HEAD if not hasattr(torch.ops.aten, "sym_size"): return False @@ -346,60 +347,121 @@ def _is_sym_size_op(node: fx.Node) -> bool: torch.ops.aten.sym_size.int, torch.ops.aten.sym_size.default, ) +======= + example_value = node.meta.get("example_value") + return example_value is not None and isinstance(example_value, torch.SymInt) +>>>>>>> e726e436c (Modify the test for scenario with torch.tensor()) -def _replicate_sym_size_nodes_for_split( +def _find_tensor_for_symint( + symint_value: torch.SymInt, + graph: fx.GraphModule, +) -> tuple[fx.Node, int] | None: + """ + Find a tensor placeholder with a dimension matching the given SymInt. + + Returns (tensor_node, dim) or None if no match found. + """ + for node in graph.graph.nodes: + if node.op != "placeholder": + continue + tensor_value = node.meta.get("example_value") + if tensor_value is None or not isinstance(tensor_value, torch.Tensor): + continue + if not hasattr(tensor_value, "shape"): + continue + + for dim, size in enumerate(tensor_value.shape): + # Match by identity + if size is symint_value: + return (node, dim) + # Match by underlying symbolic node + if ( + hasattr(size, "node") + and hasattr(symint_value, "node") + and size.node is symint_value.node + ): + return (node, dim) + # Match by string representation (fallback) + if str(size) == str(symint_value): + return (node, dim) + + return None + + +def _replace_symint_placeholders( graph: fx.GraphModule, node_to_subgraph_id: dict[fx.Node, int], ) -> None: """ - Replicate sym_size operations to ALL consumer subgraphs. + Replace SymInt placeholder uses with sym_size calls. - When splitting a graph, if a sym_size call has consumers in multiple - subgraphs, we replicate the sym_size operation to each consumer subgraph. - This ensures each subgraph computes sym_size locally rather than receiving - it as an input, avoiding torch.Size crossing subgraph boundaries. + When using torch.compile with mark_dynamic, the captured graph has SymInt + placeholders (e.g., s77) as separate inputs. standalone_compile / inductor + expects only tensor inputs. + + This function creates sym_size.int nodes to replace SymInt placeholder uses. + + IMPORTANT: We do NOT delete the SymInt placeholders here because split_module + needs them for its symbol_to_node mapping. If we delete them, split_module + fails with KeyError when processing tensors whose shapes contain the symbol. + The placeholders are removed AFTER split_module by _remove_symint_placeholders. """ for node in list(graph.graph.nodes): - if node.op in ("output", "placeholder"): + if not _is_symint_placeholder(node): continue - if not _is_sym_size_op(node): + symint_value = node.meta.get("example_value") + if symint_value is None: continue - node_subgraph = node_to_subgraph_id.get(node) - if node_subgraph is None: + tensor_dim = _find_tensor_for_symint(symint_value, graph) + if tensor_dim is None: + logger.warning( + "Could not find tensor dimension for SymInt placeholder %s", + node.name, + ) continue - # Group consumers by their subgraph ID (only those in later subgraphs) + tensor_node, dim = tensor_dim + + # Get list of users before modifying + users_list = list(node.users.keys()) + if not users_list: + # No users, keep the placeholder for symbol_to_node mapping + continue + + # Create sym_size for each subgraph that uses this SymInt subgraph_to_consumers: dict[int, list[fx.Node]] = {} - for user in node.users: + for user in users_list: if user.op == "output": continue - user_subgraph = node_to_subgraph_id.get(user) - if user_subgraph is not None and user_subgraph > node_subgraph: - if user_subgraph not in subgraph_to_consumers: - subgraph_to_consumers[user_subgraph] = [] - subgraph_to_consumers[user_subgraph].append(user) - - if not subgraph_to_consumers: - continue + user_subgraph = node_to_subgraph_id.get(user, 0) + if user_subgraph not in subgraph_to_consumers: + subgraph_to_consumers[user_subgraph] = [] + subgraph_to_consumers[user_subgraph].append(user) - # Create a copy of sym_size for EACH consumer subgraph for subgraph_id, consumer_list in subgraph_to_consumers.items(): with graph.graph.inserting_before(consumer_list[0]): - new_sym_size = graph.graph.call_function( - node.target, - args=node.args, - kwargs=node.kwargs, + sym_size_node = graph.graph.call_function( + torch.ops.aten.sym_size.int, + args=(tensor_node, dim), ) if node.meta: - new_sym_size.meta = node.meta.copy() + sym_size_node.meta = node.meta.copy() - node_to_subgraph_id[new_sym_size] = subgraph_id + node_to_subgraph_id[sym_size_node] = subgraph_id for consumer in consumer_list: - consumer.replace_input_with(node, new_sym_size) + consumer.replace_input_with(node, sym_size_node) + + # NOTE: We do NOT delete the SymInt placeholder here! + # split_module needs it for symbol_to_node mapping. + # It will be removed by _remove_symint_placeholders after split_module. + + # NOTE: We skip lint()/recompile() here since split_module reads from + # graph.graph.nodes directly, not the forward() method. This avoids + # potential issues with graph state changes before split_module. def split_graph( @@ -440,11 +502,11 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id - # Replicate sym_size operations (tensor.shape accesses) to all consumer - # subgraphs. This avoids issues where PT2 doesn't support torch.Size as - # submodule output when sym_size is in one subgraph and its consumer is - # in another. - _replicate_sym_size_nodes_for_split(graph, node_to_subgraph_id) + # Replace SymInt placeholders with sym_size.int calls and delete them. + # This is needed for torch.compile + mark_dynamic, where the captured graph + # has SymInt placeholders as separate inputs. standalone_compile / inductor + # expects only tensor inputs. + _replace_symint_placeholders(graph, node_to_subgraph_id) # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and @@ -454,6 +516,13 @@ def split_graph( graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True ) + # Note: With the simplified approach, _replace_symint_placeholders_with_sym_size + # now DELETES SymInt placeholders BEFORE split_module runs. This prevents + # split_module from threading SymInt through submodules. The post-split cleanup + # _remove_symint_placeholders is still called as a safety net in case any + # SymInt placeholders remain (e.g., if they couldn't be replaced). + _remove_symint_placeholders(split_gm) + outputs = [] names = [name for (name, module) in split_gm.named_modules()] @@ -474,6 +543,237 @@ def split_graph( return split_gm, outputs +def _remove_symint_placeholders(gm: fx.GraphModule) -> None: + """ + Remove SymInt placeholders from a GraphModule after split_module. + + After split_module, SymInt placeholders may still exist and may have users + (call_module nodes that pass the SymInt to submodules). This function: + 1. Replaces SymInt arguments in call_module nodes with sym_size.int calls + 2. Removes the now-unused SymInt placeholders + + This ensures the final graph only requires tensor inputs. + """ + # Collect SymInt and tensor placeholders + symint_placeholders: list[fx.Node] = [] + tensor_placeholders: list[fx.Node] = [] + + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + example_value = node.meta.get("example_value") + if example_value is None: + continue + if isinstance(example_value, torch.SymInt): + symint_placeholders.append(node) + elif isinstance(example_value, torch.Tensor): + tensor_placeholders.append(node) + + if not symint_placeholders: + return + + # Build mapping from SymInt placeholder to (tensor, dim) that can compute it + symint_to_tensor_dim: dict[fx.Node, tuple[fx.Node, int]] = {} + + for symint_node in symint_placeholders: + symint_value = symint_node.meta.get("example_value") + if symint_value is None: + continue + + # Find a tensor with a dynamic dimension matching this SymInt + for tensor_node in tensor_placeholders: + tensor_value = tensor_node.meta.get("example_value") + if tensor_value is None or not hasattr(tensor_value, "shape"): + continue + + for dim, size in enumerate(tensor_value.shape): + # Match by identity + if size is symint_value: + symint_to_tensor_dim[symint_node] = (tensor_node, dim) + break + # Match by underlying symbolic node + if ( + hasattr(size, "node") + and hasattr(symint_value, "node") + and size.node is symint_value.node + ): + symint_to_tensor_dim[symint_node] = (tensor_node, dim) + break + # Match by string representation (fallback) + if str(size) == str(symint_value): + symint_to_tensor_dim[symint_node] = (tensor_node, dim) + break + + if symint_node in symint_to_tensor_dim: + break + + logger.debug( + "Mapped SymInt placeholders to tensor dims: %s", + {n.name: (t.name, d) for n, (t, d) in symint_to_tensor_dim.items()}, + ) + + # For each SymInt placeholder that has users (call_module nodes), replace + # the SymInt argument with a sym_size.int call on the corresponding tensor + nodes_modified = False + for symint_node in symint_placeholders: + if not symint_node.users: + # No users, can just delete + gm.graph.erase_node(symint_node) + nodes_modified = True + continue + + if symint_node not in symint_to_tensor_dim: + logger.warning( + "Could not find tensor dimension for SymInt placeholder %s", + symint_node.name, + ) + continue + + tensor_node, dim = symint_to_tensor_dim[symint_node] + + # Replace each use of the SymInt with a sym_size.int call + # We need to create a new sym_size node before each user + users_list = list(symint_node.users.keys()) + for user in users_list: + if user.op != "call_module": + # For non-call_module users, create sym_size before them + with gm.graph.inserting_before(user): + sym_size_node = gm.graph.call_function( + torch.ops.aten.sym_size.int, + args=(tensor_node, dim), + ) + if symint_node.meta: + sym_size_node.meta = symint_node.meta.copy() + user.replace_input_with(symint_node, sym_size_node) + else: + # For call_module nodes, we need to remove the SymInt from args + # and update the submodule to compute sym_size locally + _update_submodule_to_compute_symint_locally( + gm, user, symint_node, tensor_node, dim + ) + + # Now the SymInt placeholder should have no users + if not symint_node.users: + gm.graph.erase_node(symint_node) + nodes_modified = True + else: + logger.warning( + "SymInt placeholder %s still has %d users after processing: %s", + symint_node.name, + len(symint_node.users), + list(symint_node.users.keys()), + ) + + if nodes_modified: + gm.graph.lint() + gm.recompile() + + +def _update_submodule_to_compute_symint_locally( + gm: fx.GraphModule, + call_module_node: fx.Node, + symint_node: fx.Node, + tensor_node: fx.Node, + dim: int, +) -> None: + """ + Update a submodule call to compute SymInt locally instead of receiving it. + + This modifies: + 1. The call_module node's args to remove the SymInt and ensure tensor is passed + 2. The submodule to compute sym_size.int from the tensor instead of taking + SymInt as a parameter + """ + submod_name = call_module_node.target + submodule = getattr(gm, submod_name) + + # Find which argument position(s) correspond to symint_node and tensor_node + old_args = list(call_module_node.args) + symint_arg_indices = [i for i, arg in enumerate(old_args) if arg is symint_node] + tensor_arg_indices = [i for i, arg in enumerate(old_args) if arg is tensor_node] + + if not symint_arg_indices: + return + + # Get the submodule's placeholder nodes + submod_placeholders = [n for n in submodule.graph.nodes if n.op == "placeholder"] + + # Find the placeholder in submodule that corresponds to the SymInt + symint_placeholder_idx = symint_arg_indices[0] + if symint_placeholder_idx >= len(submod_placeholders): + logger.warning( + "SymInt arg index %d out of range for submodule %s with %d placeholders", + symint_placeholder_idx, + submod_name, + len(submod_placeholders), + ) + return + + symint_submod_placeholder = submod_placeholders[symint_placeholder_idx] + + # Find or ensure there's a placeholder for the tensor in the submodule + tensor_submod_placeholder = None + if tensor_arg_indices: + tensor_placeholder_idx = tensor_arg_indices[0] + if tensor_placeholder_idx < len(submod_placeholders): + tensor_submod_placeholder = submod_placeholders[tensor_placeholder_idx] + + if tensor_submod_placeholder is None: + # Tensor is not currently passed to this submodule, need to add it + # Add tensor to call_module args (at the end) + new_args = list(old_args) + [tensor_node] + # Also remove the SymInt from args + new_args = [ + arg for i, arg in enumerate(new_args) if i not in symint_arg_indices + ] + call_module_node.args = tuple(new_args) + + # Add new placeholder to submodule at the end + last_placeholder = submod_placeholders[-1] + with submodule.graph.inserting_after(last_placeholder): + tensor_submod_placeholder = submodule.graph.placeholder("tensor_for_symint") + if tensor_node.meta: + tensor_submod_placeholder.meta = tensor_node.meta.copy() + + else: + # Tensor is already passed, just need to update args to remove SymInt + new_args = [ + arg for i, arg in enumerate(old_args) if i not in symint_arg_indices + ] + call_module_node.args = tuple(new_args) + + # Find first node to insert sym_size before (after placeholders/get_attr) + insert_point = None + for node in submodule.graph.nodes: + if node.op not in ("placeholder", "get_attr"): + insert_point = node + break + + if insert_point is None: + logger.warning("Could not find insertion point in submodule %s", submod_name) + return + + # Create sym_size.int node in submodule + with submodule.graph.inserting_before(insert_point): + sym_size_node = submodule.graph.call_function( + torch.ops.aten.sym_size.int, + args=(tensor_submod_placeholder, dim), + ) + if symint_submod_placeholder.meta: + sym_size_node.meta = symint_submod_placeholder.meta.copy() + + # Replace all uses + + # Replace all uses of SymInt placeholder with sym_size node + symint_submod_placeholder.replace_all_uses_with(sym_size_node) + + # Remove the SymInt placeholder from submodule + submodule.graph.erase_node(symint_submod_placeholder) + + submodule.graph.lint() + submodule.recompile() + + compilation_start_time = 0.0 From 018fe847afda28835c6d8e008bd49185f65044e9 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Mon, 2 Feb 2026 12:45:26 -0800 Subject: [PATCH 19/25] Fix rebase and arrangement Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 4 +--- vllm/compilation/backends.py | 7 ++----- vllm/compilation/cuda_graph.py | 5 +---- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 8765c86de70c..31e986c47bd7 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -12,11 +12,10 @@ from vllm.compilation.backends import ( split_graph, ) -from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.fx_utils import find_op_nodes, is_func # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 -from vllm.compilation.fx_utils import find_op_nodes, is_func @pytest.fixture @@ -200,7 +199,6 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: ] + ["output"] - def test_sym_size_moved_across_split_boundary(): """ Test that sym_size operations (tensor.shape accesses) are moved to the same diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7758bb57de5c..8f0aa068187a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -336,21 +336,18 @@ def _is_symint_placeholder(node: fx.Node) -> bool: """Check if a node is a SymInt placeholder (from torch.compile + mark_dynamic).""" if node.op != "placeholder": return False -<<<<<<< HEAD - + if not hasattr(torch.ops.aten, "sym_size"): return False - + # Handle both torch.ops.aten.sym_size.int and sym_size.default return node.target in ( torch.ops.aten.sym_size, torch.ops.aten.sym_size.int, torch.ops.aten.sym_size.default, ) -======= example_value = node.meta.get("example_value") return example_value is not None and isinstance(example_value, torch.SymInt) ->>>>>>> e726e436c (Modify the test for scenario with torch.tensor()) def _find_tensor_for_symint( diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 753ccaac8f3f..b2e831b0d294 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -247,6 +247,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) ] + entry.input_addresses = input_addresses cudagraph = torch.cuda.CUDAGraph() with ExitStack() as stack: @@ -288,10 +289,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: compilation_counter.num_cudagraph_captured += 1 - # Save input addresses for debugging/replay verification - if self.is_debugging_mode: - entry.input_addresses = input_addresses - # important: we need to return the output, rather than # the weak ref of the output, so that pytorch can correctly # manage the memory during cuda graph capture From d936732726eb3441d4abfccf3f1e94c371829d75 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Mon, 2 Feb 2026 13:32:19 -0800 Subject: [PATCH 20/25] Fix error Signed-off-by: Xiao Fu --- vllm/compilation/cuda_graph.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 9b76bb166ec5..b2e831b0d294 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -289,10 +289,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: compilation_counter.num_cudagraph_captured += 1 - # Save input addresses for debugging/replay verification - if self.is_debugging_mode: - entry.input_addresses = input_addresses - # important: we need to return the output, rather than # the weak ref of the output, so that pytorch can correctly # manage the memory during cuda graph capture From b0666e0c33f824b62e0c85d6a6831bd43622e37c Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Fri, 13 Feb 2026 15:39:54 -0800 Subject: [PATCH 21/25] Reduce logic overhead --- vllm/compilation/backends.py | 258 +++++------------------------------ 1 file changed, 37 insertions(+), 221 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a5b6bcd8799d..204e8ee1ddee 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -347,9 +347,6 @@ def _is_symint_placeholder(node: fx.Node) -> bool: torch.ops.aten.sym_size.default, ) - example_value = node.meta.get("example_value") - return example_value is not None and isinstance(example_value, torch.SymInt) - def _find_tensor_for_symint( symint_value: torch.SymInt, @@ -514,11 +511,7 @@ def split_graph( graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True ) - # Note: With the simplified approach, _replace_symint_placeholders_with_sym_size - # now DELETES SymInt placeholders BEFORE split_module runs. This prevents - # split_module from threading SymInt through submodules. The post-split cleanup - # _remove_symint_placeholders is still called as a safety net in case any - # SymInt placeholders remain (e.g., if they couldn't be replaced). + # Remove any remaining SymInt placeholders after split_module. _remove_symint_placeholders(split_gm) outputs = [] @@ -545,233 +538,56 @@ def _remove_symint_placeholders(gm: fx.GraphModule) -> None: """ Remove SymInt placeholders from a GraphModule after split_module. - After split_module, SymInt placeholders may still exist and may have users - (call_module nodes that pass the SymInt to submodules). This function: - 1. Replaces SymInt arguments in call_module nodes with sym_size.int calls - 2. Removes the now-unused SymInt placeholders - - This ensures the final graph only requires tensor inputs. + Since _replace_symint_placeholders already replaced all SymInt users with + sym_size.int calls before split_module, the SymInt placeholders should have + no real consumers. However, split_module may still thread them through to + call_module nodes via its symbol_to_node tracking. This function removes + those spurious references and erases the SymInt placeholders so the final + graph only requires tensor inputs. """ - # Collect SymInt and tensor placeholders - symint_placeholders: list[fx.Node] = [] - tensor_placeholders: list[fx.Node] = [] - + nodes_to_erase = [] for node in gm.graph.nodes: if node.op != "placeholder": continue example_value = node.meta.get("example_value") - if example_value is None: - continue - if isinstance(example_value, torch.SymInt): - symint_placeholders.append(node) - elif isinstance(example_value, torch.Tensor): - tensor_placeholders.append(node) - - if not symint_placeholders: - return - - # Build mapping from SymInt placeholder to (tensor, dim) that can compute it - symint_to_tensor_dim: dict[fx.Node, tuple[fx.Node, int]] = {} - - for symint_node in symint_placeholders: - symint_value = symint_node.meta.get("example_value") - if symint_value is None: + if not isinstance(example_value, torch.SymInt): continue - # Find a tensor with a dynamic dimension matching this SymInt - for tensor_node in tensor_placeholders: - tensor_value = tensor_node.meta.get("example_value") - if tensor_value is None or not hasattr(tensor_value, "shape"): - continue - - for dim, size in enumerate(tensor_value.shape): - # Match by identity - if size is symint_value: - symint_to_tensor_dim[symint_node] = (tensor_node, dim) - break - # Match by underlying symbolic node - if ( - hasattr(size, "node") - and hasattr(symint_value, "node") - and size.node is symint_value.node - ): - symint_to_tensor_dim[symint_node] = (tensor_node, dim) - break - # Match by string representation (fallback) - if str(size) == str(symint_value): - symint_to_tensor_dim[symint_node] = (tensor_node, dim) - break - - if symint_node in symint_to_tensor_dim: - break - - logger.debug( - "Mapped SymInt placeholders to tensor dims: %s", - {n.name: (t.name, d) for n, (t, d) in symint_to_tensor_dim.items()}, - ) - - # For each SymInt placeholder that has users (call_module nodes), replace - # the SymInt argument with a sym_size.int call on the corresponding tensor - nodes_modified = False - for symint_node in symint_placeholders: - if not symint_node.users: - # No users, can just delete - gm.graph.erase_node(symint_node) - nodes_modified = True - continue - - if symint_node not in symint_to_tensor_dim: + # Remove this SymInt from any call_module args that split_module + # may have threaded it into. + for user in list(node.users.keys()): + if user.op == "call_module": + new_args = tuple(a for a in user.args if a is not node) + user.args = new_args + + # Also remove the corresponding placeholder from the + # submodule so its signature stays in sync. + submodule = getattr(gm, user.target) + for submod_node in list(submodule.graph.nodes): + if submod_node.op != "placeholder": + continue + sub_ev = submod_node.meta.get("example_value") + if isinstance(sub_ev, torch.SymInt): + if not submod_node.users: + submodule.graph.erase_node(submod_node) + submodule.graph.lint() + submodule.recompile() + + if node.users: logger.warning( - "Could not find tensor dimension for SymInt placeholder %s", - symint_node.name, + "SymInt placeholder %s still has users: %s", + node.name, + list(node.users.keys()), ) continue - - tensor_node, dim = symint_to_tensor_dim[symint_node] - - # Replace each use of the SymInt with a sym_size.int call - # We need to create a new sym_size node before each user - users_list = list(symint_node.users.keys()) - for user in users_list: - if user.op != "call_module": - # For non-call_module users, create sym_size before them - with gm.graph.inserting_before(user): - sym_size_node = gm.graph.call_function( - torch.ops.aten.sym_size.int, - args=(tensor_node, dim), - ) - if symint_node.meta: - sym_size_node.meta = symint_node.meta.copy() - user.replace_input_with(symint_node, sym_size_node) - else: - # For call_module nodes, we need to remove the SymInt from args - # and update the submodule to compute sym_size locally - _update_submodule_to_compute_symint_locally( - gm, user, symint_node, tensor_node, dim - ) - - # Now the SymInt placeholder should have no users - if not symint_node.users: - gm.graph.erase_node(symint_node) - nodes_modified = True - else: - logger.warning( - "SymInt placeholder %s still has %d users after processing: %s", - symint_node.name, - len(symint_node.users), - list(symint_node.users.keys()), - ) - - if nodes_modified: + nodes_to_erase.append(node) + for node in nodes_to_erase: + gm.graph.erase_node(node) + if nodes_to_erase: gm.graph.lint() gm.recompile() -def _update_submodule_to_compute_symint_locally( - gm: fx.GraphModule, - call_module_node: fx.Node, - symint_node: fx.Node, - tensor_node: fx.Node, - dim: int, -) -> None: - """ - Update a submodule call to compute SymInt locally instead of receiving it. - - This modifies: - 1. The call_module node's args to remove the SymInt and ensure tensor is passed - 2. The submodule to compute sym_size.int from the tensor instead of taking - SymInt as a parameter - """ - submod_name = call_module_node.target - submodule = getattr(gm, submod_name) - - # Find which argument position(s) correspond to symint_node and tensor_node - old_args = list(call_module_node.args) - symint_arg_indices = [i for i, arg in enumerate(old_args) if arg is symint_node] - tensor_arg_indices = [i for i, arg in enumerate(old_args) if arg is tensor_node] - - if not symint_arg_indices: - return - - # Get the submodule's placeholder nodes - submod_placeholders = [n for n in submodule.graph.nodes if n.op == "placeholder"] - - # Find the placeholder in submodule that corresponds to the SymInt - symint_placeholder_idx = symint_arg_indices[0] - if symint_placeholder_idx >= len(submod_placeholders): - logger.warning( - "SymInt arg index %d out of range for submodule %s with %d placeholders", - symint_placeholder_idx, - submod_name, - len(submod_placeholders), - ) - return - - symint_submod_placeholder = submod_placeholders[symint_placeholder_idx] - - # Find or ensure there's a placeholder for the tensor in the submodule - tensor_submod_placeholder = None - if tensor_arg_indices: - tensor_placeholder_idx = tensor_arg_indices[0] - if tensor_placeholder_idx < len(submod_placeholders): - tensor_submod_placeholder = submod_placeholders[tensor_placeholder_idx] - - if tensor_submod_placeholder is None: - # Tensor is not currently passed to this submodule, need to add it - # Add tensor to call_module args (at the end) - new_args = list(old_args) + [tensor_node] - # Also remove the SymInt from args - new_args = [ - arg for i, arg in enumerate(new_args) if i not in symint_arg_indices - ] - call_module_node.args = tuple(new_args) - - # Add new placeholder to submodule at the end - last_placeholder = submod_placeholders[-1] - with submodule.graph.inserting_after(last_placeholder): - tensor_submod_placeholder = submodule.graph.placeholder("tensor_for_symint") - if tensor_node.meta: - tensor_submod_placeholder.meta = tensor_node.meta.copy() - - else: - # Tensor is already passed, just need to update args to remove SymInt - new_args = [ - arg for i, arg in enumerate(old_args) if i not in symint_arg_indices - ] - call_module_node.args = tuple(new_args) - - # Find first node to insert sym_size before (after placeholders/get_attr) - insert_point = None - for node in submodule.graph.nodes: - if node.op not in ("placeholder", "get_attr"): - insert_point = node - break - - if insert_point is None: - logger.warning("Could not find insertion point in submodule %s", submod_name) - return - - # Create sym_size.int node in submodule - with submodule.graph.inserting_before(insert_point): - sym_size_node = submodule.graph.call_function( - torch.ops.aten.sym_size.int, - args=(tensor_submod_placeholder, dim), - ) - if symint_submod_placeholder.meta: - sym_size_node.meta = symint_submod_placeholder.meta.copy() - - # Replace all uses - - # Replace all uses of SymInt placeholder with sym_size node - symint_submod_placeholder.replace_all_uses_with(sym_size_node) - - # Remove the SymInt placeholder from submodule - submodule.graph.erase_node(symint_submod_placeholder) - - submodule.graph.lint() - submodule.recompile() - - compilation_start_time = 0.0 From c31b32d940bb4e8adb40322778f16f38cb500069 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Fri, 13 Feb 2026 15:52:43 -0800 Subject: [PATCH 22/25] Fix pre-commit issue --- vllm/compilation/backends.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 204e8ee1ddee..3def9710b147 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -567,9 +567,8 @@ def _remove_symint_placeholders(gm: fx.GraphModule) -> None: if submod_node.op != "placeholder": continue sub_ev = submod_node.meta.get("example_value") - if isinstance(sub_ev, torch.SymInt): - if not submod_node.users: - submodule.graph.erase_node(submod_node) + if isinstance(sub_ev, torch.SymInt) and not submod_node.users: + submodule.graph.erase_node(submod_node) submodule.graph.lint() submodule.recompile() From 4de21b8b952effca54a925c420d1422609501bb8 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 24 Feb 2026 17:03:21 -0800 Subject: [PATCH 23/25] Move sym_size.int to producer subgraph to reduce tensor boundary crossing --- tests/compile/test_graph_partition.py | 245 ++++++-------------------- vllm/compilation/backends.py | 204 ++------------------- 2 files changed, 69 insertions(+), 380 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 31e986c47bd7..4afab4f62130 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -12,7 +12,7 @@ from vllm.compilation.backends import ( split_graph, ) -from vllm.compilation.fx_utils import find_op_nodes, is_func +from vllm.compilation.fx_utils import find_op_nodes # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 @@ -199,17 +199,17 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: ] + ["output"] -def test_sym_size_moved_across_split_boundary(): +def test_sym_size_in_producer_subgraph(): """ - Test that sym_size operations (tensor.shape accesses) are moved to the same - subgraph as their consumers when they would otherwise cross subgraph boundaries. + Test that sym_size operations are assigned to the same subgraph as their + tensor operand (the producer), so only the SymInt result crosses the + split boundary — not the original tensor. - This prevents issues where PT2 doesn't fully support torch.Size as submodule - output when sym_size is in one subgraph and its consumer is in another. + This avoids passing tensors to consumer subgraphs just for .size() calls, + which would keep the tensor alive longer and increase memory usage. """ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - # Get shape before the split point - this creates sym_size ops batch_size = x.shape[0] hidden_size = x.shape[1] @@ -217,15 +217,12 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: z = torch.sigmoid(x) # Use the shape values after the split point - # Without the fix, this would fail because batch_size/hidden_size - # would be outputs of the first subgraph (as torch.Size/SymInt) reshaped_y = y.view(batch_size, hidden_size) return z + reshaped_y x = torch.randn(4, 8) y = torch.randn(32) # Will be reshaped to (4, 8) - # Use symbolic tracing to generate sym_size operations gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) # Verify the graph contains sym_size operations @@ -234,98 +231,75 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: "Test setup failed: graph should contain sym_size operations" ) - # Split on sigmoid which is the split point split_ops = ["aten::sigmoid"] split_gm, split_items = split_graph(gm, split_ops) - # Find the sigmoid (splitting) subgraph and the consumer subgraph + # Find producer subgraph (before sigmoid) and consumer subgraph (with view) splitting_items = [item for item in split_items if item.is_splitting_graph] assert len(splitting_items) == 1, "Should have exactly 1 splitting subgraph" - # KEY VERIFICATION: sym_size operations should be in the same subgraph - # as the view operation (their consumer), NOT in an earlier subgraph. - # This prevents torch.Size from crossing subgraph boundaries. - - # Find which subgraph contains the view operation view_subgraph = None for item in split_items: view_nodes = list(find_op_nodes(torch.ops.aten.view, item.graph.graph)) if view_nodes: view_subgraph = item break - assert view_subgraph is not None, "Should have a subgraph with view operation" - # Verify sym_size operations are in the SAME subgraph as view + # KEY VERIFICATION: sym_size should NOT be in the consumer (view) subgraph. + # It should be in the producer subgraph, with only the SymInt result + # crossing the boundary. sym_size_in_view_subgraph = list( find_op_nodes(torch.ops.aten.sym_size, view_subgraph.graph.graph) ) - assert len(sym_size_in_view_subgraph) > 0, ( - "sym_size operations should be in the same subgraph as their consumer " - "(view). This ensures torch.Size doesn't cross subgraph boundaries." - ) - - # Verify ordering within the consumer subgraph: sym_size before view - consumer_nodes = list(view_subgraph.graph.graph.nodes) - # CRITICAL VERIFICATION: The sigmoid (splitting/unsafe op) subgraph must - # have a LOWER graph_id than the consumer subgraph. Since subgraphs execute - # in order of graph_id, this proves that: - # 1. Sigmoid runs FIRST - # 2. sym_size + view run SECOND (in consumer subgraph) - # Therefore, sym_size now happens AFTER the unsafe op. - sigmoid_subgraph = splitting_items[0] - assert sigmoid_subgraph.graph_id < view_subgraph.graph_id, ( - f"Sigmoid subgraph (graph_id={sigmoid_subgraph.graph_id}) must execute " - f"before consumer subgraph (graph_id={view_subgraph.graph_id}). " - "This ensures sym_size happens AFTER the unsafe operation." + assert len(sym_size_in_view_subgraph) == 0, ( + "sym_size operations should NOT be in the consumer subgraph. " + "They should be in the producer subgraph so only the SymInt result " + "crosses the boundary, avoiding passing the tensor for .size() calls." ) - sym_size_indices = [ - i - for i, node in enumerate(consumer_nodes) - if is_func(node, torch.ops.aten.sym_size.int) - ] - view_indices = [ - i - for i, node in enumerate(consumer_nodes) - if is_func(node, torch.ops.aten.view.default) - ] + # Verify sym_size is in a producer subgraph (before sigmoid) + producer_subgraphs_with_sym_size = [] + for item in split_items: + if item.is_splitting_graph: + continue + if item.graph_id > splitting_items[0].graph_id: + continue + sym_size_nodes = list( + find_op_nodes(torch.ops.aten.sym_size, item.graph.graph) + ) + if sym_size_nodes: + producer_subgraphs_with_sym_size.append(item.submod_name) - max_sym_size_idx = max(sym_size_indices) - min_view_idx = min(view_indices) - assert max_sym_size_idx < min_view_idx, ( - f"sym_size (max index {max_sym_size_idx}) should come before " - f"view (min index {min_view_idx}) in the consumer subgraph." + assert len(producer_subgraphs_with_sym_size) > 0, ( + "sym_size operations should be in a producer subgraph (before sigmoid)." ) - # Verify functional correctness with same-shaped inputs + # Verify the consumer subgraph does NOT receive the original tensor x + # as an input (it should only receive y, z, and SymInt values) + view_placeholders = [ + n for n in view_subgraph.graph.graph.nodes if n.op == "placeholder" + ] + for ph in view_placeholders: + ev = ph.meta.get("example_value") + if isinstance(ev, torch.Tensor) and ev.shape == x.shape: + # This placeholder matches x's shape — it should be y or z, + # not x itself being passed just for .size() + pass # Allow tensors that are actually used for computation + + # Verify functional correctness output_original = gm(x, y) output_split = split_gm(x, y) assert torch.allclose(output_original, output_split), "Output mismatch after split" -def test_sym_size_with_torch_compile_and_mark_dynamic(): +def test_symint_crosses_split_boundary(): """ - Test handling of SymInt placeholders from torch.compile with mark_dynamic - across MULTIPLE split subgraphs. - - When using torch.compile + mark_dynamic, the captured graph has: - - SymInt placeholders (e.g., s77) as separate inputs - - Operations that use the SymInt directly (e.g., view([s77, 8])) - - standalone_compile / inductor expects only tensor inputs. split_graph must: - 1. Replace SymInt placeholder uses with sym_size calls on tensor inputs - 2. Replicate sym_size to ALL consumer subgraphs that need the dynamic size - 3. Remove unused SymInt placeholders from the final graph - - This test validates the complete SymInt -> sym_size pipeline with MULTIPLE - split boundaries to ensure sym_size is correctly replicated across subgraphs: - - Phase 1: SymInt placeholders exist in the captured graph - - Phase 2 & 3: split_graph handles SymInt replacement and removal - - Phase 4: sym_size.int exists in EACH consumer subgraph that needs it - - Phase 5: Functional correctness with original input - - Phase 6: Functional correctness with different batch size - - Phase 7: Validate multiple split subgraphs exist + Test that SymInt placeholders from torch.compile + mark_dynamic + cross split boundaries safely via split_module's natural threading. + + SymInt values are threaded through subgraphs by split_module and + handled correctly by inductor — no special replacement is needed. """ captured_graph = None @@ -335,32 +309,17 @@ def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModul return gm def model_fn(x: torch.Tensor) -> torch.Tensor: - # Get the dynamic shape before any splits batch_size = x.shape[0] hidden_size = x.shape[1] - - # First split point - sigmoid #1 x = torch.ops.aten.sigmoid.default(x) - - # Use dynamic size after first split - creates sym_size consumer x = x.clone().view(batch_size, hidden_size) - - # Second split point - sigmoid #2 x = torch.ops.aten.sigmoid.default(x) - - # Use dynamic size again after second split - another sym_size consumer x = x.clone().view(batch_size, hidden_size) - - # Third split point - sigmoid #3 x = torch.ops.aten.sigmoid.default(x) - - # Use dynamic size again after third split - yet another consumer x = x.clone().view(batch_size, hidden_size) - return x x = torch.randn(4, 8) - # Mark the first dimension as dynamic torch._dynamo.mark_dynamic(x, 0) compiled_fn = torch.compile(model_fn, backend=capturing_backend) @@ -368,121 +327,25 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: assert captured_graph is not None, "Graph should be captured by backend" - # ===== PHASE 1: Validate SymInt placeholders exist in captured graph ===== + # SymInt placeholders should exist in the captured graph symint_placeholders = [ node for node in captured_graph.graph.nodes if node.op == "placeholder" - and node.meta.get("example_value") is not None and isinstance(node.meta.get("example_value"), torch.SymInt) ] assert len(symint_placeholders) > 0, ( - "Phase 1 FAILED: Captured graph should have SymInt placeholders from " - "mark_dynamic. This is the prerequisite for testing the sym_size pipeline." + "Captured graph should have SymInt placeholders from mark_dynamic." ) - # Record original SymInt users for later validation - original_symint_users = {} - for symint_node in symint_placeholders: - users = [u for u in symint_node.users if u.op != "output"] - original_symint_users[symint_node.name] = [u.name for u in users] - - # ===== PHASE 2 & 3: split_graph handles SymInt replacement and removal ===== - # NOTE: split_graph modifies the input graph in-place! - # With 3 sigmoid operations, we expect 7 subgraphs: - # submod_0 (before sigmoid #1), submod_1 (sigmoid #1), - # submod_2 (between sigmoid #1 and #2), submod_3 (sigmoid #2), - # submod_4 (between sigmoid #2 and #3), submod_5 (sigmoid #3), - # submod_6 (after sigmoid #3) + # split_graph should handle SymInt placeholders without error split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) - # ===== PHASE 7: Validate multiple split subgraphs exist ===== - # Count splitting subgraphs (the sigmoid operations) + # Should have 3 splitting subgraphs (3 sigmoids) splitting_subgraphs = [item for item in split_items if item.is_splitting_graph] - assert len(splitting_subgraphs) == 3, ( - f"Phase 7 FAILED: Expected 3 splitting subgraphs (3 sigmoids), " - f"got {len(splitting_subgraphs)}" + f"Expected 3 splitting subgraphs (3 sigmoids), got {len(splitting_subgraphs)}" ) - # Note: Total subgraphs can be 6 or 7 depending on whether there are - # operations before the first sigmoid. With torch.compile, shape access - # operations may be folded differently, resulting in 6 subgraphs: - # submod_1 (sigmoid #1), submod_2 (compute), submod_3 (sigmoid #2), - # submod_4 (compute), submod_5 (sigmoid #3), submod_6 (compute) assert len(split_items) >= 6, ( - f"Phase 7 FAILED: Expected at least 6 total subgraphs " - f"(3 sigmoids + at least 3 compute blocks), got {len(split_items)}" - ) - - # ===== PHASE 3: Validate SymInt placeholders are removed from split_gm ===== - split_placeholders = [ - node for node in split_gm.graph.nodes if node.op == "placeholder" - ] - - remaining_symint_placeholders = [ - node - for node in split_placeholders - if node.meta.get("example_value") is not None - and isinstance(node.meta.get("example_value"), torch.SymInt) - ] - assert len(remaining_symint_placeholders) == 0, ( - f"Phase 3 FAILED: split_gm should not have SymInt placeholders after " - f"_remove_symint_placeholders. Found: " - f"{[n.name for n in remaining_symint_placeholders]}. " - "This means SymInt would be passed as input which inductor doesn't support." - ) - - # ===== PHASE 4: Validate sym_size.int exists in consumer subgraphs ===== - # Each non-splitting subgraph that uses dynamic sizes should have sym_size.int - # to compute the dynamic dimension locally from the tensor input. - total_sym_size_nodes = 0 - subgraphs_with_sym_size = [] - - for item in split_items: - sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph)) - - if sym_size_nodes: - total_sym_size_nodes += len(sym_size_nodes) - subgraphs_with_sym_size.append(item.submod_name) - - assert total_sym_size_nodes > 0, ( - "Phase 4 FAILED: No sym_size.int nodes found in any subgraph. " - "split_graph should replace SymInt placeholders with sym_size.int calls " - "that compute dynamic sizes from tensor inputs." - ) - - # With 3 split boundaries and dynamic size usage after each split, - # we expect sym_size to be replicated to multiple consumer subgraphs - assert len(subgraphs_with_sym_size) >= 3, ( - f"Phase 4 FAILED: sym_size should exist in consumer subgraphs. " - f"Found sym_size in {len(subgraphs_with_sym_size)} subgraphs: " - f"{subgraphs_with_sym_size}" - ) - - # ===== PHASE 5: Validate functional correctness ===== - # split_gm should work with tensor-only input (no SymInt) - output_split = split_gm(x) - - # Handle case where output is a tuple - if isinstance(output_split, tuple): - output_split = output_split[0] - - # For reference, run the model directly to get expected output - expected_output = model_fn(x) - - assert torch.allclose(expected_output, output_split), ( - "Phase 5 FAILED: Output mismatch after split. The sym_size pipeline " - "should preserve functional correctness." - ) - - # ===== PHASE 6: Validate with different batch size ===== - # The dynamic dimension should work with different sizes - x_different = torch.randn(8, 8) # Different batch size - output_different = split_gm(x_different) - if isinstance(output_different, tuple): - output_different = output_different[0] - expected_different = model_fn(x_different) - assert torch.allclose(expected_different, output_different), ( - "Phase 6 FAILED: Output mismatch with different batch size. " - "sym_size should correctly compute the dynamic dimension at runtime." + f"Expected at least 6 total subgraphs, got {len(split_items)}" ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3def9710b147..8f92dd10779b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -332,136 +332,24 @@ class SplitItem: graph: fx.GraphModule -def _is_symint_placeholder(node: fx.Node) -> bool: - """Check if a node is a SymInt placeholder (from torch.compile + mark_dynamic).""" - if node.op != "placeholder": - return False - - if not hasattr(torch.ops.aten, "sym_size"): - return False - - # Handle both torch.ops.aten.sym_size.int and sym_size.default - return node.target in ( - torch.ops.aten.sym_size, - torch.ops.aten.sym_size.int, - torch.ops.aten.sym_size.default, - ) - - -def _find_tensor_for_symint( - symint_value: torch.SymInt, - graph: fx.GraphModule, -) -> tuple[fx.Node, int] | None: - """ - Find a tensor placeholder with a dimension matching the given SymInt. - - Returns (tensor_node, dim) or None if no match found. - """ - for node in graph.graph.nodes: - if node.op != "placeholder": - continue - tensor_value = node.meta.get("example_value") - if tensor_value is None or not isinstance(tensor_value, torch.Tensor): - continue - if not hasattr(tensor_value, "shape"): - continue - - for dim, size in enumerate(tensor_value.shape): - # Match by identity - if size is symint_value: - return (node, dim) - # Match by underlying symbolic node - if ( - hasattr(size, "node") - and hasattr(symint_value, "node") - and size.node is symint_value.node - ): - return (node, dim) - # Match by string representation (fallback) - if str(size) == str(symint_value): - return (node, dim) - - return None - - -def _replace_symint_placeholders( - graph: fx.GraphModule, - node_to_subgraph_id: dict[fx.Node, int], -) -> None: - """ - Replace SymInt placeholder uses with sym_size calls. - - When using torch.compile with mark_dynamic, the captured graph has SymInt - placeholders (e.g., s77) as separate inputs. standalone_compile / inductor - expects only tensor inputs. - - This function creates sym_size.int nodes to replace SymInt placeholder uses. - - IMPORTANT: We do NOT delete the SymInt placeholders here because split_module - needs them for its symbol_to_node mapping. If we delete them, split_module - fails with KeyError when processing tensors whose shapes contain the symbol. - The placeholders are removed AFTER split_module by _remove_symint_placeholders. - """ - for node in list(graph.graph.nodes): - if not _is_symint_placeholder(node): - continue - - symint_value = node.meta.get("example_value") - if symint_value is None: - continue - - tensor_dim = _find_tensor_for_symint(symint_value, graph) - if tensor_dim is None: - logger.warning( - "Could not find tensor dimension for SymInt placeholder %s", - node.name, - ) - continue - - tensor_node, dim = tensor_dim - - # Get list of users before modifying - users_list = list(node.users.keys()) - if not users_list: - # No users, keep the placeholder for symbol_to_node mapping - continue - - # Create sym_size for each subgraph that uses this SymInt - subgraph_to_consumers: dict[int, list[fx.Node]] = {} - for user in users_list: - if user.op == "output": - continue - user_subgraph = node_to_subgraph_id.get(user, 0) - if user_subgraph not in subgraph_to_consumers: - subgraph_to_consumers[user_subgraph] = [] - subgraph_to_consumers[user_subgraph].append(user) - - for subgraph_id, consumer_list in subgraph_to_consumers.items(): - with graph.graph.inserting_before(consumer_list[0]): - sym_size_node = graph.graph.call_function( - torch.ops.aten.sym_size.int, - args=(tensor_node, dim), - ) - if node.meta: - sym_size_node.meta = node.meta.copy() - - node_to_subgraph_id[sym_size_node] = subgraph_id - - for consumer in consumer_list: - consumer.replace_input_with(node, sym_size_node) - - # NOTE: We do NOT delete the SymInt placeholder here! - # split_module needs it for symbol_to_node mapping. - # It will be removed by _remove_symint_placeholders after split_module. - - # NOTE: We skip lint()/recompile() here since split_module reads from - # graph.graph.nodes directly, not the forward() method. This avoids - # potential issues with graph state changes before split_module. - - def split_graph( graph: fx.GraphModule, splitting_ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: + # Move sym_size.int nodes to right after their tensor operand so they + # end up in the producer subgraph. This avoids passing the tensor to + # consumer subgraphs just for .size() calls — only the SymInt result + # crosses the boundary. + for node in list(graph.graph.nodes): + if (node.op == "call_function" + and node.target == torch.ops.aten.sym_size.int): + tensor_node = node.args[0] + with graph.graph.inserting_after(tensor_node): + new_node = graph.graph.call_function( + torch.ops.aten.sym_size.int, args=node.args) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + graph.graph.erase_node(node) + # split graph by ops subgraph_id = 0 node_to_subgraph_id: dict[fx.Node, int] = {} @@ -497,12 +385,6 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id - # Replace SymInt placeholders with sym_size.int calls and delete them. - # This is needed for torch.compile + mark_dynamic, where the captured graph - # has SymInt placeholders as separate inputs. standalone_compile / inductor - # expects only tensor inputs. - _replace_symint_placeholders(graph, node_to_subgraph_id) - # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and # the semantics of the graph will change when we @@ -511,9 +393,6 @@ def split_graph( graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True ) - # Remove any remaining SymInt placeholders after split_module. - _remove_symint_placeholders(split_gm) - outputs = [] names = [name for (name, module) in split_gm.named_modules()] @@ -534,59 +413,6 @@ def split_graph( return split_gm, outputs -def _remove_symint_placeholders(gm: fx.GraphModule) -> None: - """ - Remove SymInt placeholders from a GraphModule after split_module. - - Since _replace_symint_placeholders already replaced all SymInt users with - sym_size.int calls before split_module, the SymInt placeholders should have - no real consumers. However, split_module may still thread them through to - call_module nodes via its symbol_to_node tracking. This function removes - those spurious references and erases the SymInt placeholders so the final - graph only requires tensor inputs. - """ - nodes_to_erase = [] - for node in gm.graph.nodes: - if node.op != "placeholder": - continue - example_value = node.meta.get("example_value") - if not isinstance(example_value, torch.SymInt): - continue - - # Remove this SymInt from any call_module args that split_module - # may have threaded it into. - for user in list(node.users.keys()): - if user.op == "call_module": - new_args = tuple(a for a in user.args if a is not node) - user.args = new_args - - # Also remove the corresponding placeholder from the - # submodule so its signature stays in sync. - submodule = getattr(gm, user.target) - for submod_node in list(submodule.graph.nodes): - if submod_node.op != "placeholder": - continue - sub_ev = submod_node.meta.get("example_value") - if isinstance(sub_ev, torch.SymInt) and not submod_node.users: - submodule.graph.erase_node(submod_node) - submodule.graph.lint() - submodule.recompile() - - if node.users: - logger.warning( - "SymInt placeholder %s still has users: %s", - node.name, - list(node.users.keys()), - ) - continue - nodes_to_erase.append(node) - for node in nodes_to_erase: - gm.graph.erase_node(node) - if nodes_to_erase: - gm.graph.lint() - gm.recompile() - - compilation_start_time = 0.0 From 1ccdb56ddd9376c582977be0fe750a8f170dc0ae Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 24 Feb 2026 17:06:52 -0800 Subject: [PATCH 24/25] Fix and the unclear part of boundary crossing allowlist Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 4 +--- vllm/compilation/backends.py | 6 +++--- vllm/compilation/cuda_graph.py | 1 + 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 4afab4f62130..9837ff6143b5 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -265,9 +265,7 @@ def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: continue if item.graph_id > splitting_items[0].graph_id: continue - sym_size_nodes = list( - find_op_nodes(torch.ops.aten.sym_size, item.graph.graph) - ) + sym_size_nodes = list(find_op_nodes(torch.ops.aten.sym_size, item.graph.graph)) if sym_size_nodes: producer_subgraphs_with_sym_size.append(item.submod_name) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8f92dd10779b..7be3a5ddffd5 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -340,12 +340,12 @@ def split_graph( # consumer subgraphs just for .size() calls — only the SymInt result # crosses the boundary. for node in list(graph.graph.nodes): - if (node.op == "call_function" - and node.target == torch.ops.aten.sym_size.int): + if node.op == "call_function" and node.target == torch.ops.aten.sym_size.int: tensor_node = node.args[0] with graph.graph.inserting_after(tensor_node): new_node = graph.graph.call_function( - torch.ops.aten.sym_size.int, args=node.args) + torch.ops.aten.sym_size.int, args=node.args + ) new_node.meta = node.meta.copy() node.replace_all_uses_with(new_node) graph.graph.erase_node(node) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index b2e831b0d294..7ffa74d0d7e6 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -293,6 +293,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: # the weak ref of the output, so that pytorch can correctly # manage the memory during cuda graph capture return output + if self.is_debugging_mode: # check if the input addresses are the same new_input_addresses = [ From 581e02fb5cb5b0723a6f9a08cfdfa06962947db5 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 24 Feb 2026 17:24:30 -0800 Subject: [PATCH 25/25] Remove unused subgraph inputs after split_module Signed-off-by: Xiao Fu --- tests/compile/test_graph_partition.py | 96 +++++++++++++++++++++++++++ vllm/compilation/backends.py | 30 +++++++-- 2 files changed, 120 insertions(+), 6 deletions(-) diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 9837ff6143b5..bd9bb5de5b4c 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -347,3 +347,99 @@ def model_fn(x: torch.Tensor) -> torch.Tensor: assert len(split_items) >= 6, ( f"Expected at least 6 total subgraphs, got {len(split_items)}" ) + + +def test_unused_subgraph_inputs_removed(): + """ + Test that unused inputs threaded by split_module are removed from subgraphs. + + split_module threads values (e.g., SymInt) to all subgraphs in the chain, + even those that don't reference them. This test verifies that the cleanup + pass removes these unnecessary inputs, keeping subgraph signatures clean. + """ + + def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] + hidden_size = x.shape[1] + + z = torch.sigmoid(x) + + reshaped_y = y.view(batch_size, hidden_size) + return z + reshaped_y + + x = torch.randn(4, 8) + y = torch.randn(32) + gm = make_fx(model_fn, tracing_mode="symbolic")(x, y) + + split_gm, split_items = split_graph(gm, ["aten::sigmoid"]) + + # Every subgraph should only have inputs it actually uses + for item in split_items: + for node in item.graph.graph.nodes: + if node.op == "placeholder": + assert len(node.users) > 0, ( + f"Subgraph {item.submod_name} has unused input '{node.name}'. " + "Unused inputs should be removed by the cleanup pass." + ) + + # Verify functional correctness + output_original = gm(x, y) + output_split = split_gm(x, y) + assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +def test_unused_symint_inputs_removed_multi_split(): + """ + Test that with torch.compile + mark_dynamic and multiple split points, + SymInt inputs are removed from subgraphs that don't use them. + + split_module threads SymInt (e.g., s77) to every subgraph in the chain. + Splitting subgraphs (sigmoid) don't reference the SymInt, so it should + be stripped from their inputs. + """ + captured_graph = None + + def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: + nonlocal captured_graph + captured_graph = gm + return gm + + def model_fn(x: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] + hidden_size = x.shape[1] + x = torch.ops.aten.sigmoid.default(x) + x = x.clone().view(batch_size, hidden_size) + x = torch.ops.aten.sigmoid.default(x) + x = x.clone().view(batch_size, hidden_size) + return x + + x = torch.randn(4, 8) + torch._dynamo.mark_dynamic(x, 0) + + compiled_fn = torch.compile(model_fn, backend=capturing_backend) + compiled_fn(x) + + assert captured_graph is not None + + split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) + + # Splitting subgraphs (sigmoid) should NOT have SymInt inputs + for item in split_items: + if not item.is_splitting_graph: + continue + for node in item.graph.graph.nodes: + if node.op == "placeholder": + ev = node.meta.get("example_value") + assert not isinstance(ev, torch.SymInt), ( + f"Splitting subgraph {item.submod_name} has unused SymInt " + f"input '{node.name}'. SymInt should only appear in " + "subgraphs that reference it." + ) + + # All subgraphs: no unused inputs + for item in split_items: + for node in item.graph.graph.nodes: + if node.op == "placeholder": + assert len(node.users) > 0, ( + f"Subgraph {item.submod_name} has unused input '{node.name}'." + ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7be3a5ddffd5..882375481ed5 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -394,19 +394,37 @@ def split_graph( ) outputs = [] + parent_modified = False - names = [name for (name, module) in split_gm.named_modules()] - - for name in names: - if "." in name or name == "": - # recursive child module or the root module + for node in split_gm.graph.nodes: + if node.op != "call_module": continue + name = node.target module = getattr(split_gm, name) - graph_id = int(name.replace("submod_", "")) + + # Remove unused inputs that split_module may have threaded through + # unnecessarily (e.g., SymInt values passed to subgraphs that + # don't reference them). + placeholders = [n for n in module.graph.nodes if n.op == "placeholder"] + unused_indices = [i for i, ph in enumerate(placeholders) if not ph.users] + if unused_indices: + for i in reversed(unused_indices): + module.graph.erase_node(placeholders[i]) + node.args = tuple( + arg for i, arg in enumerate(node.args) if i not in unused_indices + ) + module.graph.lint() + module.recompile() + parent_modified = True + outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + if parent_modified: + split_gm.graph.lint() + split_gm.recompile() + # sort by integer graph_id, rather than string name outputs.sort(key=lambda x: x.graph_id)