Skip to content

[compile][graph_partition] Remove unused subgraph inputs after split_module#35251

Closed
fxdawnn wants to merge 30 commits intovllm-project:mainfrom
fxdawnn:graph_redundancy_removal
Closed

[compile][graph_partition] Remove unused subgraph inputs after split_module#35251
fxdawnn wants to merge 30 commits intovllm-project:mainfrom
fxdawnn:graph_redundancy_removal

Conversation

@fxdawnn
Copy link
Copy Markdown
Contributor

@fxdawnn fxdawnn commented Feb 25, 2026

Follow-up of #32747

Summary

split_module threads values through all subgraphs in the chain between producer and consumer, even subgraphs that don't reference them. For example, a SymInt value needed by a compute subgraph gets threaded through intermediate sigmoid subgraphs that never use it. These unused inputs add noise to subgraph signatures.

This PR adds a cleanup pass in split_graph that strips unused placeholder inputs from submodules and their corresponding args from call_module nodes in the parent graph.

Before (without cleanup):

submod_1 [SPLIT]:   inputs: x (Tensor), s77 (SymInt)    # s77 unused
                    ops: sigmoid
submod_2 [COMPUTE]: inputs: y (Tensor), s77 (SymInt)    # s77 used
                    ops: clone, view
submod_3 [SPLIT]:   inputs: z (Tensor), s77 (SymInt)    # s77 unused
                    ops: sigmoid

After (with cleanup):

submod_1 [SPLIT]:   inputs: x (Tensor)                  # s77 removed
                    ops: sigmoid
submod_2 [COMPUTE]: inputs: y (Tensor), s77 (SymInt)    # s77 kept (used)
                    ops: clone, view
submod_3 [SPLIT]:   inputs: z (Tensor)                  # s77 removed
                    ops: sigmoid

Change

The cleanup is merged into the existing output-building loop in split_graph (+14 net lines in backends.py):

placeholders = [
    n for n in module.graph.nodes if n.op == "placeholder"
]
unused_indices = [i for i, ph in enumerate(placeholders) if not ph.users]
if unused_indices:
    for i in reversed(unused_indices):
        module.graph.erase_node(placeholders[i])
    node.args = tuple(
        arg for i, arg in enumerate(node.args)
        if i not in unused_indices
    )
    module.graph.lint()
    module.recompile()

For each submodule, it finds placeholders with no users, erases them, and removes the corresponding positional args from the parent's call_module node to keep signatures in sync.

Test plan

  • test_unused_subgraph_inputs_removed — single split with make_fx symbolic tracing, verifies every subgraph placeholder has at least one user after cleanup
  • test_unused_symint_inputs_removed_multi_split — multi split with torch.compile + mark_dynamic, verifies sigmoid subgraphs don't receive SymInt inputs they don't use (depends on sym_size producer PR)
  • Existing tests continue to pass

Graph Outlook

  CASE 1: Single consumer (one split boundary)
======================================================================

  Original function:
def model_fn(x, y):
    batch_size = x.shape[0]
    hidden_size = x.shape[1]
    z = torch.sigmoid(x)            # <-- split point
    reshaped_y = y.view(batch_size, hidden_size)
    return z + reshaped_y

  Captured FX graph (make_fx symbolic):
    sigmoid = sigmoid.default(x_1)
    sym_size_int = sym_size.int(x_1, 0)
    sym_size_int_1 = sym_size.int(x_1, 1)
    view = view.default(y_1, [sym_size_int, sym_size_int_1])
    add = add.Tensor(sigmoid, view)

  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
  WITHOUT CLEANUP: unused inputs remain
  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

  submod_0 [COMPUTE]  << sym_size HERE
    inputs: x_1 (NoneType)
    ops:    int, int

  submod_1 [SPLIT]
    inputs: x_1 (NoneType)
    ops:    default

  submod_2 [COMPUTE]
    inputs: y_1 (NoneType), sym_size_int_2 (NoneType), sym_size_int_3 (NoneType), sigmoid (NoneType)
    ops:    default, Tensor

  Total unused inputs: 0

  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
  WITH CLEANUP: unused inputs removed
  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

  submod_0 [COMPUTE]  << sym_size HERE
    inputs: x_1 (NoneType)
    ops:    int, int

  submod_1 [SPLIT]
    inputs: x_1 (NoneType)
    ops:    default

  submod_2 [COMPUTE]
    inputs: y_1 (NoneType), sym_size_int_2 (NoneType), sym_size_int_3 (NoneType), sigmoid (NoneType)
    ops:    default, Tensor

  Total unused inputs: 0


======================================================================
  CASE 2: Multiple consumers (three split boundaries)
======================================================================

  Original function:
def model_fn(x):
    batch_size = x.shape[0]
    hidden_size = x.shape[1]
    x = torch.sigmoid(x)            # <-- split point 1
    x = x.clone().view(batch_size, hidden_size)
    x = torch.sigmoid(x)            # <-- split point 2
    x = x.clone().view(batch_size, hidden_size)
    x = torch.sigmoid(x)            # <-- split point 3
    x = x.clone().view(batch_size, hidden_size)
    return x

  Captured FX graph (make_fx symbolic):
    sigmoid = sigmoid.default(x_1)
    clone = clone.default(sigmoid)
    sym_size_int = sym_size.int(x_1, 0)
    sym_size_int_1 = sym_size.int(x_1, 1)
    view = view.default(clone, [sym_size_int, sym_size_int_1])
    sigmoid_1 = sigmoid.default(view)
    clone_1 = clone.default(sigmoid_1)
    view_1 = view.default(clone_1, [sym_size_int, sym_size_int_1])
    sigmoid_2 = sigmoid.default(view_1)
    clone_2 = clone.default(sigmoid_2)
    view_2 = view.default(clone_2, [sym_size_int, sym_size_int_1])

  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
  WITHOUT CLEANUP: SymInt threaded to all subgraphs
  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

  submod_0 [COMPUTE]  << sym_size HERE
    inputs: x_1 (NoneType)
    ops:    int, int

  submod_1 [SPLIT]
    inputs: x_1 (NoneType)
    ops:    default

  submod_2 [COMPUTE]
    inputs: sigmoid (NoneType), sym_size_int_2 (NoneType), sym_size_int_3 (NoneType)
    ops:    default, default

  submod_3 [SPLIT]
    inputs: view (NoneType)
    ops:    default

  submod_4 [COMPUTE]
    inputs: sigmoid_1 (NoneType), sym_size_int_2 (NoneType), sym_size_int_3 (NoneType)
    ops:    default, default

  submod_5 [SPLIT]
    inputs: view_1 (NoneType)
    ops:    default

  submod_6 [COMPUTE]
    inputs: sigmoid_2 (NoneType), sym_size_int_2 (NoneType), sym_size_int_3 (NoneType)
    ops:    default, default

  Total unused inputs: 0

  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
  WITH CLEANUP: unused SymInt inputs stripped
  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

  submod_0 [COMPUTE]  << sym_size HERE
    inputs: x_1 (NoneType)
    ops:    int, int

  submod_1 [SPLIT]
    inputs: x_1 (NoneType)
    ops:    default

  submod_2 [COMPUTE]
    inputs: sigmoid (NoneType), sym_size_int_2 (NoneType), sym_size_int_3 (NoneType)
    ops:    default, default

  submod_3 [SPLIT]
    inputs: view (NoneType)
    ops:    default

  submod_4 [COMPUTE]
    inputs: sigmoid_1 (NoneType), sym_size_int_2 (NoneType), sym_size_int_3 (NoneType)
    ops:    default, default

  submod_5 [SPLIT]
    inputs: view_1 (NoneType)
    ops:    default

  submod_6 [COMPUTE]
    inputs: sigmoid_2 (NoneType), sym_size_int_2 (NoneType), sym_size_int_3 (NoneType)
    ops:    default, default

  Total unused inputs: 0


======================================================================
  CASE 3: torch.compile + mark_dynamic (SymInt passthrough)
======================================================================

  Original function:
def model_fn(x):
    batch_size = x.shape[0]
    hidden_size = x.shape[1]
    x = sigmoid(x)                  # <-- split point 1
    x = x.clone().view(batch_size, hidden_size)
    x = sigmoid(x)                  # <-- split point 2
    x = x.clone().view(batch_size, hidden_size)
    return x

  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
  WITHOUT CLEANUP: s77 (SymInt) in sigmoid subgraphs
  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

  submod_1 [SPLIT]
    inputs: l_x_ (Tensor (s77, 8)), s77 (SymInt) [UNUSED]
    ops:    default

  submod_2 [COMPUTE]
    inputs: x (Tensor (s77, 8)), s77 (SymInt)
    ops:    clone, view

  submod_3 [SPLIT]
    inputs: x_1 (Tensor (s77, 8)), s77 (SymInt) [UNUSED]
    ops:    default

  submod_4 [COMPUTE]
    inputs: x_2 (Tensor (s77, 8)), s77 (SymInt)
    ops:    clone, view

  Total unused inputs: 2

  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
  WITH CLEANUP: s77 removed from subgraphs that don't use it
  - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

  submod_1 [SPLIT]
    inputs: l_x_ (Tensor (s77, 8))
    ops:    default

  submod_2 [COMPUTE]
    inputs: x (Tensor (s77, 8)), s77 (SymInt)
    ops:    clone, view

  submod_3 [SPLIT]
    inputs: x_1 (Tensor (s77, 8))
    ops:    default

  submod_4 [COMPUTE]
    inputs: x_2 (Tensor (s77, 8)), s77 (SymInt)
    ops:    clone, view

  Total unused inputs: 0

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

fxdawnn and others added 29 commits January 20, 2026 13:37
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: Xiao Fu <xiaofu@meta.com>
Signed-off-by: Xiao Fu <xiaofu@meta.com>
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: Xiao Fu <xiaofu@meta.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Xiao <31429901+fxdawnn@users.noreply.github.com>
Signed-off-by: Xiao Fu <xiaofu@meta.com>
Signed-off-by: Xiao Fu <xiaofu@meta.com>
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: Xiao Fu <xiaofu@meta.com>
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: Xiao Fu <xiaofu@meta.com>
Signed-off-by: Xiao Fu <xiaofu@meta.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Xiao <31429901+fxdawnn@users.noreply.github.com>
Signed-off-by: Xiao Fu <xiaofu@meta.com>
Signed-off-by: Xiao Fu <xiaofu@meta.com>
Signed-off-by: Xiao Fu <xiaofu@meta.com>
Signed-off-by: Xiao Fu <xiaofu@meta.com>
@fxdawnn fxdawnn requested a review from zou3519 as a code owner February 25, 2026 01:21
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 25, 2026

⚠️ The sha of the head commit of this PR conflicts with #32747. Mergify cannot evaluate rules on this PR. ⚠️

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a graph pass to move torch.ops.aten.sym_size.int operations to be immediately after their tensor operand. This is an optimization to ensure that when the graph is split, the sym_size operation is in the producer subgraph, avoiding the need to pass the full tensor to consumer subgraphs just for a size lookup. The change is implemented in vllm/compilation/backends.py and is accompanied by new tests.

My review found a couple of issues in the new tests: an unused pytest fixture and a verification loop that has no effect. I've left specific comments on these.

It's also worth noting that the PR title and description seem to describe a different feature - removing unused inputs from subgraphs after splitting. The implemented code is about repositioning sym_size nodes before splitting. You may want to update the PR description to reflect the actual changes.

Comment on lines +21 to +27
@pytest.fixture
def vllm_compile_env(monkeypatch):
"""Set up vLLM compilation environment variables for testing."""
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
monkeypatch.setenv("VLLM_LOGGING_LEVEL", "debug")
yield
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This pytest fixture vllm_compile_env is defined but does not appear to be used by any of the tests in this file. If it's not needed, it should be removed to avoid clutter. If it is intended to be used, please apply it to the relevant tests.

Comment on lines +281 to +286
for ph in view_placeholders:
ev = ph.meta.get("example_value")
if isinstance(ev, torch.Tensor) and ev.shape == x.shape:
# This placeholder matches x's shape — it should be y or z,
# not x itself being passed just for .size()
pass # Allow tensors that are actually used for computation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This loop for verifying that the original tensor x is not passed to the consumer subgraph is currently ineffective as it only contains a pass statement and performs no assertions. This can be misleading as it looks like a verification is being done. The primary assertion assert len(sym_size_in_view_subgraph) == 0 already covers the main goal of this test. I recommend removing this loop to avoid confusion, unless a reliable assertion can be added.

Signed-off-by: Xiao Fu <xiaofu@meta.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fxdawnn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@fxdawnn
Copy link
Copy Markdown
Contributor Author

fxdawnn commented Mar 2, 2026

@fxdawnn fxdawnn closed this Mar 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant