-
Notifications
You must be signed in to change notification settings - Fork 3
/
rcn_init.m
61 lines (59 loc) · 2.49 KB
/
rcn_init.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
function net = rcn_init(opts)
% define net
net.layers = {} ;
if opts.depth > 1
net.layers{end+1} = struct('type', 'conv', ...
'filters',sqrt(2/9)*randn(3,3,1,opts.filterSize, 'single'), ...
'biases', zeros(1, opts.filterSize, 'single'), ...
'stride', 1, ...
'pad', 1);
net.layers{end+1} = struct('type', 'relu');
end
for i=1:opts.depth - 2
net.layers{end+1} = struct('type', 'conv', ...
'filters', sqrt(2/9/opts.filterSize)*randn(3,3,opts.filterSize,opts.filterSize, 'single'), ...
'biases', zeros(1, opts.filterSize, 'single'), ...
'stride', 1, ...
'pad', 1) ;
net.layers{end+1} = struct('type', 'relu');
end
if opts.resid
bias_diff = 0; %if diff, it's centered aroun zero
else
bias_diff = 0.5; % if not diff, it's centered around 0.5 which is the average DC.
end;
if opts.depth > 1
net.layers{end+1} = struct('type', 'conv', ...
'filters', 0.001*sqrt(2/9/opts.filterSize)*randn(3,3,opts.filterSize,1, 'single'),...
'biases', bias_diff + zeros(1,1,'single'), ...
'stride', 1, ...
'pad', 1);
else
net.layers{end+1} = struct('type', 'conv', ...
'filters', 0.001*sqrt(2/9/1)*randn(3,3,1,1, 'single'),...
'biases', bias_diff + zeros(1,1,'single'), ...
'stride', 1, ...
'pad', 1);
end
net.layers{end+1} = struct('type', 'euclidloss') ;
% optionally switch to batch normalization
if opts.useBnorm
d = 1;
while d+1 < numel(net.layers)
if strcmp(net.layers{d}.type,'conv')
net = insertBnorm(net, d);
end
d = d + 1;
end
end
% --------------------------------------------------------------------
function net = insertBnorm(net, l)
% --------------------------------------------------------------------
assert(isfield(net.layers{l}, 'filters'));
ndim = size(net.layers{l}.filters, 4);
layer = struct('type', 'bnorm', ...
'weights', {{ones(ndim, 1, 'single'), zeros(ndim, 1, 'single')}}, ...
'learningRate', [1 1], ...
'weightDecay', [0 0]) ;
net.layers{l}.biases = [] ;
net.layers = horzcat(net.layers(1:l), layer, net.layers(l+1:end)) ;