diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 23e448300a45..fc6db1ddcb31 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -26,7 +26,7 @@ import re import copy import json -import logging +import warnings from .symbol import Symbol def _str2tuple(string): @@ -254,13 +254,13 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs nodes = conf["nodes"] # check if multiple nodes have the same name if len(nodes) != len(set([node["name"] for node in nodes])): - seen = set() - seen_add = seen.add + seen_nodes = set() # find all repeated names - repeated = set(node['name'] for node in nodes if node['name'] in seen - or seen_add(node['name'])) - logging.warning("There are multiple variables with the same name in your graph, " - "this may result in cyclic graph. Repeated names: %s", ','.join(repeated)) + repeated = set(node['name'] for node in nodes if node['name'] in seen_nodes + or seen_nodes.add(node['name'])) + warning_message = "There are multiple variables with the same name in your graph, " \ + "this may result in cyclic graph. Repeated names: " + ','.join(repeated) + warnings.warn(warning_message, RuntimeWarning) # default attributes of node node_attr = {"shape": "box", "fixedsize": "true", "width": "1.3", "height": "0.8034", "style": "filled"} diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py index 73cfa94ba030..e8aaebada152 100644 --- a/tests/python/unittest/test_viz.py +++ b/tests/python/unittest/test_viz.py @@ -16,6 +16,7 @@ # under the License. import mxnet as mx +import warnings def test_print_summary(): data = mx.sym.Variable('data') @@ -32,5 +33,19 @@ def test_print_summary(): shape["data"]=(1,3,28,28) mx.viz.print_summary(sc1, shape) +def test_plot_network(): + # Test warnings for cyclic graph + net = mx.sym.Variable('data') + net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=128) + net = mx.sym.Activation(data=net, name='relu1', act_type="relu") + net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=10) + net = mx.sym.SoftmaxOutput(data=net, name='out') + with warnings.catch_warnings(record=True) as w: + digraph = mx.viz.plot_network(net, shape={'data': (100, 200)}, + node_attrs={"fixedsize": "false"}) + assert len(w) == 1 + assert "There are multiple variables with the same name in your graph" in str(w[-1].message) + if __name__ == "__main__": - test_print_summary() + import nose + nose.runmodule()