From 39e19b4fd1b8b254fd8a057e621fdd25181995e3 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 5 Apr 2018 17:08:54 -0700 Subject: [PATCH 1/9] give warning for variables with same name in graph visualization --- python/mxnet/visualization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 2b9da15212db..b0827dbe7a5a 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -26,7 +26,7 @@ import re import copy import json - +import logging from .symbol import Symbol def _str2tuple(string): @@ -252,6 +252,9 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs shape_dict = dict(zip(interals.list_outputs(), out_shapes)) conf = json.loads(symbol.tojson()) nodes = conf["nodes"] + # check if multiple nodes have the same name + if len(nodes) != len(set([node["name"] for node in nodes])): + logging.warning("There are multiple variables with the same name in your graph, this may result in cyclic graph") # default attributes of node node_attr = {"shape": "box", "fixedsize": "true", "width": "1.3", "height": "0.8034", "style": "filled"} From 38cbfb5c165ff6929cce0c42aea98f9d2577ba6f Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 6 Apr 2018 10:54:15 -0700 Subject: [PATCH 2/9] fix line too long --- python/mxnet/visualization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index b0827dbe7a5a..683c9ba655e2 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -254,7 +254,8 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs nodes = conf["nodes"] # check if multiple nodes have the same name if len(nodes) != len(set([node["name"] for node in nodes])): - logging.warning("There are multiple variables with the same name in your graph, this may result in cyclic graph") + logging.warning("There are multiple variables with the same name in your graph, " + "this may result in cyclic graph") # default attributes of node node_attr = {"shape": "box", "fixedsize": "true", "width": "1.3", "height": "0.8034", "style": "filled"} From c287e879d343c7dfcd7ca1ec088bd436b8b4ab9d Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Mon, 9 Apr 2018 14:45:30 -0700 Subject: [PATCH 3/9] print repetead node names --- python/mxnet/visualization.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 683c9ba655e2..23e448300a45 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -254,8 +254,13 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs nodes = conf["nodes"] # check if multiple nodes have the same name if len(nodes) != len(set([node["name"] for node in nodes])): + seen = set() + seen_add = seen.add + # find all repeated names + repeated = set(node['name'] for node in nodes if node['name'] in seen + or seen_add(node['name'])) logging.warning("There are multiple variables with the same name in your graph, " - "this may result in cyclic graph") + "this may result in cyclic graph. Repeated names: %s", ','.join(repeated)) # default attributes of node node_attr = {"shape": "box", "fixedsize": "true", "width": "1.3", "height": "0.8034", "style": "filled"} From bca8542da5938341e7477a2ce13b6e877f826f33 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 12 Apr 2018 12:10:05 -0700 Subject: [PATCH 4/9] update warning and unit test --- python/mxnet/visualization.py | 14 +++++++------- tests/python/unittest/test_viz.py | 17 ++++++++++++++++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 23e448300a45..fc6db1ddcb31 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -26,7 +26,7 @@ import re import copy import json -import logging +import warnings from .symbol import Symbol def _str2tuple(string): @@ -254,13 +254,13 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs nodes = conf["nodes"] # check if multiple nodes have the same name if len(nodes) != len(set([node["name"] for node in nodes])): - seen = set() - seen_add = seen.add + seen_nodes = set() # find all repeated names - repeated = set(node['name'] for node in nodes if node['name'] in seen - or seen_add(node['name'])) - logging.warning("There are multiple variables with the same name in your graph, " - "this may result in cyclic graph. Repeated names: %s", ','.join(repeated)) + repeated = set(node['name'] for node in nodes if node['name'] in seen_nodes + or seen_nodes.add(node['name'])) + warning_message = "There are multiple variables with the same name in your graph, " \ + "this may result in cyclic graph. Repeated names: " + ','.join(repeated) + warnings.warn(warning_message, RuntimeWarning) # default attributes of node node_attr = {"shape": "box", "fixedsize": "true", "width": "1.3", "height": "0.8034", "style": "filled"} diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py index 73cfa94ba030..e8aaebada152 100644 --- a/tests/python/unittest/test_viz.py +++ b/tests/python/unittest/test_viz.py @@ -16,6 +16,7 @@ # under the License. import mxnet as mx +import warnings def test_print_summary(): data = mx.sym.Variable('data') @@ -32,5 +33,19 @@ def test_print_summary(): shape["data"]=(1,3,28,28) mx.viz.print_summary(sc1, shape) +def test_plot_network(): + # Test warnings for cyclic graph + net = mx.sym.Variable('data') + net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=128) + net = mx.sym.Activation(data=net, name='relu1', act_type="relu") + net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=10) + net = mx.sym.SoftmaxOutput(data=net, name='out') + with warnings.catch_warnings(record=True) as w: + digraph = mx.viz.plot_network(net, shape={'data': (100, 200)}, + node_attrs={"fixedsize": "false"}) + assert len(w) == 1 + assert "There are multiple variables with the same name in your graph" in str(w[-1].message) + if __name__ == "__main__": - test_print_summary() + import nose + nose.runmodule() From f9a44e7c73018580e1b9caf6896a42723029f2af Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 12 Apr 2018 12:12:53 -0700 Subject: [PATCH 5/9] add assert for repeated node --- tests/python/unittest/test_viz.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py index e8aaebada152..32a63976afb1 100644 --- a/tests/python/unittest/test_viz.py +++ b/tests/python/unittest/test_viz.py @@ -45,7 +45,8 @@ def test_plot_network(): node_attrs={"fixedsize": "false"}) assert len(w) == 1 assert "There are multiple variables with the same name in your graph" in str(w[-1].message) - + assert "fc" in str(w[-1].message) + if __name__ == "__main__": import nose nose.runmodule() From 4880fff36d106b4d651663a94c4eddf0484655bb Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 12 Apr 2018 14:43:08 -0700 Subject: [PATCH 6/9] add graphviz for arm --- ci/docker/install/arm64_openblas.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/docker/install/arm64_openblas.sh b/ci/docker/install/arm64_openblas.sh index 3151a4b5c2a6..d6562baa8711 100755 --- a/ci/docker/install/arm64_openblas.sh +++ b/ci/docker/install/arm64_openblas.sh @@ -32,4 +32,5 @@ make install ln -s /opt/OpenBLAS/lib/libopenblas.so /usr/lib/libopenblas.so ln -s /opt/OpenBLAS/lib/libopenblas.a /usr/lib/libopenblas.a ln -s /opt/OpenBLAS/lib/libopenblas.a /usr/lib/liblapack.a +pip install graphviz popd \ No newline at end of file From 0ef7c1e185cef22ed57862b35c9b526441c6c6fc Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 12 Apr 2018 15:25:26 -0700 Subject: [PATCH 7/9] update docker install --- ci/docker/install/arm64_openblas.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/ci/docker/install/arm64_openblas.sh b/ci/docker/install/arm64_openblas.sh index d6562baa8711..3151a4b5c2a6 100755 --- a/ci/docker/install/arm64_openblas.sh +++ b/ci/docker/install/arm64_openblas.sh @@ -32,5 +32,4 @@ make install ln -s /opt/OpenBLAS/lib/libopenblas.so /usr/lib/libopenblas.so ln -s /opt/OpenBLAS/lib/libopenblas.a /usr/lib/libopenblas.a ln -s /opt/OpenBLAS/lib/libopenblas.a /usr/lib/liblapack.a -pip install graphviz popd \ No newline at end of file From e424405f47c1a76bb1b0434646b2e8c89da10079 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Mon, 16 Apr 2018 11:42:31 -0700 Subject: [PATCH 8/9] skip unittest if graphviz could not be imported --- tests/python/unittest/test_viz.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py index 32a63976afb1..48fea71fb38c 100644 --- a/tests/python/unittest/test_viz.py +++ b/tests/python/unittest/test_viz.py @@ -17,6 +17,7 @@ import mxnet as mx import warnings +import unittest def test_print_summary(): data = mx.sym.Variable('data') @@ -33,6 +34,15 @@ def test_print_summary(): shape["data"]=(1,3,28,28) mx.viz.print_summary(sc1, shape) +def graphviz_exists(): + try: + import graphviz + except ImportError: + return False + else: + return True + +@unittest.skipIf(not graphviz_exists(), "Skip test_plot_network as Graphviz could not be imported") def test_plot_network(): # Test warnings for cyclic graph net = mx.sym.Variable('data') From 1b778a64beb08acfc94eb0c4ac6f9c70b8e22d5f Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 29 Jun 2018 09:55:34 -0700 Subject: [PATCH 9/9] optimize imports --- tests/python/unittest/test_viz.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py index 48fea71fb38c..eb5921f2e823 100644 --- a/tests/python/unittest/test_viz.py +++ b/tests/python/unittest/test_viz.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. -import mxnet as mx -import warnings import unittest +import warnings + +import mxnet as mx + def test_print_summary(): data = mx.sym.Variable('data') @@ -56,7 +58,7 @@ def test_plot_network(): assert len(w) == 1 assert "There are multiple variables with the same name in your graph" in str(w[-1].message) assert "fc" in str(w[-1].message) - + if __name__ == "__main__": import nose nose.runmodule()