Skip to content

Commit 992b8c3

Browse files
authored
Refactor graph visualization (#165)
* Move graph_utils.py to relaying * Add tracing wrapper of YOLOModule * Fixing lint * Make IR visualizer as a Class * Reset the unseen_ops * Refactor the dot initialization * Adopt the f-strings * Make get_node_names explicit * Adding minor utils * Move predictions to attributes * Move self.predictions above the make_graph * Fix status initialization * Add get_trace_module unit-test * Update tvm relay and graph visualize notebooks * Add more assertions in unit-test * Add docstring for get_trace_module
1 parent f8560bd commit 992b8c3

File tree

7 files changed

+935
-488
lines changed

7 files changed

+935
-488
lines changed

notebooks/export-relay-inference-tvm.ipynb

+31-133
Large diffs are not rendered by default.

notebooks/model-graph-visualization.ipynb

+606-170
Large diffs are not rendered by default.

test/test_relaying.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torch.jit._trace import TopLevelTracedModule
2+
3+
from yolort.models import yolov5s
4+
from yolort.relaying import get_trace_module
5+
6+
7+
def test_get_trace_module():
8+
model_func = yolov5s(pretrained=True)
9+
script_module = get_trace_module(model_func, input_shape=(416, 320))
10+
assert isinstance(script_module, TopLevelTracedModule)
11+
assert script_module.code is not None

yolort/relaying/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
2+
from .trace_wrapper import get_trace_module

yolort/relaying/ir_visualizer.py

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright (c) 2021, Zhiqiang Wang
2+
# Copyright (c) 2020, Thomas Viehmann
3+
"""
4+
Visualizing JIT Modules
5+
6+
Modified from https://github.com/t-vi/pytorch-tvmisc/tree/master/hacks
7+
with license under the CC-BY-SA 4.0.
8+
9+
Please link to Thomas's blog post or the original github source (linked from the
10+
blog post) with the attribution notice.
11+
"""
12+
from collections import OrderedDict
13+
from graphviz import Digraph
14+
15+
16+
class TorchScriptVisualizer:
17+
def __init__(self, module):
18+
19+
self.module = module
20+
21+
self.unseen_ops = {
22+
'prim::ListConstruct', 'prim::ListUnpack',
23+
'prim::TupleConstruct', 'prim::TupleUnpack',
24+
'aten::Int',
25+
'aten::unbind', 'aten::detach',
26+
'aten::contiguous', 'aten::to',
27+
'aten::unsqueeze', 'aten::squeeze',
28+
'aten::index', 'aten::slice', 'aten::select',
29+
'aten::constant_pad_nd',
30+
'aten::size', 'aten::split_with_sizes',
31+
'aten::expand_as', 'aten::expand',
32+
'aten::_shape_as_tensor',
33+
}
34+
# probably also partially absorbing ops. :/
35+
self.absorbing_ops = ('aten::size', 'aten::_shape_as_tensor')
36+
37+
def render(
38+
self,
39+
classes_to_visit={'YOLO', 'YOLOHead'},
40+
format='svg',
41+
labelloc='t',
42+
attr_size='8,7',
43+
):
44+
self.clean_status()
45+
46+
model_input = next(self.module.graph.inputs())
47+
model_type = self.get_node_names(model_input)[-1]
48+
dot = Digraph(
49+
format=format,
50+
graph_attr={'label': model_type, 'labelloc': labelloc},
51+
)
52+
self.make_graph(self.module, dot=dot, classes_to_visit=classes_to_visit)
53+
54+
dot.attr(size=attr_size)
55+
return dot
56+
57+
def clean_status(self):
58+
self._seen_edges = set()
59+
self._seen_input_names = set()
60+
self._predictions = OrderedDict()
61+
62+
@staticmethod
63+
def get_node_names(node):
64+
return node.type().str().split('.')
65+
66+
@staticmethod
67+
def get_function_name(node):
68+
return node.type().__repr__().split('.')[-1]
69+
70+
def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None,
71+
classes_to_visit=None, classes_found=None):
72+
graph = module.graph
73+
74+
self_input = next(graph.inputs())
75+
self._predictions[self_input] = (set(), set()) # Stand for `input` and `op` respectively
76+
77+
for nr, i in enumerate(list(graph.inputs())[1:]):
78+
name = f'{prefix}input_{i.debugName()}'
79+
self._predictions[i] = {name}, set()
80+
dot.node(name, shape='ellipse')
81+
if input_preds is not None:
82+
pred, op = input_preds[nr]
83+
self.make_edges(pred, f'input_{name}', name, op, parent_dot)
84+
85+
for node in graph.nodes():
86+
node_inputs = list(node.inputs())
87+
only_first_ops = {'aten::expand_as'}
88+
rel_inp_end = 1 if node.kind() in only_first_ops else None
89+
90+
relevant_inputs = [i for i in node_inputs[:rel_inp_end] if is_relevant_type(i.type())]
91+
relevant_outputs = [o for o in node.outputs() if is_relevant_type(o.type())]
92+
93+
if node.kind() == 'prim::CallMethod':
94+
node_names = self.get_node_names(node_inputs[0])
95+
fq_submodule_name = '.'.join([nc for nc in node_names if not nc.startswith('__')])
96+
submodule_type = node_names[-1]
97+
submodule_name = find_name(node_inputs[0], self_input)
98+
name = f'{prefix}.{node.output().debugName()}'
99+
label = f'{prefix}{submodule_name} ({submodule_type})'
100+
101+
if classes_found is not None:
102+
classes_found.add(fq_submodule_name)
103+
104+
if ((classes_to_visit is None and (not fq_submodule_name.startswith('torch.nn')
105+
or fq_submodule_name.startswith('torch.nn.modules.container')))
106+
or (classes_to_visit is not None and (submodule_type in classes_to_visit
107+
or fq_submodule_name in classes_to_visit))):
108+
109+
# go into subgraph
110+
sub_prefix = f'{prefix}{submodule_name}.'
111+
112+
for i, o in enumerate(node.outputs()):
113+
self._predictions[o] = {f'{sub_prefix}output_{i}'}, set()
114+
115+
with dot.subgraph(name=f'cluster_{name}') as sub_dot:
116+
sub_dot.attr(label=label)
117+
sub_module = module
118+
for k in submodule_name.split('.'):
119+
sub_module = getattr(sub_module, k)
120+
121+
self.make_graph(
122+
sub_module,
123+
dot=sub_dot,
124+
parent_dot=dot,
125+
prefix=sub_prefix,
126+
input_preds=[self._predictions[i] for i in node_inputs[1:]],
127+
classes_to_visit=classes_to_visit,
128+
classes_found=classes_found,
129+
)
130+
131+
else:
132+
dot.node(name, label=label, shape='box')
133+
for i in relevant_inputs:
134+
pred, op = self._predictions[i]
135+
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
136+
for o in node.outputs():
137+
self._predictions[o] = {name}, set()
138+
139+
elif node.kind() == 'prim::CallFunction':
140+
name = f'{prefix}.{node.output().debugName()}'
141+
fun_name = self.get_function_name(node_inputs[0])
142+
dot.node(name, label=fun_name, shape='box')
143+
for i in relevant_inputs:
144+
pred, op = self._predictions[i]
145+
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
146+
for o in node.outputs():
147+
self._predictions[o] = {name}, set()
148+
149+
else:
150+
label = node.kind().split('::')[-1].rstrip('_')
151+
pred, op = set(), set()
152+
for i in relevant_inputs:
153+
apred, aop = self._predictions[i]
154+
pred |= apred
155+
op |= aop
156+
157+
if node.kind() in self.absorbing_ops:
158+
pred, op = set(), set()
159+
elif (len(relevant_inputs) > 0
160+
and len(relevant_outputs) > 0
161+
and node.kind() not in self.unseen_ops):
162+
op.add(label)
163+
for o in node.outputs():
164+
self._predictions[o] = pred, op
165+
166+
for i, o in enumerate(graph.outputs()):
167+
name = f'{prefix}output_{i}'
168+
dot.node(name, shape='ellipse')
169+
pred, op = self._predictions[o]
170+
self.make_edges(pred, f'input_{name}', name, op, dot)
171+
172+
def add_edge(self, dot, n1, n2):
173+
if (n1, n2) not in self._seen_edges:
174+
self._seen_edges.add((n1, n2))
175+
dot.edge(n1, n2)
176+
177+
def make_edges(self, preds, input_name, name, op, edge_dot):
178+
if len(op) > 0:
179+
if input_name not in self._seen_input_names:
180+
self._seen_input_names.add(input_name)
181+
label_lines = [[]]
182+
line_len = 0
183+
for w in op:
184+
if line_len >= 20:
185+
label_lines.append([])
186+
line_len = 0
187+
label_lines[-1].append(w)
188+
line_len += len(w) + 1
189+
190+
edge_dot.node(
191+
input_name,
192+
label='\n'.join([' '.join(w) for w in label_lines]),
193+
shape='box',
194+
style='rounded',
195+
)
196+
for p in preds:
197+
self.add_edge(edge_dot, p, input_name)
198+
self.add_edge(edge_dot, input_name, name)
199+
else:
200+
for p in preds:
201+
self.add_edge(edge_dot, p, name)
202+
203+
204+
def find_name(layer_input, self_input, suffix=None):
205+
if layer_input == self_input:
206+
return suffix
207+
cur = layer_input.node().s('name')
208+
if suffix is not None:
209+
cur = f'{cur}.{suffix}'
210+
of = next(layer_input.node().inputs())
211+
return find_name(of, self_input, suffix=cur)
212+
213+
214+
def is_relevant_type(t):
215+
kind = t.kind()
216+
if kind == 'TensorType':
217+
return True
218+
if kind in ('ListType', 'OptionalType'):
219+
return is_relevant_type(t.getElementType())
220+
if kind == 'TupleType':
221+
return any([is_relevant_type(tt) for tt in t.elements()])
222+
return False

yolort/relaying/trace_wrapper.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
2+
from typing import Dict, Tuple, Callable
3+
4+
import torch
5+
from torch import nn, Tensor
6+
7+
8+
def dict_to_tuple(out_dict: Dict[str, Tensor]) -> Tuple:
9+
"""
10+
Convert the model output dictionary to tuple format.
11+
"""
12+
if "masks" in out_dict.keys():
13+
return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
14+
return out_dict["boxes"], out_dict["scores"], out_dict["labels"]
15+
16+
17+
class TraceWrapper(nn.Module):
18+
"""
19+
This is a wrapper for `torch.jit.trace`, as there are some scenarios
20+
where `torch.jit.script` support is limited.
21+
"""
22+
def __init__(self, model):
23+
super().__init__()
24+
self.model = model
25+
26+
def forward(self, x):
27+
out = self.model(x)
28+
return dict_to_tuple(out[0])
29+
30+
31+
@torch.no_grad()
32+
def get_trace_module(
33+
model_func: Callable[..., nn.Module],
34+
input_shape: Tuple[int, int] = (416, 416),
35+
):
36+
"""
37+
Get the tarcing of a given model function.
38+
39+
Example:
40+
41+
>>> from yolort.models import yolov5s
42+
>>> from yolort.relaying.trace_wrapper import get_trace_module
43+
>>>
44+
>>> model = yolov5s(pretrained=True)
45+
>>> tracing_module = get_trace_module(model)
46+
>>> print(tracing_module.code)
47+
def forward(self,
48+
x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
49+
_0, _1, _2, = (self.model).forward(x, )
50+
return (_0, _1, _2)
51+
52+
Args:
53+
model_func (Callable): The model function to be traced.
54+
input_shape (Tuple[int, int]): Shape size of the input image.
55+
"""
56+
model = TraceWrapper(model_func)
57+
model.eval()
58+
59+
dummy_input = torch.rand(1, 3, *input_shape)
60+
trace_module = torch.jit.trace(model, dummy_input)
61+
trace_module.eval()
62+
63+
return trace_module

0 commit comments

Comments
 (0)