From fe933a7e3f7f58da481d21dba7cf99c44d3cab0b Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 4 Feb 2019 15:20:44 -0800 Subject: [PATCH 1/7] Add dtype to plot_network --- python/mxnet/visualization.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 1ebdcb54f4ce..70b261a64bdc 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -208,7 +208,7 @@ def print_layer_summary(node, out_shape): print('Total params: %s' % total_params) print('_' * line_length) -def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={}, +def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={}, hide_weights=True): """Creates a visualization (Graphviz digraph object) of the given computation graph. Graphviz must be installed for this function to work. @@ -274,11 +274,19 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs draw_shape = False if shape is not None: draw_shape = True - interals = symbol.get_internals() - _, out_shapes, _ = interals.infer_shape(**shape) + internals = symbol.get_internals() + _, out_shapes, _ = internals.infer_shape(**shape) if out_shapes is None: raise ValueError("Input shape is incomplete") - shape_dict = dict(zip(interals.list_outputs(), out_shapes)) + shape_dict = dict(zip(internals.list_outputs(), out_shapes)) + draw_type = False + if dtype is not None: + draw_type = True + internals = symbol.get_internals() + _, out_types, _ = internals.infer_type(**dtype) + if out_types is None: + raise ValueError("Input type is incomplete") + type_dict = dict(zip(internals.list_outputs(), out_types)) conf = json.loads(symbol.tojson()) nodes = conf["nodes"] # check if multiple nodes have the same name @@ -370,7 +378,7 @@ def looks_like_weight(name): input_node = nodes[item[0]] input_name = input_node["name"] if input_name not in hidden_nodes: - attr = {"dir": "back", 'arrowtail':'open'} + attr = {"dir": "back", 'arrowtail':'open', 'label': ''} # add shapes if draw_shape: if input_node["op"] != "null": @@ -387,6 +395,19 @@ def looks_like_weight(name): shape = shape_dict[key][1:] label = "x".join([str(x) for x in shape]) attr["label"] = label + if draw_type: + if input_node["op"] != "null": + key = input_name + "_output" + if "attrs" in input_node: + params = input_node["attrs"] + if "num_outputs" in params: + key += str(int(params["num_outputs"]) - 1) + dtype = type_dict[key] + attr["label"] += '(' + dtype.__name__ + ')' + else: + key = input_name + dtype = type_dict[key] + attr["label"] += '(' + dtype.__name__ + ')' dot.edge(tail_name=name, head_name=input_name, **attr) return dot From 4d89c2b7c4ec5c04d642a1cd12f1fae53314ce39 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 4 Feb 2019 15:37:10 -0800 Subject: [PATCH 2/7] Added docstring for the new param --- python/mxnet/visualization.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 70b261a64bdc..e80f8b947be6 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -224,6 +224,10 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None Specifies the shape of the input tensors. If specified, the visualization will include the shape of the tensors between the nodes. `shape` is a dictionary mapping input symbol names (str) to the corresponding tensor shape (tuple). + dtype: dict, optional + Specifies the type of the input tensors. If specified, the visualization will include + the type of the tensors between the nodes. `dtype` is a dictionary mapping + input symbol names (str) to the corresponding tensor type (e.g. `numpy.float32`). node_attrs: dict, optional Specifies the attributes for nodes in the generated visualization. `node_attrs` is a dictionary of Graphviz attribute names and values. For example:: From 8d5e1a7ae6959aeaa0f98eb8760643a8260a9cfb Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 4 Feb 2019 15:39:19 -0800 Subject: [PATCH 3/7] Added dtype to the plot_network test --- tests/python/unittest/test_viz.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py index fe564b0088f8..13210993014a 100644 --- a/tests/python/unittest/test_viz.py +++ b/tests/python/unittest/test_viz.py @@ -19,6 +19,7 @@ import warnings import mxnet as mx +import numpy as np def test_print_summary(): @@ -55,6 +56,7 @@ def test_plot_network(): net = mx.sym.SoftmaxOutput(data=net, name='out') with warnings.catch_warnings(record=True) as w: digraph = mx.viz.plot_network(net, shape={'data': (100, 200)}, + dtype={'data': np.float32}, node_attrs={"fixedsize": "false"}) assert len(w) == 1 assert "There are multiple variables with the same name in your graph" in str(w[-1].message) From 49dc77cb73b3e7252343b3463135f619b9f349d6 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 5 Feb 2019 08:45:34 -0800 Subject: [PATCH 4/7] Changes from review --- python/mxnet/visualization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index e80f8b947be6..d2833f670daf 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -275,10 +275,10 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None raise ImportError("Draw network requires graphviz library") if not isinstance(symbol, Symbol): raise TypeError("symbol must be a Symbol") + internals = symbol.get_internals() draw_shape = False if shape is not None: draw_shape = True - internals = symbol.get_internals() _, out_shapes, _ = internals.infer_shape(**shape) if out_shapes is None: raise ValueError("Input shape is incomplete") @@ -286,7 +286,6 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None draw_type = False if dtype is not None: draw_type = True - internals = symbol.get_internals() _, out_types, _ = internals.infer_type(**dtype) if out_types is None: raise ValueError("Input type is incomplete") From a00c9d40e93090b875e05fa745f177e577c080a1 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 13 Feb 2019 13:57:03 -0800 Subject: [PATCH 5/7] Fixes from review --- python/mxnet/visualization.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index d2833f670daf..728e752ea43e 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -276,16 +276,14 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None if not isinstance(symbol, Symbol): raise TypeError("symbol must be a Symbol") internals = symbol.get_internals() - draw_shape = False - if shape is not None: - draw_shape = True + draw_shape = shape is not None + if draw_shape _, out_shapes, _ = internals.infer_shape(**shape) if out_shapes is None: raise ValueError("Input shape is incomplete") shape_dict = dict(zip(internals.list_outputs(), out_shapes)) - draw_type = False - if dtype is not None: - draw_type = True + draw_type = dtype is not None + if draw_type _, out_types, _ = internals.infer_type(**dtype) if out_types is None: raise ValueError("Input type is incomplete") From bb75077d10364b847b0bc8f4a9b7e0f1237ca656 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 13 Feb 2019 14:15:52 -0800 Subject: [PATCH 6/7] Fix typo --- python/mxnet/visualization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 728e752ea43e..dd3a1df345d3 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -277,13 +277,13 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None raise TypeError("symbol must be a Symbol") internals = symbol.get_internals() draw_shape = shape is not None - if draw_shape + if draw_shape: _, out_shapes, _ = internals.infer_shape(**shape) if out_shapes is None: raise ValueError("Input shape is incomplete") shape_dict = dict(zip(internals.list_outputs(), out_shapes)) draw_type = dtype is not None - if draw_type + if draw_type: _, out_types, _ = internals.infer_type(**dtype) if out_types is None: raise ValueError("Input type is incomplete") From 3d93f9d1a315b013af8bdb1edc10c5a111ef37c9 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 6 Mar 2019 10:27:06 -0800 Subject: [PATCH 7/7] Retrigger CI