Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add dtype visualization to plot_network #14066

Merged
merged 9 commits into from
Mar 15, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def print_layer_summary(node, out_shape):
print('Total params: %s' % total_params)
print('_' * line_length)

def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={},
def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={},
hide_weights=True):
"""Creates a visualization (Graphviz digraph object) of the given computation graph.
Graphviz must be installed for this function to work.
Expand All @@ -224,6 +224,10 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
Specifies the shape of the input tensors. If specified, the visualization will include
the shape of the tensors between the nodes. `shape` is a dictionary mapping
input symbol names (str) to the corresponding tensor shape (tuple).
dtype: dict, optional
Specifies the type of the input tensors. If specified, the visualization will include
the type of the tensors between the nodes. `dtype` is a dictionary mapping
input symbol names (str) to the corresponding tensor type (e.g. `numpy.float32`).
node_attrs: dict, optional
Specifies the attributes for nodes in the generated visualization. `node_attrs` is
a dictionary of Graphviz attribute names and values. For example::
Expand Down Expand Up @@ -274,11 +278,19 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
draw_shape = False
if shape is not None:
szha marked this conversation as resolved.
Show resolved Hide resolved
draw_shape = True
interals = symbol.get_internals()
_, out_shapes, _ = interals.infer_shape(**shape)
internals = symbol.get_internals()
szha marked this conversation as resolved.
Show resolved Hide resolved
_, out_shapes, _ = internals.infer_shape(**shape)
if out_shapes is None:
raise ValueError("Input shape is incomplete")
shape_dict = dict(zip(interals.list_outputs(), out_shapes))
shape_dict = dict(zip(internals.list_outputs(), out_shapes))
draw_type = False
if dtype is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just if dtype: ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as shapes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this seems slightly more readable.

draw_type = dtype is not None
if draw_type:
   ...

draw_type = True
internals = symbol.get_internals()
_, out_types, _ = internals.infer_type(**dtype)
if out_types is None:
raise ValueError("Input type is incomplete")
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
type_dict = dict(zip(internals.list_outputs(), out_types))
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
# check if multiple nodes have the same name
Expand Down Expand Up @@ -370,7 +382,7 @@ def looks_like_weight(name):
input_node = nodes[item[0]]
input_name = input_node["name"]
if input_name not in hidden_nodes:
attr = {"dir": "back", 'arrowtail':'open'}
attr = {"dir": "back", 'arrowtail':'open', 'label': ''}
# add shapes
if draw_shape:
if input_node["op"] != "null":
Expand All @@ -387,6 +399,19 @@ def looks_like_weight(name):
shape = shape_dict[key][1:]
label = "x".join([str(x) for x in shape])
attr["label"] = label
if draw_type:
if input_node["op"] != "null":
key = input_name + "_output"
if "attrs" in input_node:
params = input_node["attrs"]
if "num_outputs" in params:
key += str(int(params["num_outputs"]) - 1)
dtype = type_dict[key]
attr["label"] += '(' + dtype.__name__ + ')'
else:
key = input_name
dtype = type_dict[key]
attr["label"] += '(' + dtype.__name__ + ')'
dot.edge(tail_name=name, head_name=input_name, **attr)

return dot
2 changes: 2 additions & 0 deletions tests/python/unittest/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import warnings

import mxnet as mx
import numpy as np


def test_print_summary():
Expand Down Expand Up @@ -55,6 +56,7 @@ def test_plot_network():
net = mx.sym.SoftmaxOutput(data=net, name='out')
with warnings.catch_warnings(record=True) as w:
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
dtype={'data': np.float32},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be testing for dtypes other than np.float32?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not really matter - this test just tries to catch errors during preparation of the picture, and for that the exact type used does not make any difference.

node_attrs={"fixedsize": "false"})
assert len(w) == 1
assert "There are multiple variables with the same name in your graph" in str(w[-1].message)
Expand Down