Skip to content

Commit 6dbffd5

Browse files
committed
Merge branch 'deep_supervised_rnn' of https://github.com/kjw0612/rcn into deep_supervised_rnn
2 parents 28a8608 + 094cd91 commit 6dbffd5

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

rcn_train_dag.m

+5-1
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,16 @@
275275
net.params(i).der = min(max(net.params(i).der, -opts.gradRange/mult), opts.gradRange/mult);
276276
thisDecay = opts.weightDecay * net.params(i).weightDecay ;
277277

278+
momentum_prev = state.momentum{i};
278279
state.momentum{i} = opts.momentum * state.momentum{i} ...
279280
- lr * net.params(i).learningRate * ...
280281
thisDecay * net.params(i).value ...
281282
- lr * net.params(i).learningRate * (1 / batchSize) * net.params(i).der ;
282283

283-
net.params(i).value = net.params(i).value + state.momentum{i};
284+
%Nesterov
285+
net.params(i).value = net.params(i).value ...
286+
- opts.momentum * momentum_prev ...
287+
+ (1 + opts.momentum) * state.momentum{i};
284288
end
285289

286290
% -------------------------------------------------------------------------

0 commit comments

Comments
 (0)