Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-210]give warning for variables with same name in graph visualization #10429

Merged
merged 9 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import re
import copy
import json

import warnings
from .symbol import Symbol

def _str2tuple(string):
Expand Down Expand Up @@ -252,6 +252,15 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
shape_dict = dict(zip(interals.list_outputs(), out_shapes))
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
# check if multiple nodes have the same name
if len(nodes) != len(set([node["name"] for node in nodes])):
seen_nodes = set()
# find all repeated names
repeated = set(node['name'] for node in nodes if node['name'] in seen_nodes
or seen_nodes.add(node['name']))
warning_message = "There are multiple variables with the same name in your graph, " \
"this may result in cyclic graph. Repeated names: " + ','.join(repeated)
warnings.warn(warning_message, RuntimeWarning)
# default attributes of node
node_attr = {"shape": "box", "fixedsize": "true",
"width": "1.3", "height": "0.8034", "style": "filled"}
Expand Down
30 changes: 29 additions & 1 deletion tests/python/unittest/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
# specific language governing permissions and limitations
# under the License.

import unittest
import warnings

import mxnet as mx


def test_print_summary():
data = mx.sym.Variable('data')
bias = mx.sym.Variable('fc1_bias', lr_mult=1.0)
Expand All @@ -32,5 +36,29 @@ def test_print_summary():
shape["data"]=(1,3,28,28)
mx.viz.print_summary(sc1, shape)

def graphviz_exists():
try:
import graphviz
except ImportError:
return False
else:
return True

@unittest.skipIf(not graphviz_exists(), "Skip test_plot_network as Graphviz could not be imported")
def test_plot_network():
# Test warnings for cyclic graph
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=128)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=10)
net = mx.sym.SoftmaxOutput(data=net, name='out')
with warnings.catch_warnings(record=True) as w:
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
node_attrs={"fixedsize": "false"})
assert len(w) == 1
assert "There are multiple variables with the same name in your graph" in str(w[-1].message)
assert "fc" in str(w[-1].message)

if __name__ == "__main__":
test_print_summary()
import nose
nose.runmodule()