From 7c74d1f6367324258480e5b78bdadb1cfa557e6a Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Mon, 2 Jul 2018 13:59:07 -0700 Subject: [PATCH] [MXNET-210]give warning for variables with same name in graph visualization (#10429) * give warning for variables with same name in graph visualization * fix line too long * print repetead node names * update warning and unit test * add assert for repeated node * add graphviz for arm * update docker install * skip unittest if graphviz could not be imported * optimize imports --- python/mxnet/visualization.py | 11 ++++++++++- tests/python/unittest/test_viz.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 2b9da15212db..fc6db1ddcb31 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -26,7 +26,7 @@ import re import copy import json - +import warnings from .symbol import Symbol def _str2tuple(string): @@ -252,6 +252,15 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs shape_dict = dict(zip(interals.list_outputs(), out_shapes)) conf = json.loads(symbol.tojson()) nodes = conf["nodes"] + # check if multiple nodes have the same name + if len(nodes) != len(set([node["name"] for node in nodes])): + seen_nodes = set() + # find all repeated names + repeated = set(node['name'] for node in nodes if node['name'] in seen_nodes + or seen_nodes.add(node['name'])) + warning_message = "There are multiple variables with the same name in your graph, " \ + "this may result in cyclic graph. Repeated names: " + ','.join(repeated) + warnings.warn(warning_message, RuntimeWarning) # default attributes of node node_attr = {"shape": "box", "fixedsize": "true", "width": "1.3", "height": "0.8034", "style": "filled"} diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py index 73cfa94ba030..eb5921f2e823 100644 --- a/tests/python/unittest/test_viz.py +++ b/tests/python/unittest/test_viz.py @@ -15,8 +15,12 @@ # specific language governing permissions and limitations # under the License. +import unittest +import warnings + import mxnet as mx + def test_print_summary(): data = mx.sym.Variable('data') bias = mx.sym.Variable('fc1_bias', lr_mult=1.0) @@ -32,5 +36,29 @@ def test_print_summary(): shape["data"]=(1,3,28,28) mx.viz.print_summary(sc1, shape) +def graphviz_exists(): + try: + import graphviz + except ImportError: + return False + else: + return True + +@unittest.skipIf(not graphviz_exists(), "Skip test_plot_network as Graphviz could not be imported") +def test_plot_network(): + # Test warnings for cyclic graph + net = mx.sym.Variable('data') + net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=128) + net = mx.sym.Activation(data=net, name='relu1', act_type="relu") + net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=10) + net = mx.sym.SoftmaxOutput(data=net, name='out') + with warnings.catch_warnings(record=True) as w: + digraph = mx.viz.plot_network(net, shape={'data': (100, 200)}, + node_attrs={"fixedsize": "false"}) + assert len(w) == 1 + assert "There are multiple variables with the same name in your graph" in str(w[-1].message) + assert "fc" in str(w[-1].message) + if __name__ == "__main__": - test_print_summary() + import nose + nose.runmodule()