Skip to content

Commit

Permalink
Add dtype visualization to plot_network (apache#14066)
Browse files Browse the repository at this point in the history
* Add dtype to plot_network

* Added docstring for the new param

* Added dtype to the plot_network test

* Changes from review

* Fixes from review

* Fix typo

* Retrigger CI
  • Loading branch information
ptrendx authored and vdantu committed Mar 31, 2019
1 parent dcacf28 commit 8ab4e96
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
38 changes: 30 additions & 8 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def print_layer_summary(node, out_shape):
print('Total params: %s' % total_params)
print('_' * line_length)

def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={},
def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={},
hide_weights=True):
"""Creates a visualization (Graphviz digraph object) of the given computation graph.
Graphviz must be installed for this function to work.
Expand All @@ -224,6 +224,10 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
Specifies the shape of the input tensors. If specified, the visualization will include
the shape of the tensors between the nodes. `shape` is a dictionary mapping
input symbol names (str) to the corresponding tensor shape (tuple).
dtype: dict, optional
Specifies the type of the input tensors. If specified, the visualization will include
the type of the tensors between the nodes. `dtype` is a dictionary mapping
input symbol names (str) to the corresponding tensor type (e.g. `numpy.float32`).
node_attrs: dict, optional
Specifies the attributes for nodes in the generated visualization. `node_attrs` is
a dictionary of Graphviz attribute names and values. For example::
Expand Down Expand Up @@ -271,14 +275,19 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
raise ImportError("Draw network requires graphviz library")
if not isinstance(symbol, Symbol):
raise TypeError("symbol must be a Symbol")
draw_shape = False
if shape is not None:
draw_shape = True
interals = symbol.get_internals()
_, out_shapes, _ = interals.infer_shape(**shape)
internals = symbol.get_internals()
draw_shape = shape is not None
if draw_shape:
_, out_shapes, _ = internals.infer_shape(**shape)
if out_shapes is None:
raise ValueError("Input shape is incomplete")
shape_dict = dict(zip(interals.list_outputs(), out_shapes))
shape_dict = dict(zip(internals.list_outputs(), out_shapes))
draw_type = dtype is not None
if draw_type:
_, out_types, _ = internals.infer_type(**dtype)
if out_types is None:
raise ValueError("Input type is incomplete")
type_dict = dict(zip(internals.list_outputs(), out_types))
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
# check if multiple nodes have the same name
Expand Down Expand Up @@ -370,7 +379,7 @@ def looks_like_weight(name):
input_node = nodes[item[0]]
input_name = input_node["name"]
if input_name not in hidden_nodes:
attr = {"dir": "back", 'arrowtail':'open'}
attr = {"dir": "back", 'arrowtail':'open', 'label': ''}
# add shapes
if draw_shape:
if input_node["op"] != "null":
Expand All @@ -387,6 +396,19 @@ def looks_like_weight(name):
shape = shape_dict[key][1:]
label = "x".join([str(x) for x in shape])
attr["label"] = label
if draw_type:
if input_node["op"] != "null":
key = input_name + "_output"
if "attrs" in input_node:
params = input_node["attrs"]
if "num_outputs" in params:
key += str(int(params["num_outputs"]) - 1)
dtype = type_dict[key]
attr["label"] += '(' + dtype.__name__ + ')'
else:
key = input_name
dtype = type_dict[key]
attr["label"] += '(' + dtype.__name__ + ')'
dot.edge(tail_name=name, head_name=input_name, **attr)

return dot
2 changes: 2 additions & 0 deletions tests/python/unittest/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import warnings

import mxnet as mx
import numpy as np


def test_print_summary():
Expand Down Expand Up @@ -55,6 +56,7 @@ def test_plot_network():
net = mx.sym.SoftmaxOutput(data=net, name='out')
with warnings.catch_warnings(record=True) as w:
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
dtype={'data': np.float32},
node_attrs={"fixedsize": "false"})
assert len(w) == 1
assert "There are multiple variables with the same name in your graph" in str(w[-1].message)
Expand Down

0 comments on commit 8ab4e96

Please sign in to comment.