diff --git a/src/visualize.jl b/src/visualize.jl index f0dd74efdc67..42d31a22d69e 100644 --- a/src/visualize.jl +++ b/src/visualize.jl @@ -51,6 +51,14 @@ function to_graphviz(network :: SymbolicNode; title="Network Visualization", inp attr = deepcopy(node_attr) label = op + # Up to 0.8 version of mxnet additional info was stored in + # node["param"]. Staring from pre0.9 `param` was changed to `attr`. + if haskey(node, "param") + node_info = node["param"] + elseif haskey(node, "attr") + node_info = node["attr"] + end + if op == "null" if i ∈ heads # heads are output nodes @@ -62,23 +70,23 @@ function to_graphviz(network :: SymbolicNode; title="Network Visualization", inp end elseif op == "Convolution" label = format("Convolution\nkernel={1}\nstride={2}\nn-filter={3}", - _extract_shape(node["param"]["kernel"]), - _extract_shape(node["param"]["stride"]), - node["param"]["num_filter"]) + _extract_shape(node_info["kernel"]), + _extract_shape(node_info["stride"]), + node_info["num_filter"]) colorkey = 2 elseif op == "FullyConnected" - label = format("FullyConnected\nnum-hidden={1}", node["param"]["num_hidden"]) + label = format("FullyConnected\nnum-hidden={1}", node_info["num_hidden"]) colorkey = 2 elseif op == "Activation" - label = format("Activation\nact-type={1}", node["param"]["act_type"]) + label = format("Activation\nact-type={1}", node_info["act_type"]) colorkey = 3 elseif op == "BatchNorm" colorkey = 4 elseif op == "Pooling" label = format("Pooling\ntype={1}\nkernel={2}\nstride={3}", - node["param"]["pool_type"], - _extract_shape(node["param"]["kernel"]), - _extract_shape(node["param"]["stride"])) + node_info["pool_type"], + _extract_shape(node_info["kernel"]), + _extract_shape(node_info["stride"])) colorkey = 5 elseif op ∈ ("Concat", "Flatten", "Reshape") colorkey = 6 diff --git a/test/unittest/visualize.jl b/test/unittest/visualize.jl new file mode 100644 index 000000000000..973c2b7034d0 --- /dev/null +++ b/test/unittest/visualize.jl @@ -0,0 +1,34 @@ +module TestVisualize +using MXNet +using Base.Test + +using ..Main: mlp2 + +################################################################################ +# Test Implementations +################################################################################ + +function test_basic() + info("Visualize::basic") + + mlp = mlp2() + + # Order of elements or default color values can change, but length of the output should be more or less stable + @test length(mx.to_graphviz(mlp)) == length( +""" +digraph "Network Visualization" { +node [fontsize=10]; +edge [fontsize=10]; +"fc1" [label="fc1\\nFullyConnected\\nnum-hidden=1000",style="rounded,filled",fixedsize=true,width=1.3,fillcolor="#fb8072",shape=box,penwidth=2,height=0.8034,color="#941305"]; +"activation0" [label="activation0\\nActivation\\nact-type=relu",style="rounded,filled",fixedsize=true,width=1.3,fillcolor="#ffffb3",shape=box,penwidth=2,height=0.8034,color="#999900"]; +"fc2" [label="fc2\\nFullyConnected\\nnum-hidden=10",style="rounded,filled",fixedsize=true,width=1.3,fillcolor="#fb8072",shape=box,penwidth=2,height=0.8034,color="#941305"]; +"activation0" -> "fc1" [arrowtail=open,color="#737373",dir=back]; +"fc2" -> "activation0" [arrowtail=open,color="#737373",dir=back]; +} +""") +end +################################################################################ +# Run tests +################################################################################ +test_basic() +end