Skip to content

Commit

Permalink
Move predictions to attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 16, 2021
1 parent fbcb672 commit 8bd2673
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions yolort/relaying/ir_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Please link to Thomas's blog post or the original github source (linked from the
blog post) with the attribution notice.
"""
from collections import OrderedDict
from graphviz import Digraph


Expand All @@ -18,11 +19,12 @@ def __init__(self, module):
self.module = module
self.seen_edges = set()
self.seen_input_names = set()
self.predictions = OrderedDict()

self.unseen_ops = {
'aten::Int',
'prim::ListConstruct', 'prim::ListUnpack',
'prim::TupleConstruct', 'prim::TupleUnpack',
'aten::Int',
'aten::unbind', 'aten::detach',
'aten::contiguous', 'aten::to',
'aten::unsqueeze', 'aten::squeeze',
Expand Down Expand Up @@ -53,14 +55,13 @@ def get_function_name(node):
def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None,
classes_to_visit=None, classes_found=None):
graph = module.graph
preds = {}

self_input = next(graph.inputs())
preds[self_input] = (set(), set()) # inps, ops
self.predictions[self_input] = (set(), set()) # inputs, ops

for nr, i in enumerate(list(graph.inputs())[1:]):
name = f'{prefix}input_{i.debugName()}'
preds[i] = {name}, set()
self.predictions[i] = {name}, set()
dot.node(name, shape='ellipse')
if input_preds is not None:
pred, op = input_preds[nr]
Expand Down Expand Up @@ -103,36 +104,36 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
dot=sub_dot,
parent_dot=dot,
prefix=sub_prefix,
input_preds=[preds[i] for i in node_inputs[1:]],
input_preds=[self.predictions[i] for i in node_inputs[1:]],
classes_to_visit=classes_to_visit,
classes_found=classes_found,
)

for i, o in enumerate(node.outputs()):
preds[o] = {f'{sub_prefix}output_{i}'}, set()
self.predictions[o] = {f'{sub_prefix}output_{i}'}, set()
else:
dot.node(name, label=label, shape='box')
for i in relevant_inputs:
pred, op = preds[i]
pred, op = self.predictions[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
preds[o] = {name}, set()
self.predictions[o] = {name}, set()

elif node.kind() == 'prim::CallFunction':
name = f'{prefix}.{node.output().debugName()}'
fun_name = self.get_function_name(node_inputs[0])
dot.node(name, label=fun_name, shape='box')
for i in relevant_inputs:
pred, op = preds[i]
pred, op = self.predictions[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
preds[o] = {name}, set()
self.predictions[o] = {name}, set()

else:
label = node.kind().split('::')[-1].rstrip('_')
pred, op = set(), set()
for i in relevant_inputs:
apred, aop = preds[i]
apred, aop = self.predictions[i]
pred |= apred
op |= aop

Expand All @@ -143,12 +144,12 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
and node.kind() not in self.unseen_ops):
op.add(label)
for o in node.outputs():
preds[o] = pred, op
self.predictions[o] = pred, op

for i, o in enumerate(graph.outputs()):
name = f'{prefix}output_{i}'
dot.node(name, shape='ellipse')
pred, op = preds[o]
pred, op = self.predictions[o]
self.make_edges(pred, f'input_{name}', name, op, dot)

def add_edge(self, dot, n1, n2):
Expand Down

0 comments on commit 8bd2673

Please sign in to comment.