Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
bnorm: don't add bnorm in last layer
  • Loading branch information
jonghyuk0605 committed Sep 19, 2015
1 parent b55e7fb commit 99eaa72
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
5 changes: 4 additions & 1 deletion rcn.m
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
opts = vl_argparse(opts, varargin);

exp_name = 'exp';
if opts.useBnorm
exp_name = 'exp_bn';
end
for problem_iter = 1:numel(opts.problems)
problem = opts.problems{problem_iter};
switch problem.type
Expand All @@ -43,7 +46,7 @@
exp_name = sprintf('%s_resid%d_depth%d', exp_name, opts.resid, opts.depth);
opts.expDir = fullfile('data/exp',exp_name);

rep=20;
rep=25;
opts.learningRate = [0.1*ones(1,rep) 0.01*ones(1,rep) 0.001*ones(1,rep) 0.0001*ones(1,rep)];%*0.99 .^ (0:500);
opts.gradRange = 1e-4;
if ~exist('data/result', 'dir'), mkdir('data/result'); end
Expand Down
2 changes: 1 addition & 1 deletion rcn_init.m
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
% optionally switch to batch normalization
if opts.useBnorm
d = 1;
while d < numel(net.layers)
while d+1 < numel(net.layers)
if strcmp(net.layers{d}.type,'conv')
net = insertBnorm(net, d);
end
Expand Down
9 changes: 5 additions & 4 deletions rcn_train.m
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@
end
if opts.useGpu
imlow = gpuArray(imlow);
imhigh = gpuArray(imhigh);
end

% predict
Expand Down Expand Up @@ -440,13 +441,13 @@
end
switch evalType
case 'PSNR'
eval_base(problem_iter) = eval_base(problem_iter) + compute_psnr(imhigh,imlow);
eval_ours(problem_iter) = eval_ours(problem_iter) + compute_psnr(imhigh,impred);
eval_base(problem_iter) = eval_base(problem_iter) + gather(compute_psnr(imhigh,imlow));
eval_ours(problem_iter) = eval_ours(problem_iter) + gather(compute_psnr(imhigh,impred));
end

if printPic && f_n == 1
imwrite(imlow, strcat(problem.type,'_low.bmp'));
imwrite(impred, strcat(problem.type,'_pred.bmp'));
imwrite(gather(imlow), strcat(problem.type,'_low.bmp'));
imwrite(gather(impred), strcat(problem.type,'_pred.bmp'));
end
end
end
Expand Down

0 comments on commit 99eaa72

Please sign in to comment.