-
-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[Bugfix] Eliminate tuple inputs to submodules in graph partitioning #28533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
06c5967
[Bugfix] Eliminate tuple inputs to submodules in graph partitioning
gmagogsfm 4dffe34
fix lint
gmagogsfm 287c005
fix unit tests to split graph correctly
gmagogsfm beabd84
fix pre-existing lint on type annotation
gmagogsfm 6b1fed4
fix comment to include standalone_compile
gmagogsfm 8b04734
always assert input_node is already in a subgraph
gmagogsfm 152328d
support getitem on tensor case
gmagogsfm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import operator | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torch.fx.experimental.proxy_tensor import make_fx | ||
|
|
||
| from vllm.compilation.backends import split_graph | ||
|
|
||
|
|
||
| def test_getitem_moved_to_producer_subgraph(): | ||
| """ | ||
| Test that getitem operations are moved to the same subgraph as their input, | ||
| preventing tuple inputs to submodules. | ||
| """ | ||
|
|
||
| def model_fn(x: torch.Tensor) -> torch.Tensor: | ||
| # torch.split returns a tuple, creating real getitem operations | ||
| # Should become first submodule that produces tuple | ||
| chunks = torch.split(x, x.shape[0] // 2, dim=0) | ||
|
|
||
| # Following ops should become second submodule that consumes tuple | ||
| result_0 = torch.relu(chunks[0]) | ||
| result_1 = torch.relu(chunks[1]) | ||
| return torch.cat([result_0, result_1], dim=0) | ||
|
|
||
| x = torch.randn(4, 3) | ||
| gm = make_fx(model_fn)(x) | ||
|
|
||
| has_getitem = any( | ||
| node.op == "call_function" and node.target == operator.getitem | ||
| for node in gm.graph.nodes | ||
| ) | ||
| assert has_getitem, "Test setup failed: graph should contain getitem operations" | ||
|
|
||
| # Split on tuple producer aten::split | ||
| split_ops = ["aten::split.Tensor"] | ||
| split_gm, split_items = split_graph(gm, split_ops) | ||
| assert len(split_items) == 2, "Graph should be split into 2 submodules" | ||
|
|
||
| for split_item in split_items: | ||
| submodule = split_item.graph | ||
|
|
||
| getitem_on_placeholder = [] | ||
| for node in submodule.graph.nodes: | ||
| if ( | ||
| node.op == "call_function" | ||
| and node.target == operator.getitem | ||
| and node.args[0].op == "placeholder" | ||
| ): | ||
| getitem_on_placeholder.append(node) | ||
|
|
||
| assert len(getitem_on_placeholder) == 0, ( | ||
| f"Submodule {split_item.submod_name} has getitem operations on " | ||
| f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. " | ||
| "This means tuple inputs were not properly eliminated." | ||
| ) | ||
|
|
||
| new_x = torch.randn(4, 3) | ||
| output_original = gm(new_x) | ||
| output_split = split_gm(new_x) | ||
|
|
||
| assert torch.allclose(output_original, output_split), "Output mismatch" | ||
|
|
||
|
|
||
| def test_no_tuple_inputs_with_multiple_consumers(): | ||
| """ | ||
| Test that when a tuple is consumed by multiple split operations, | ||
| getitem operations are properly moved to avoid tuple inputs. | ||
| """ | ||
|
|
||
| def model_fn(x: torch.Tensor) -> torch.Tensor: | ||
| # torch.split returns a tuple, creating real getitem operations | ||
| # Should become first submodule that produces tuple | ||
| chunks = torch.split(x, x.shape[0] // 2, dim=0) | ||
|
|
||
| # These should become second submodule consuming tuple | ||
| result_1 = torch.relu(chunks[0]) | ||
| result_2 = torch.relu(chunks[1]) | ||
|
|
||
| # Artificial graph splitting point to create another | ||
| # independent submodule that consumes tuple later | ||
| # This would become the third submodule | ||
| result_1 = torch.sigmoid(result_1) | ||
|
|
||
| # Fourth submodule that consumes tuple | ||
| result = torch.cat([chunks[0], chunks[1], result_1, result_2]) | ||
| return result | ||
|
|
||
| x = torch.randn(4, 3) | ||
| gm = make_fx(model_fn)(x) | ||
|
|
||
| has_getitem = any( | ||
| node.op == "call_function" and node.target == operator.getitem | ||
| for node in gm.graph.nodes | ||
| ) | ||
| assert has_getitem, "Test setup failed: graph should contain getitem operations" | ||
|
|
||
| split_ops = ["aten::split.Tensor", "aten::sigmoid"] | ||
| split_gm, split_items = split_graph(gm, split_ops) | ||
| assert len(split_items) == 4, "Graph should be split into 4 submodules" | ||
|
|
||
| for split_item in split_items: | ||
| submodule = split_item.graph | ||
|
|
||
| for node in submodule.graph.nodes: | ||
| if ( | ||
| node.op == "call_function" | ||
| and node.target == operator.getitem | ||
| and node.args[0].op == "placeholder" | ||
| ): | ||
| pytest.fail( | ||
| f"Submodule {split_item.submod_name} has getitem on " | ||
| f"placeholder {node.args[0].name}, indicating it receives " | ||
| "a tuple input" | ||
| ) | ||
|
|
||
| new_x = torch.randn(4, 3) | ||
| output_original = gm(new_x) | ||
| output_split = split_gm(new_x) | ||
|
|
||
| assert torch.allclose(output_original, output_split), "Output mismatch after split" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.