|
| 1 | +# Copyright (c) 2021, Zhiqiang Wang |
| 2 | +# Copyright (c) 2020, Thomas Viehmann |
| 3 | +""" |
| 4 | +Visualizing JIT Modules |
| 5 | +
|
| 6 | +Modified from https://github.com/t-vi/pytorch-tvmisc/tree/master/hacks |
| 7 | +with license under the CC-BY-SA 4.0. |
| 8 | +
|
| 9 | +Please link to Thomas's blog post or the original github source (linked from the |
| 10 | +blog post) with the attribution notice. |
| 11 | +""" |
| 12 | +from collections import OrderedDict |
| 13 | +from graphviz import Digraph |
| 14 | + |
| 15 | + |
| 16 | +class TorchScriptVisualizer: |
| 17 | + def __init__(self, module): |
| 18 | + |
| 19 | + self.module = module |
| 20 | + |
| 21 | + self.unseen_ops = { |
| 22 | + 'prim::ListConstruct', 'prim::ListUnpack', |
| 23 | + 'prim::TupleConstruct', 'prim::TupleUnpack', |
| 24 | + 'aten::Int', |
| 25 | + 'aten::unbind', 'aten::detach', |
| 26 | + 'aten::contiguous', 'aten::to', |
| 27 | + 'aten::unsqueeze', 'aten::squeeze', |
| 28 | + 'aten::index', 'aten::slice', 'aten::select', |
| 29 | + 'aten::constant_pad_nd', |
| 30 | + 'aten::size', 'aten::split_with_sizes', |
| 31 | + 'aten::expand_as', 'aten::expand', |
| 32 | + 'aten::_shape_as_tensor', |
| 33 | + } |
| 34 | + # probably also partially absorbing ops. :/ |
| 35 | + self.absorbing_ops = ('aten::size', 'aten::_shape_as_tensor') |
| 36 | + |
| 37 | + def render( |
| 38 | + self, |
| 39 | + classes_to_visit={'YOLO', 'YOLOHead'}, |
| 40 | + format='svg', |
| 41 | + labelloc='t', |
| 42 | + attr_size='8,7', |
| 43 | + ): |
| 44 | + self.clean_status() |
| 45 | + |
| 46 | + model_input = next(self.module.graph.inputs()) |
| 47 | + model_type = self.get_node_names(model_input)[-1] |
| 48 | + dot = Digraph( |
| 49 | + format=format, |
| 50 | + graph_attr={'label': model_type, 'labelloc': labelloc}, |
| 51 | + ) |
| 52 | + self.make_graph(self.module, dot=dot, classes_to_visit=classes_to_visit) |
| 53 | + |
| 54 | + dot.attr(size=attr_size) |
| 55 | + return dot |
| 56 | + |
| 57 | + def clean_status(self): |
| 58 | + self._seen_edges = set() |
| 59 | + self._seen_input_names = set() |
| 60 | + self._predictions = OrderedDict() |
| 61 | + |
| 62 | + @staticmethod |
| 63 | + def get_node_names(node): |
| 64 | + return node.type().str().split('.') |
| 65 | + |
| 66 | + @staticmethod |
| 67 | + def get_function_name(node): |
| 68 | + return node.type().__repr__().split('.')[-1] |
| 69 | + |
| 70 | + def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None, |
| 71 | + classes_to_visit=None, classes_found=None): |
| 72 | + graph = module.graph |
| 73 | + |
| 74 | + self_input = next(graph.inputs()) |
| 75 | + self._predictions[self_input] = (set(), set()) # Stand for `input` and `op` respectively |
| 76 | + |
| 77 | + for nr, i in enumerate(list(graph.inputs())[1:]): |
| 78 | + name = f'{prefix}input_{i.debugName()}' |
| 79 | + self._predictions[i] = {name}, set() |
| 80 | + dot.node(name, shape='ellipse') |
| 81 | + if input_preds is not None: |
| 82 | + pred, op = input_preds[nr] |
| 83 | + self.make_edges(pred, f'input_{name}', name, op, parent_dot) |
| 84 | + |
| 85 | + for node in graph.nodes(): |
| 86 | + node_inputs = list(node.inputs()) |
| 87 | + only_first_ops = {'aten::expand_as'} |
| 88 | + rel_inp_end = 1 if node.kind() in only_first_ops else None |
| 89 | + |
| 90 | + relevant_inputs = [i for i in node_inputs[:rel_inp_end] if is_relevant_type(i.type())] |
| 91 | + relevant_outputs = [o for o in node.outputs() if is_relevant_type(o.type())] |
| 92 | + |
| 93 | + if node.kind() == 'prim::CallMethod': |
| 94 | + node_names = self.get_node_names(node_inputs[0]) |
| 95 | + fq_submodule_name = '.'.join([nc for nc in node_names if not nc.startswith('__')]) |
| 96 | + submodule_type = node_names[-1] |
| 97 | + submodule_name = find_name(node_inputs[0], self_input) |
| 98 | + name = f'{prefix}.{node.output().debugName()}' |
| 99 | + label = f'{prefix}{submodule_name} ({submodule_type})' |
| 100 | + |
| 101 | + if classes_found is not None: |
| 102 | + classes_found.add(fq_submodule_name) |
| 103 | + |
| 104 | + if ((classes_to_visit is None and (not fq_submodule_name.startswith('torch.nn') |
| 105 | + or fq_submodule_name.startswith('torch.nn.modules.container'))) |
| 106 | + or (classes_to_visit is not None and (submodule_type in classes_to_visit |
| 107 | + or fq_submodule_name in classes_to_visit))): |
| 108 | + |
| 109 | + # go into subgraph |
| 110 | + sub_prefix = f'{prefix}{submodule_name}.' |
| 111 | + |
| 112 | + for i, o in enumerate(node.outputs()): |
| 113 | + self._predictions[o] = {f'{sub_prefix}output_{i}'}, set() |
| 114 | + |
| 115 | + with dot.subgraph(name=f'cluster_{name}') as sub_dot: |
| 116 | + sub_dot.attr(label=label) |
| 117 | + sub_module = module |
| 118 | + for k in submodule_name.split('.'): |
| 119 | + sub_module = getattr(sub_module, k) |
| 120 | + |
| 121 | + self.make_graph( |
| 122 | + sub_module, |
| 123 | + dot=sub_dot, |
| 124 | + parent_dot=dot, |
| 125 | + prefix=sub_prefix, |
| 126 | + input_preds=[self._predictions[i] for i in node_inputs[1:]], |
| 127 | + classes_to_visit=classes_to_visit, |
| 128 | + classes_found=classes_found, |
| 129 | + ) |
| 130 | + |
| 131 | + else: |
| 132 | + dot.node(name, label=label, shape='box') |
| 133 | + for i in relevant_inputs: |
| 134 | + pred, op = self._predictions[i] |
| 135 | + self.make_edges(pred, prefix + i.debugName(), name, op, dot) |
| 136 | + for o in node.outputs(): |
| 137 | + self._predictions[o] = {name}, set() |
| 138 | + |
| 139 | + elif node.kind() == 'prim::CallFunction': |
| 140 | + name = f'{prefix}.{node.output().debugName()}' |
| 141 | + fun_name = self.get_function_name(node_inputs[0]) |
| 142 | + dot.node(name, label=fun_name, shape='box') |
| 143 | + for i in relevant_inputs: |
| 144 | + pred, op = self._predictions[i] |
| 145 | + self.make_edges(pred, prefix + i.debugName(), name, op, dot) |
| 146 | + for o in node.outputs(): |
| 147 | + self._predictions[o] = {name}, set() |
| 148 | + |
| 149 | + else: |
| 150 | + label = node.kind().split('::')[-1].rstrip('_') |
| 151 | + pred, op = set(), set() |
| 152 | + for i in relevant_inputs: |
| 153 | + apred, aop = self._predictions[i] |
| 154 | + pred |= apred |
| 155 | + op |= aop |
| 156 | + |
| 157 | + if node.kind() in self.absorbing_ops: |
| 158 | + pred, op = set(), set() |
| 159 | + elif (len(relevant_inputs) > 0 |
| 160 | + and len(relevant_outputs) > 0 |
| 161 | + and node.kind() not in self.unseen_ops): |
| 162 | + op.add(label) |
| 163 | + for o in node.outputs(): |
| 164 | + self._predictions[o] = pred, op |
| 165 | + |
| 166 | + for i, o in enumerate(graph.outputs()): |
| 167 | + name = f'{prefix}output_{i}' |
| 168 | + dot.node(name, shape='ellipse') |
| 169 | + pred, op = self._predictions[o] |
| 170 | + self.make_edges(pred, f'input_{name}', name, op, dot) |
| 171 | + |
| 172 | + def add_edge(self, dot, n1, n2): |
| 173 | + if (n1, n2) not in self._seen_edges: |
| 174 | + self._seen_edges.add((n1, n2)) |
| 175 | + dot.edge(n1, n2) |
| 176 | + |
| 177 | + def make_edges(self, preds, input_name, name, op, edge_dot): |
| 178 | + if len(op) > 0: |
| 179 | + if input_name not in self._seen_input_names: |
| 180 | + self._seen_input_names.add(input_name) |
| 181 | + label_lines = [[]] |
| 182 | + line_len = 0 |
| 183 | + for w in op: |
| 184 | + if line_len >= 20: |
| 185 | + label_lines.append([]) |
| 186 | + line_len = 0 |
| 187 | + label_lines[-1].append(w) |
| 188 | + line_len += len(w) + 1 |
| 189 | + |
| 190 | + edge_dot.node( |
| 191 | + input_name, |
| 192 | + label='\n'.join([' '.join(w) for w in label_lines]), |
| 193 | + shape='box', |
| 194 | + style='rounded', |
| 195 | + ) |
| 196 | + for p in preds: |
| 197 | + self.add_edge(edge_dot, p, input_name) |
| 198 | + self.add_edge(edge_dot, input_name, name) |
| 199 | + else: |
| 200 | + for p in preds: |
| 201 | + self.add_edge(edge_dot, p, name) |
| 202 | + |
| 203 | + |
| 204 | +def find_name(layer_input, self_input, suffix=None): |
| 205 | + if layer_input == self_input: |
| 206 | + return suffix |
| 207 | + cur = layer_input.node().s('name') |
| 208 | + if suffix is not None: |
| 209 | + cur = f'{cur}.{suffix}' |
| 210 | + of = next(layer_input.node().inputs()) |
| 211 | + return find_name(of, self_input, suffix=cur) |
| 212 | + |
| 213 | + |
| 214 | +def is_relevant_type(t): |
| 215 | + kind = t.kind() |
| 216 | + if kind == 'TensorType': |
| 217 | + return True |
| 218 | + if kind in ('ListType', 'OptionalType'): |
| 219 | + return is_relevant_type(t.getElementType()) |
| 220 | + if kind == 'TupleType': |
| 221 | + return any([is_relevant_type(tt) for tt in t.elements()]) |
| 222 | + return False |
0 commit comments