Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_graph_partition.py
- pytest -v -s compile/test_config.py
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
Expand Down
124 changes: 124 additions & 0 deletions tests/compile/test_graph_partition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import operator

import pytest
import torch
from torch.fx.experimental.proxy_tensor import make_fx

from vllm.compilation.backends import split_graph


def test_getitem_moved_to_producer_subgraph():
"""
Test that getitem operations are moved to the same subgraph as their input,
preventing tuple inputs to submodules.
"""

def model_fn(x: torch.Tensor) -> torch.Tensor:
# torch.split returns a tuple, creating real getitem operations
# Should become first submodule that produces tuple
chunks = torch.split(x, x.shape[0] // 2, dim=0)

# Following ops should become second submodule that consumes tuple
result_0 = torch.relu(chunks[0])
result_1 = torch.relu(chunks[1])
return torch.cat([result_0, result_1], dim=0)

x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)

has_getitem = any(
node.op == "call_function" and node.target == operator.getitem
for node in gm.graph.nodes
)
assert has_getitem, "Test setup failed: graph should contain getitem operations"

# Split on tuple producer aten::split
split_ops = ["aten::split.Tensor"]
split_gm, split_items = split_graph(gm, split_ops)
assert len(split_items) == 2, "Graph should be split into 2 submodules"

for split_item in split_items:
submodule = split_item.graph

getitem_on_placeholder = []
for node in submodule.graph.nodes:
if (
node.op == "call_function"
and node.target == operator.getitem
and node.args[0].op == "placeholder"
):
getitem_on_placeholder.append(node)

assert len(getitem_on_placeholder) == 0, (
f"Submodule {split_item.submod_name} has getitem operations on "
f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. "
"This means tuple inputs were not properly eliminated."
)

new_x = torch.randn(4, 3)
output_original = gm(new_x)
output_split = split_gm(new_x)

assert torch.allclose(output_original, output_split), "Output mismatch"


def test_no_tuple_inputs_with_multiple_consumers():
"""
Test that when a tuple is consumed by multiple split operations,
getitem operations are properly moved to avoid tuple inputs.
"""

def model_fn(x: torch.Tensor) -> torch.Tensor:
# torch.split returns a tuple, creating real getitem operations
# Should become first submodule that produces tuple
chunks = torch.split(x, x.shape[0] // 2, dim=0)

# These should become second submodule consuming tuple
result_1 = torch.relu(chunks[0])
result_2 = torch.relu(chunks[1])

# Artificial graph splitting point to create another
# independent submodule that consumes tuple later
# This would become the third submodule
result_1 = torch.sigmoid(result_1)

# Fourth submodule that consumes tuple
result = torch.cat([chunks[0], chunks[1], result_1, result_2])
return result

x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)

has_getitem = any(
node.op == "call_function" and node.target == operator.getitem
for node in gm.graph.nodes
)
assert has_getitem, "Test setup failed: graph should contain getitem operations"

split_ops = ["aten::split.Tensor", "aten::sigmoid"]
split_gm, split_items = split_graph(gm, split_ops)
assert len(split_items) == 4, "Graph should be split into 4 submodules"

for split_item in split_items:
submodule = split_item.graph

for node in submodule.graph.nodes:
if (
node.op == "call_function"
and node.target == operator.getitem
and node.args[0].op == "placeholder"
):
pytest.fail(
f"Submodule {split_item.submod_name} has getitem on "
f"placeholder {node.args[0].name}, indicating it receives "
"a tuple input"
)

new_x = torch.randn(4, 3)
output_original = gm(new_x)
output_split = split_gm(new_x)

assert torch.allclose(output_original, output_split), "Output mismatch after split"
17 changes: 15 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import dataclasses
import hashlib
import operator
import os
import pprint
import time
Expand Down Expand Up @@ -307,12 +308,24 @@ def split_graph(
) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops
subgraph_id = 0
node_to_subgraph_id = {}
split_op_graphs = []
node_to_subgraph_id: dict[fx.Node, int] = {}
split_op_graphs: list[int] = []
for node in graph.graph.nodes:
if node.op in ("output", "placeholder"):
continue

# Check if this is a getitem operation on a node from an earlier subgraph.
# If so, assign it to the same subgraph as its input to avoid passing entire
# tuple as input to submodules, which is against standalone_compile and
# AoTAutograd input requirement.
if node.op == "call_function" and node.target == operator.getitem:
# Assign this getitem to the same subgraph as its input
input_node = node.args[0]
if input_node.op != "placeholder":
assert input_node in node_to_subgraph_id
node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
continue

if should_split(node, splitting_ops):
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
Expand Down