Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add dtype visualization to plot_network #14066

Merged
merged 9 commits into from
Mar 15, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def print_layer_summary(node, out_shape):
print('Total params: %s' % total_params)
print('_' * line_length)

def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={},
def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={},
hide_weights=True):
"""Creates a visualization (Graphviz digraph object) of the given computation graph.
Graphviz must be installed for this function to work.
Expand All @@ -224,6 +224,10 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
Specifies the shape of the input tensors. If specified, the visualization will include
the shape of the tensors between the nodes. `shape` is a dictionary mapping
input symbol names (str) to the corresponding tensor shape (tuple).
dtype: dict, optional
Specifies the type of the input tensors. If specified, the visualization will include
the type of the tensors between the nodes. `dtype` is a dictionary mapping
input symbol names (str) to the corresponding tensor type (e.g. `numpy.float32`).
node_attrs: dict, optional
Specifies the attributes for nodes in the generated visualization. `node_attrs` is
a dictionary of Graphviz attribute names and values. For example::
Expand Down Expand Up @@ -271,14 +275,19 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
raise ImportError("Draw network requires graphviz library")
if not isinstance(symbol, Symbol):
raise TypeError("symbol must be a Symbol")
draw_shape = False
if shape is not None:
draw_shape = True
interals = symbol.get_internals()
_, out_shapes, _ = interals.infer_shape(**shape)
internals = symbol.get_internals()
draw_shape = shape is not None
if draw_shape:
_, out_shapes, _ = internals.infer_shape(**shape)
if out_shapes is None:
raise ValueError("Input shape is incomplete")
shape_dict = dict(zip(interals.list_outputs(), out_shapes))
shape_dict = dict(zip(internals.list_outputs(), out_shapes))
draw_type = dtype is not None
if draw_type:
_, out_types, _ = internals.infer_type(**dtype)
if out_types is None:
raise ValueError("Input type is incomplete")
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
type_dict = dict(zip(internals.list_outputs(), out_types))
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
# check if multiple nodes have the same name
Expand Down Expand Up @@ -370,7 +379,7 @@ def looks_like_weight(name):
input_node = nodes[item[0]]
input_name = input_node["name"]
if input_name not in hidden_nodes:
attr = {"dir": "back", 'arrowtail':'open'}
attr = {"dir": "back", 'arrowtail':'open', 'label': ''}
# add shapes
if draw_shape:
if input_node["op"] != "null":
Expand All @@ -387,6 +396,19 @@ def looks_like_weight(name):
shape = shape_dict[key][1:]
label = "x".join([str(x) for x in shape])
attr["label"] = label
if draw_type:
if input_node["op"] != "null":
key = input_name + "_output"
if "attrs" in input_node:
params = input_node["attrs"]
if "num_outputs" in params:
key += str(int(params["num_outputs"]) - 1)
dtype = type_dict[key]
attr["label"] += '(' + dtype.__name__ + ')'
else:
key = input_name
dtype = type_dict[key]
attr["label"] += '(' + dtype.__name__ + ')'
dot.edge(tail_name=name, head_name=input_name, **attr)

return dot
2 changes: 2 additions & 0 deletions tests/python/unittest/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import warnings

import mxnet as mx
import numpy as np


def test_print_summary():
Expand Down Expand Up @@ -55,6 +56,7 @@ def test_plot_network():
net = mx.sym.SoftmaxOutput(data=net, name='out')
with warnings.catch_warnings(record=True) as w:
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
dtype={'data': np.float32},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be testing for dtypes other than np.float32?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not really matter - this test just tries to catch errors during preparation of the picture, and for that the exact type used does not make any difference.

node_attrs={"fixedsize": "false"})
assert len(w) == 1
assert "There are multiple variables with the same name in your graph" in str(w[-1].message)
Expand Down