forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
symbol_unet.R
81 lines (73 loc) · 3.55 KB
/
symbol_unet.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
library(mxnet)
convolution_module <- function(net, kernel_size, pad_size,
filter_count, stride = c(1, 1), work_space = 2048,
batch_norm = TRUE, down_pool = FALSE, up_pool = FALSE,
act_type = "relu", convolution = TRUE) {
if (up_pool) {
net = mx.symbol.Deconvolution(net, kernel = c(2, 2), pad = c(0, 0),
stride = c(2, 2), num_filter = filter_count, workspace = work_space)
net = mx.symbol.BatchNorm(net)
if (act_type != "") {
net = mx.symbol.Activation(net, act_type = act_type)
}
}
if (convolution) {
conv = mx.symbol.Convolution(data = net, kernel = kernel_size, stride = stride,
pad = pad_size, num_filter = filter_count, workspace = work_space)
net = conv
}
if (batch_norm) {
net = mx.symbol.BatchNorm(net)
}
if (act_type != "") {
net = mx.symbol.Activation(net, act_type = act_type)
}
if (down_pool) {
pool = mx.symbol.Pooling(net, pool_type = "max", kernel = c(2, 2), stride = c(2, 2))
net = pool
}
return(net)
}
get_symbol <- function(num_classes = 10) {
data = mx.symbol.Variable('data')
kernel_size = c(3, 3)
pad_size = c(1, 1)
filter_count = 32
pool1 = convolution_module(data, kernel_size, pad_size, filter_count = filter_count, down_pool = TRUE)
net = pool1
pool2 = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 2, down_pool = TRUE)
net = pool2
pool3 = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, down_pool = TRUE)
net = pool3
pool4 = convolution_module(net,
kernel_size,
pad_size,
filter_count = filter_count * 4,
down_pool = TRUE)
net = pool4
net = mx.symbol.Dropout(net)
pool5 = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 8, down_pool = TRUE)
net = pool5
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, up_pool = TRUE)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, up_pool = TRUE)
# dirty "CROP" to wanted size... I was on old MxNet branch so used conv instead of crop for cropping
net = convolution_module(net, c(4, 4), c(0, 0), filter_count = filter_count * 4)
net = mx.symbol.Concat(c(pool3, net), num.args = 2)
net = mx.symbol.Dropout(net)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, up_pool = TRUE)
net = mx.symbol.Concat(c(pool2, net), num.args = 2)
net = mx.symbol.Dropout(net)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4)
net = convolution_module(net, kernel_size, pad_size,
filter_count = filter_count * 4, up_pool = TRUE)
convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4)
net = mx.symbol.Concat(c(pool1, net), num.args = 2)
net = mx.symbol.Dropout(net)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 2)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 2, up_pool = TRUE)
net = mx.symbol.Flatten(net)
net = mx.symbol.FullyConnected(data = net, num_hidden = num_classes)
net = mx.symbol.SoftmaxOutput(data = net, name = 'softmax')
return(net)
}