Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-210]give warning for variables with same name in graph visuali…
Browse files Browse the repository at this point in the history
…zation (#10429)

* give warning for variables with same name in graph visualization

* fix line too long

* print repetead node names

* update warning and unit test

* add assert for repeated node

* add graphviz for arm

* update docker install

* skip unittest if graphviz could not be imported

* optimize imports
  • Loading branch information
roywei authored and szha committed Jul 2, 2018
1 parent defd544 commit 7c74d1f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
11 changes: 10 additions & 1 deletion python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import re
import copy
import json

import warnings
from .symbol import Symbol

def _str2tuple(string):
Expand Down Expand Up @@ -252,6 +252,15 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs
shape_dict = dict(zip(interals.list_outputs(), out_shapes))
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
# check if multiple nodes have the same name
if len(nodes) != len(set([node["name"] for node in nodes])):
seen_nodes = set()
# find all repeated names
repeated = set(node['name'] for node in nodes if node['name'] in seen_nodes
or seen_nodes.add(node['name']))
warning_message = "There are multiple variables with the same name in your graph, " \
"this may result in cyclic graph. Repeated names: " + ','.join(repeated)
warnings.warn(warning_message, RuntimeWarning)
# default attributes of node
node_attr = {"shape": "box", "fixedsize": "true",
"width": "1.3", "height": "0.8034", "style": "filled"}
Expand Down
30 changes: 29 additions & 1 deletion tests/python/unittest/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
# specific language governing permissions and limitations
# under the License.

import unittest
import warnings

import mxnet as mx


def test_print_summary():
data = mx.sym.Variable('data')
bias = mx.sym.Variable('fc1_bias', lr_mult=1.0)
Expand All @@ -32,5 +36,29 @@ def test_print_summary():
shape["data"]=(1,3,28,28)
mx.viz.print_summary(sc1, shape)

def graphviz_exists():
try:
import graphviz
except ImportError:
return False
else:
return True

@unittest.skipIf(not graphviz_exists(), "Skip test_plot_network as Graphviz could not be imported")
def test_plot_network():
# Test warnings for cyclic graph
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=128)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc', num_hidden=10)
net = mx.sym.SoftmaxOutput(data=net, name='out')
with warnings.catch_warnings(record=True) as w:
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)},
node_attrs={"fixedsize": "false"})
assert len(w) == 1
assert "There are multiple variables with the same name in your graph" in str(w[-1].message)
assert "fc" in str(w[-1].message)

if __name__ == "__main__":
test_print_summary()
import nose
nose.runmodule()

0 comments on commit 7c74d1f

Please sign in to comment.