Skip to content

Commit 3e245cf

Browse files
committed
Visualization: support new models, operator layers, grid tag coloring, and cross-grid edge highlighting
1 parent f9e43d0 commit 3e245cf

File tree

1 file changed

+54
-16
lines changed

1 file changed

+54
-16
lines changed

Diff for: scripts/viz.py

+54-16
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22
"""Visualize an LBANN model's layer graph and save to file."""
33

44
import argparse
5+
import random
56
import re
67
import graphviz
7-
import google.protobuf.text_format
88
from lbann import lbann_pb2, layers_pb2
9+
from lbann.proto import serialize
10+
11+
# Pastel rainbow (slightly shuffled) from colorkit.co
12+
palette = [
13+
'#ffffff', '#a0c4ff', '#ffadad', '#fdffb6', '#caffbf', '#9bf6ff',
14+
'#bdb2ff', '#ffc6ff', '#ffd6a5'
15+
]
916

1017
# Parse command-line arguments
1118
parser = argparse.ArgumentParser(
@@ -17,14 +24,14 @@
1724
parser.add_argument('output',
1825
action='store',
1926
nargs='?',
20-
default='graph.pdf',
27+
default='graph.dot',
2128
type=str,
22-
help='output file (default: graph.pdf)')
29+
help='output file (default: graph.dot)')
2330
parser.add_argument('--file-format',
2431
action='store',
25-
default='pdf',
32+
default='dot',
2633
type=str,
27-
help='output file format (default: pdf)',
34+
help='output file format (default: dot)',
2835
metavar='FORMAT')
2936
parser.add_argument('--label-format',
3037
action='store',
@@ -39,6 +46,10 @@
3946
type=str,
4047
help='Graphviz visualization scheme (default: dot)',
4148
metavar='ENGINE')
49+
parser.add_argument('--color-cross-grid',
50+
action='store_true',
51+
default=False,
52+
help='Highlight cross-grid edges')
4253
args = parser.parse_args()
4354

4455
# Strip extension from filename
@@ -51,9 +62,7 @@
5162
label_format = re.sub(r' |-|_', '', args.label_format.lower())
5263

5364
# Read prototext file
54-
proto = lbann_pb2.LbannPB()
55-
with open(args.input, 'r') as f:
56-
google.protobuf.text_format.Merge(f.read(), proto)
65+
proto = serialize.generic_load(args.input)
5766
model = proto.model
5867

5968
# Construct graphviz graph
@@ -62,29 +71,36 @@
6271
engine=args.graphviz_engine)
6372
graph.attr('node', shape='rect')
6473

74+
layer_to_grid_tag = {}
75+
6576
# Construct nodes in layer graph
6677
layer_types = (set(layers_pb2.Layer.DESCRIPTOR.fields_by_name.keys()) - set([
6778
'name', 'parents', 'children', 'datatype', 'data_layout',
6879
'device_allocation', 'weights', 'freeze', 'hint_layer', 'top', 'bottom',
69-
'type', 'motif_layer'
80+
'type', 'motif_layer', 'parallel_strategy', 'grid_tag'
7081
]))
7182
for l in model.layer:
7283

7384
# Determine layer type
74-
type = ''
85+
ltype = ''
7586
for _type in layer_types:
7687
if l.HasField(_type):
77-
type = getattr(l, _type).DESCRIPTOR.name
88+
ltype = getattr(l, _type).DESCRIPTOR.name
7889
break
7990

91+
# If operator layer, use operator type
92+
if ltype == 'OperatorLayer':
93+
url = l.operator_layer.ops[0].parameters.type_url
94+
ltype = url[url.rfind('.') + 1:]
95+
8096
# Construct node label
8197
label = ''
8298
if label_format == 'nameonly':
8399
label = l.name
84100
elif label_format == 'typeonly':
85-
label = type
101+
label = ltype
86102
elif label_format == 'typeandname':
87-
label = '<{0}<br/>{1}>'.format(type, l.name)
103+
label = '<{0}<br/>{1}>'.format(ltype, l.name)
88104
elif label_format == 'full':
89105
label = '<'
90106
for (index, line) in enumerate(str(l).strip().split('\n')):
@@ -94,14 +110,36 @@
94110
label += '>'
95111

96112
# Add layer as layer graph node
97-
graph.node(l.name, label=label)
113+
tag = l.grid_tag.value
114+
layer_to_grid_tag[l.name] = tag
115+
attrs = {}
116+
if tag != 0:
117+
attrs = dict(style='filled', fillcolor=palette[tag % len(palette)])
118+
graph.node(l.name, label=label, **attrs)
98119

99120
# Add parent/child relationships as layer graph edges
100121
edges = set()
122+
cross_grid_edges = set()
101123
for l in model.layer:
102-
edges.update([(p, l.name) for p in l.parents.split()])
103-
edges.update([(l.name, c) for c in l.children.split()])
124+
tag = layer_to_grid_tag[l.name]
125+
for p in l.parents:
126+
if tag != layer_to_grid_tag[p]:
127+
cross_grid_edges.add((p, l.name))
128+
else:
129+
edges.add((p, l.name))
130+
131+
for c in l.children:
132+
if tag != layer_to_grid_tag[c]:
133+
cross_grid_edges.add((l.name, c))
134+
else:
135+
edges.add((l.name, c))
136+
104137
graph.edges(edges)
138+
if args.color_cross_grid:
139+
for u, v in cross_grid_edges:
140+
graph.edge(u, v, color='red')
141+
else:
142+
graph.edges(cross_grid_edges)
105143

106144
# Save to file
107145
graph.render(filename=filename, cleanup=True, format=file_format)

0 commit comments

Comments
 (0)