|
2 | 2 | """Visualize an LBANN model's layer graph and save to file."""
|
3 | 3 |
|
4 | 4 | import argparse
|
| 5 | +import random |
5 | 6 | import re
|
6 | 7 | import graphviz
|
7 |
| -import google.protobuf.text_format |
8 | 8 | from lbann import lbann_pb2, layers_pb2
|
| 9 | +from lbann.proto import serialize |
| 10 | + |
| 11 | +# Pastel rainbow (slightly shuffled) from colorkit.co |
| 12 | +palette = [ |
| 13 | + '#ffffff', '#a0c4ff', '#ffadad', '#fdffb6', '#caffbf', '#9bf6ff', |
| 14 | + '#bdb2ff', '#ffc6ff', '#ffd6a5' |
| 15 | +] |
9 | 16 |
|
10 | 17 | # Parse command-line arguments
|
11 | 18 | parser = argparse.ArgumentParser(
|
|
17 | 24 | parser.add_argument('output',
|
18 | 25 | action='store',
|
19 | 26 | nargs='?',
|
20 |
| - default='graph.pdf', |
| 27 | + default='graph.dot', |
21 | 28 | type=str,
|
22 |
| - help='output file (default: graph.pdf)') |
| 29 | + help='output file (default: graph.dot)') |
23 | 30 | parser.add_argument('--file-format',
|
24 | 31 | action='store',
|
25 |
| - default='pdf', |
| 32 | + default='dot', |
26 | 33 | type=str,
|
27 |
| - help='output file format (default: pdf)', |
| 34 | + help='output file format (default: dot)', |
28 | 35 | metavar='FORMAT')
|
29 | 36 | parser.add_argument('--label-format',
|
30 | 37 | action='store',
|
|
39 | 46 | type=str,
|
40 | 47 | help='Graphviz visualization scheme (default: dot)',
|
41 | 48 | metavar='ENGINE')
|
| 49 | +parser.add_argument('--color-cross-grid', |
| 50 | + action='store_true', |
| 51 | + default=False, |
| 52 | + help='Highlight cross-grid edges') |
42 | 53 | args = parser.parse_args()
|
43 | 54 |
|
44 | 55 | # Strip extension from filename
|
|
51 | 62 | label_format = re.sub(r' |-|_', '', args.label_format.lower())
|
52 | 63 |
|
53 | 64 | # Read prototext file
|
54 |
| -proto = lbann_pb2.LbannPB() |
55 |
| -with open(args.input, 'r') as f: |
56 |
| - google.protobuf.text_format.Merge(f.read(), proto) |
| 65 | +proto = serialize.generic_load(args.input) |
57 | 66 | model = proto.model
|
58 | 67 |
|
59 | 68 | # Construct graphviz graph
|
|
62 | 71 | engine=args.graphviz_engine)
|
63 | 72 | graph.attr('node', shape='rect')
|
64 | 73 |
|
| 74 | +layer_to_grid_tag = {} |
| 75 | + |
65 | 76 | # Construct nodes in layer graph
|
66 | 77 | layer_types = (set(layers_pb2.Layer.DESCRIPTOR.fields_by_name.keys()) - set([
|
67 | 78 | 'name', 'parents', 'children', 'datatype', 'data_layout',
|
68 | 79 | 'device_allocation', 'weights', 'freeze', 'hint_layer', 'top', 'bottom',
|
69 |
| - 'type', 'motif_layer' |
| 80 | + 'type', 'motif_layer', 'parallel_strategy', 'grid_tag' |
70 | 81 | ]))
|
71 | 82 | for l in model.layer:
|
72 | 83 |
|
73 | 84 | # Determine layer type
|
74 |
| - type = '' |
| 85 | + ltype = '' |
75 | 86 | for _type in layer_types:
|
76 | 87 | if l.HasField(_type):
|
77 |
| - type = getattr(l, _type).DESCRIPTOR.name |
| 88 | + ltype = getattr(l, _type).DESCRIPTOR.name |
78 | 89 | break
|
79 | 90 |
|
| 91 | + # If operator layer, use operator type |
| 92 | + if ltype == 'OperatorLayer': |
| 93 | + url = l.operator_layer.ops[0].parameters.type_url |
| 94 | + ltype = url[url.rfind('.') + 1:] |
| 95 | + |
80 | 96 | # Construct node label
|
81 | 97 | label = ''
|
82 | 98 | if label_format == 'nameonly':
|
83 | 99 | label = l.name
|
84 | 100 | elif label_format == 'typeonly':
|
85 |
| - label = type |
| 101 | + label = ltype |
86 | 102 | elif label_format == 'typeandname':
|
87 |
| - label = '<{0}<br/>{1}>'.format(type, l.name) |
| 103 | + label = '<{0}<br/>{1}>'.format(ltype, l.name) |
88 | 104 | elif label_format == 'full':
|
89 | 105 | label = '<'
|
90 | 106 | for (index, line) in enumerate(str(l).strip().split('\n')):
|
|
94 | 110 | label += '>'
|
95 | 111 |
|
96 | 112 | # Add layer as layer graph node
|
97 |
| - graph.node(l.name, label=label) |
| 113 | + tag = l.grid_tag.value |
| 114 | + layer_to_grid_tag[l.name] = tag |
| 115 | + attrs = {} |
| 116 | + if tag != 0: |
| 117 | + attrs = dict(style='filled', fillcolor=palette[tag % len(palette)]) |
| 118 | + graph.node(l.name, label=label, **attrs) |
98 | 119 |
|
99 | 120 | # Add parent/child relationships as layer graph edges
|
100 | 121 | edges = set()
|
| 122 | +cross_grid_edges = set() |
101 | 123 | for l in model.layer:
|
102 |
| - edges.update([(p, l.name) for p in l.parents.split()]) |
103 |
| - edges.update([(l.name, c) for c in l.children.split()]) |
| 124 | + tag = layer_to_grid_tag[l.name] |
| 125 | + for p in l.parents: |
| 126 | + if tag != layer_to_grid_tag[p]: |
| 127 | + cross_grid_edges.add((p, l.name)) |
| 128 | + else: |
| 129 | + edges.add((p, l.name)) |
| 130 | + |
| 131 | + for c in l.children: |
| 132 | + if tag != layer_to_grid_tag[c]: |
| 133 | + cross_grid_edges.add((l.name, c)) |
| 134 | + else: |
| 135 | + edges.add((l.name, c)) |
| 136 | + |
104 | 137 | graph.edges(edges)
|
| 138 | +if args.color_cross_grid: |
| 139 | + for u, v in cross_grid_edges: |
| 140 | + graph.edge(u, v, color='red') |
| 141 | +else: |
| 142 | + graph.edges(cross_grid_edges) |
105 | 143 |
|
106 | 144 | # Save to file
|
107 | 145 | graph.render(filename=filename, cleanup=True, format=file_format)
|
0 commit comments