From b077965b7e41108cafe0123659c4802cb771442b Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Thu, 14 Mar 2019 18:14:44 -0700 Subject: [PATCH] Add dtype visualization to plot_network (#14066) * Add dtype to plot_network * Added docstring for the new param * Added dtype to the plot_network test * Changes from review * Fixes from review * Fix typo * Retrigger CI --- python/mxnet/visualization.py | 38 ++++++++++++++++++++++++------- tests/python/unittest/test_viz.py | 2 ++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 1ebdcb54f4ce..dd3a1df345d3 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -208,7 +208,7 @@ def print_layer_summary(node, out_shape): print('Total params: %s' % total_params) print('_' * line_length) -def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={}, +def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={}, hide_weights=True): """Creates a visualization (Graphviz digraph object) of the given computation graph. Graphviz must be installed for this function to work. @@ -224,6 +224,10 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs Specifies the shape of the input tensors. If specified, the visualization will include the shape of the tensors between the nodes. `shape` is a dictionary mapping input symbol names (str) to the corresponding tensor shape (tuple). + dtype: dict, optional + Specifies the type of the input tensors. If specified, the visualization will include + the type of the tensors between the nodes. `dtype` is a dictionary mapping + input symbol names (str) to the corresponding tensor type (e.g. `numpy.float32`). node_attrs: dict, optional Specifies the attributes for nodes in the generated visualization. `node_attrs` is a dictionary of Graphviz attribute names and values. For example:: @@ -271,14 +275,19 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs raise ImportError("Draw network requires graphviz library") if not isinstance(symbol, Symbol): raise TypeError("symbol must be a Symbol") - draw_shape = False - if shape is not None: - draw_shape = True - interals = symbol.get_internals() - _, out_shapes, _ = interals.infer_shape(**shape) + internals = symbol.get_internals() + draw_shape = shape is not None + if draw_shape: + _, out_shapes, _ = internals.infer_shape(**shape) if out_shapes is None: raise ValueError("Input shape is incomplete") - shape_dict = dict(zip(interals.list_outputs(), out_shapes)) + shape_dict = dict(zip(internals.list_outputs(), out_shapes)) + draw_type = dtype is not None + if draw_type: + _, out_types, _ = internals.infer_type(**dtype) + if out_types is None: + raise ValueError("Input type is incomplete") + type_dict = dict(zip(internals.list_outputs(), out_types)) conf = json.loads(symbol.tojson()) nodes = conf["nodes"] # check if multiple nodes have the same name @@ -370,7 +379,7 @@ def looks_like_weight(name): input_node = nodes[item[0]] input_name = input_node["name"] if input_name not in hidden_nodes: - attr = {"dir": "back", 'arrowtail':'open'} + attr = {"dir": "back", 'arrowtail':'open', 'label': ''} # add shapes if draw_shape: if input_node["op"] != "null": @@ -387,6 +396,19 @@ def looks_like_weight(name): shape = shape_dict[key][1:] label = "x".join([str(x) for x in shape]) attr["label"] = label + if draw_type: + if input_node["op"] != "null": + key = input_name + "_output" + if "attrs" in input_node: + params = input_node["attrs"] + if "num_outputs" in params: + key += str(int(params["num_outputs"]) - 1) + dtype = type_dict[key] + attr["label"] += '(' + dtype.__name__ + ')' + else: + key = input_name + dtype = type_dict[key] + attr["label"] += '(' + dtype.__name__ + ')' dot.edge(tail_name=name, head_name=input_name, **attr) return dot diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py index fe564b0088f8..13210993014a 100644 --- a/tests/python/unittest/test_viz.py +++ b/tests/python/unittest/test_viz.py @@ -19,6 +19,7 @@ import warnings import mxnet as mx +import numpy as np def test_print_summary(): @@ -55,6 +56,7 @@ def test_plot_network(): net = mx.sym.SoftmaxOutput(data=net, name='out') with warnings.catch_warnings(record=True) as w: digraph = mx.viz.plot_network(net, shape={'data': (100, 200)}, + dtype={'data': np.float32}, node_attrs={"fixedsize": "false"}) assert len(w) == 1 assert "There are multiple variables with the same name in your graph" in str(w[-1].message)