Skip to content

Commit 44ca1a5

Browse files
authored
Naive graph decomposer (#338)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * naive_graph_decomposer
1 parent 221b77e commit 44ca1a5

File tree

5 files changed

+128
-11
lines changed

5 files changed

+128
-11
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
# input model path
3+
MODEL_PATH_IN_SAMPLES=/timm/resnet18
4+
# output model path
5+
OUTPUT_DIR=/tmp/naive_decompose_workspace
6+
7+
mkdir -p $OUTPUT_DIR
8+
# extract subgraph 0-8, 8-16
9+
export GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS=0,8,16
10+
export GRAPH_NET_EXTRACT_WORKSPACE=$OUTPUT_DIR
11+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
12+
os.path.dirname(graph_net.__file__))")
13+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py

graph_net/torch/decompose_util.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,16 @@ def convert_to_submodules_graph(
3333
split_positions = [
3434
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
3535
]
36-
submodule_ranges = [
37-
(start, end)
36+
range_idx2submodule_body_nodes = [
37+
submodules_body_nodes[start:end]
3838
for i in range(len(split_positions) - 1)
3939
for start in [split_positions[i]]
4040
for end in [split_positions[i + 1]]
4141
if end > start
4242
]
4343

4444
def get_body_nodes(range_idx):
45-
start, end = submodule_ranges[range_idx]
46-
return submodules_body_nodes[start:end]
45+
return range_idx2submodule_body_nodes[range_idx]
4746

4847
def get_name2sub_submodule():
4948
used_module_names = set(
@@ -55,15 +54,28 @@ def get_name2sub_submodule():
5554
if name in used_module_names
5655
}
5756

58-
for range_idx in range(len(submodule_ranges)):
59-
start_node_idx, end_node_idx = submodule_ranges[range_idx]
57+
def get_start_node_idx(range_idx):
58+
start_node = get_body_nodes(range_idx)[0]
59+
for i, node in enumerate(original_gm.graph.nodes):
60+
if node == start_node:
61+
return i
62+
raise NotImplementedError("Dead code.")
63+
64+
def get_end_node_idx(range_idx):
65+
last_node = get_body_nodes(range_idx)[-1]
66+
for i, node in enumerate(original_gm.graph.nodes):
67+
if node == last_node:
68+
return i + 1
69+
raise NotImplementedError("Dead code.")
70+
71+
for range_idx in range(len(range_idx2submodule_body_nodes)):
6072
(
6173
submodule_input_nodes,
6274
submodule_output_nodes,
6375
) = _get_submodule_inputs_and_outputs(
6476
original_gm=original_gm,
65-
start_node_idx=(num_placeholders + start_node_idx),
66-
end_node_idx=(num_placeholders + end_node_idx),
77+
start_node_idx=get_start_node_idx(range_idx),
78+
end_node_idx=get_end_node_idx(range_idx),
6779
)
6880

6981
def get_input_nodes(range_idx):

graph_net/torch/extractor.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,13 @@ def try_rename_placeholder(node):
125125
return gm.forward
126126

127127

128-
def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False):
128+
def extract(
129+
name,
130+
dynamic=True,
131+
mut_graph_codes=None,
132+
placeholder_auto_rename=False,
133+
custom_extractor_path=None,
134+
):
129135
"""
130136
Extract computation graphs from PyTorch nn.Module.
131137
The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
@@ -194,9 +200,19 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
194200
>>>
195201
"""
196202

203+
def get_graph_extractor_cls():
204+
if custom_extractor_path is None:
205+
return GraphExtractor
206+
import importlib.util as imp
207+
208+
spec = imp.spec_from_file_location("graph_extractor", custom_extractor_path)
209+
graph_extractor = imp.module_from_spec(spec)
210+
spec.loader.exec_module(graph_extractor)
211+
return graph_extractor.GraphExtractor
212+
197213
def wrapper(model: torch.nn.Module):
198214
assert isinstance(model, torch.nn.Module), f"{type(model)=}"
199-
extractor = GraphExtractor(
215+
extractor = get_graph_extractor_cls()(
200216
name, dynamic, mut_graph_codes, placeholder_auto_rename
201217
)
202218
# return torch.compile(backend=extractor, dynamic=dynamic)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import torch
3+
import json
4+
import shutil
5+
from typing import Union, Callable
6+
from graph_net.torch import utils
7+
from graph_net.torch.decompose_util import convert_to_submodules_graph
8+
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
9+
10+
11+
class GraphExtractor:
12+
def __init__(
13+
self, name, dynamic, mut_graph_codes=None, placeholder_auto_rename=False
14+
):
15+
self.subgraph_counter = 0
16+
self.name = name
17+
self.dynamic = dynamic
18+
self.mut_graph_codes = mut_graph_codes
19+
self.placeholder_auto_rename = placeholder_auto_rename
20+
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
21+
if not self.workspace_path:
22+
raise EnvironmentError(
23+
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
24+
)
25+
split_pos_str = os.environ.get("GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS")
26+
if split_pos_str is None:
27+
raise EnvironmentError(
28+
"Environment variable 'GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS' is not set."
29+
)
30+
self.split_positions = [int(pos) for pos in split_pos_str.split(",")]
31+
32+
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
33+
return convert_to_submodules_graph(
34+
gm,
35+
split_positions=self.split_positions,
36+
submodule_hook=self.get_naive_decomposer_extractor,
37+
group_head_and_tail=False,
38+
)
39+
40+
def get_naive_decomposer_extractor(self, submodule, seq_no):
41+
return NaiveDecomposerExtractor(self, submodule, seq_no)
42+
43+
44+
class NaiveDecomposerExtractor(torch.nn.Module):
45+
def __init__(self, parent_graph_extractor, submodule, seq_no):
46+
super().__init__()
47+
self.parent_graph_extractor = parent_graph_extractor
48+
self.submodule = submodule
49+
self.seq_no = seq_no
50+
self.extracted = False
51+
name = f"{parent_graph_extractor.name}_{self.seq_no}"
52+
self.builtin_extractor = BuiltinGraphExtractor(
53+
name=name,
54+
dynamic=False,
55+
mut_graph_codes=[],
56+
placeholder_auto_rename=parent_graph_extractor.placeholder_auto_rename,
57+
)
58+
59+
def forward(self, *args):
60+
if not self.extracted:
61+
self.builtin_extractor(self.submodule, args)
62+
self.extracted = True
63+
return self.submodule(*args)

graph_net/torch/single_device_runner.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,13 @@ def main(args):
5959
print(f"{model_path=}")
6060
if args.enable_extract:
6161
assert args.extract_name is not None
62-
kwargs = dict(name=args.extract_name, dynamic=False, **dump_graph_options)
62+
63+
kwargs = dict(
64+
name=args.extract_name,
65+
dynamic=False,
66+
custom_extractor_path=args.custom_extractor_path,
67+
**dump_graph_options,
68+
)
6369
model = extract(**kwargs)(model)
6470

6571
inputs_params = utils.load_converted_from_text(f"{model_path}")
@@ -110,5 +116,12 @@ def main(args):
110116
default=None,
111117
help="Extracted graph's name",
112118
)
119+
parser.add_argument(
120+
"--custom-extractor-path",
121+
type=str,
122+
required=False,
123+
default=None,
124+
help="Custom extractor python file path",
125+
)
113126
args = parser.parse_args()
114127
main(args=args)

0 commit comments

Comments
 (0)