Skip to content

Commit

Permalink
Fixed visualize (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
Arkoniak authored and pluskid committed Jan 7, 2017
1 parent 590055b commit 4d0aa87
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/visualize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ function to_graphviz(network :: SymbolicNode; title="Network Visualization", inp
attr = deepcopy(node_attr)
label = op

# Up to 0.8 version of mxnet additional info was stored in
# node["param"]. Staring from pre0.9 `param` was changed to `attr`.
if haskey(node, "param")
node_info = node["param"]
elseif haskey(node, "attr")
node_info = node["attr"]
end

if op == "null"
if i heads
# heads are output nodes
Expand All @@ -62,23 +70,23 @@ function to_graphviz(network :: SymbolicNode; title="Network Visualization", inp
end
elseif op == "Convolution"
label = format("Convolution\nkernel={1}\nstride={2}\nn-filter={3}",
_extract_shape(node["param"]["kernel"]),
_extract_shape(node["param"]["stride"]),
node["param"]["num_filter"])
_extract_shape(node_info["kernel"]),
_extract_shape(node_info["stride"]),
node_info["num_filter"])
colorkey = 2
elseif op == "FullyConnected"
label = format("FullyConnected\nnum-hidden={1}", node["param"]["num_hidden"])
label = format("FullyConnected\nnum-hidden={1}", node_info["num_hidden"])
colorkey = 2
elseif op == "Activation"
label = format("Activation\nact-type={1}", node["param"]["act_type"])
label = format("Activation\nact-type={1}", node_info["act_type"])
colorkey = 3
elseif op == "BatchNorm"
colorkey = 4
elseif op == "Pooling"
label = format("Pooling\ntype={1}\nkernel={2}\nstride={3}",
node["param"]["pool_type"],
_extract_shape(node["param"]["kernel"]),
_extract_shape(node["param"]["stride"]))
node_info["pool_type"],
_extract_shape(node_info["kernel"]),
_extract_shape(node_info["stride"]))
colorkey = 5
elseif op ("Concat", "Flatten", "Reshape")
colorkey = 6
Expand Down
34 changes: 34 additions & 0 deletions test/unittest/visualize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module TestVisualize
using MXNet
using Base.Test

using ..Main: mlp2

################################################################################
# Test Implementations
################################################################################

function test_basic()
info("Visualize::basic")

mlp = mlp2()

# Order of elements or default color values can change, but length of the output should be more or less stable
@test length(mx.to_graphviz(mlp)) == length(
"""
digraph "Network Visualization" {
node [fontsize=10];
edge [fontsize=10];
"fc1" [label="fc1\\nFullyConnected\\nnum-hidden=1000",style="rounded,filled",fixedsize=true,width=1.3,fillcolor="#fb8072",shape=box,penwidth=2,height=0.8034,color="#941305"];
"activation0" [label="activation0\\nActivation\\nact-type=relu",style="rounded,filled",fixedsize=true,width=1.3,fillcolor="#ffffb3",shape=box,penwidth=2,height=0.8034,color="#999900"];
"fc2" [label="fc2\\nFullyConnected\\nnum-hidden=10",style="rounded,filled",fixedsize=true,width=1.3,fillcolor="#fb8072",shape=box,penwidth=2,height=0.8034,color="#941305"];
"activation0" -> "fc1" [arrowtail=open,color="#737373",dir=back];
"fc2" -> "activation0" [arrowtail=open,color="#737373",dir=back];
}
""")
end
################################################################################
# Run tests
################################################################################
test_basic()
end

0 comments on commit 4d0aa87

Please sign in to comment.