Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update warning and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
roywei committed Apr 12, 2018
1 parent 73000eb commit 485fd7b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
14 changes: 7 additions & 7 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import re
import copy
import json
import logging
import warnings
from .symbol import Symbol

def _str2tuple(string):
Expand Down Expand Up @@ -254,13 +254,13 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
nodes = conf["nodes"]
# check if multiple nodes have the same name
if len(nodes) != len(set([node["name"] for node in nodes])):
seen = set()
seen_add = seen.add
seen_nodes = set()
# find all repeated names
repeated = set(node['name'] for node in nodes if node['name'] in seen
or seen_add(node['name']))
logging.warning("There are multiple variables with the same name in your graph, "
"this may result in cyclic graph. Repeated names: %s", ','.join(repeated))
repeated = set(node['name'] for node in nodes if node['name'] in seen_nodes
or seen_nodes.add(node['name']))
warning_message = "There are multiple variables with the same name in your graph, " \
"this may result in cyclic graph. Repeated names: " + ','.join(repeated)
warnings.warn(warning_message, RuntimeWarning)
# default attributes of node
node_attr = {"shape": "box", "fixedsize": "true",
"width": "1.3", "height": "0.8034", "style": "filled"}
Expand Down
17 changes: 16 additions & 1 deletion tests/python/unittest/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import mxnet as mx
import warnings

def test_print_summary():
data = mx.sym.Variable('data')
Expand All @@ -32,5 +33,19 @@ def test_print_summary():
shape["data"]=(1,3,28,28)
mx.viz.print_summary(sc1, shape)

def test_plot_network():
# Test warnings for cyclic graph
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=128)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=10)
net = mx.sym.SoftmaxOutput(data=net, name='out')
with warnings.catch_warnings(record=True) as w:
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
node_attrs={"fixedsize": "false"})
assert len(w) == 1
assert "There are multiple variables with the same name in your graph" in str(w[-1].message)

if __name__ == "__main__":
test_print_summary()
import nose
nose.runmodule()

0 comments on commit 485fd7b

Please sign in to comment.