From 8bd267364ae807c71322566ab995e995cd3901bb Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 15 Sep 2021 13:11:59 -0400 Subject: [PATCH] Move predictions to attributes --- yolort/relaying/ir_visualizer.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/yolort/relaying/ir_visualizer.py b/yolort/relaying/ir_visualizer.py index c464f5a6..0e1f40c0 100644 --- a/yolort/relaying/ir_visualizer.py +++ b/yolort/relaying/ir_visualizer.py @@ -9,6 +9,7 @@ Please link to Thomas's blog post or the original github source (linked from the blog post) with the attribution notice. """ +from collections import OrderedDict from graphviz import Digraph @@ -18,11 +19,12 @@ def __init__(self, module): self.module = module self.seen_edges = set() self.seen_input_names = set() + self.predictions = OrderedDict() self.unseen_ops = { - 'aten::Int', 'prim::ListConstruct', 'prim::ListUnpack', 'prim::TupleConstruct', 'prim::TupleUnpack', + 'aten::Int', 'aten::unbind', 'aten::detach', 'aten::contiguous', 'aten::to', 'aten::unsqueeze', 'aten::squeeze', @@ -53,14 +55,13 @@ def get_function_name(node): def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None, classes_to_visit=None, classes_found=None): graph = module.graph - preds = {} self_input = next(graph.inputs()) - preds[self_input] = (set(), set()) # inps, ops + self.predictions[self_input] = (set(), set()) # inputs, ops for nr, i in enumerate(list(graph.inputs())[1:]): name = f'{prefix}input_{i.debugName()}' - preds[i] = {name}, set() + self.predictions[i] = {name}, set() dot.node(name, shape='ellipse') if input_preds is not None: pred, op = input_preds[nr] @@ -103,36 +104,36 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N dot=sub_dot, parent_dot=dot, prefix=sub_prefix, - input_preds=[preds[i] for i in node_inputs[1:]], + input_preds=[self.predictions[i] for i in node_inputs[1:]], classes_to_visit=classes_to_visit, classes_found=classes_found, ) for i, o in enumerate(node.outputs()): - preds[o] = {f'{sub_prefix}output_{i}'}, set() + self.predictions[o] = {f'{sub_prefix}output_{i}'}, set() else: dot.node(name, label=label, shape='box') for i in relevant_inputs: - pred, op = preds[i] + pred, op = self.predictions[i] self.make_edges(pred, prefix + i.debugName(), name, op, dot) for o in node.outputs(): - preds[o] = {name}, set() + self.predictions[o] = {name}, set() elif node.kind() == 'prim::CallFunction': name = f'{prefix}.{node.output().debugName()}' fun_name = self.get_function_name(node_inputs[0]) dot.node(name, label=fun_name, shape='box') for i in relevant_inputs: - pred, op = preds[i] + pred, op = self.predictions[i] self.make_edges(pred, prefix + i.debugName(), name, op, dot) for o in node.outputs(): - preds[o] = {name}, set() + self.predictions[o] = {name}, set() else: label = node.kind().split('::')[-1].rstrip('_') pred, op = set(), set() for i in relevant_inputs: - apred, aop = preds[i] + apred, aop = self.predictions[i] pred |= apred op |= aop @@ -143,12 +144,12 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N and node.kind() not in self.unseen_ops): op.add(label) for o in node.outputs(): - preds[o] = pred, op + self.predictions[o] = pred, op for i, o in enumerate(graph.outputs()): name = f'{prefix}output_{i}' dot.node(name, shape='ellipse') - pred, op = preds[o] + pred, op = self.predictions[o] self.make_edges(pred, f'input_{name}', name, op, dot) def add_edge(self, dot, n1, n2):