From 93146890e7984f0375d87ef417f7cad751488cca Mon Sep 17 00:00:00 2001 From: solin319 Date: Thu, 17 Jan 2019 04:16:21 +0800 Subject: [PATCH] fix bug in nag optimizer (#13683) * fix bug in nag optimizer ``` grad += wd * weight mom[:] += grad grad[:] += self.momentum * mom weight[:] += -lr * grad ``` This will minus wd*weight twice, but in`state = momentum * state + grad + wd * weight weight = weight - (lr * (grad + momentum * state)) ` only minus once. * fix bug in nag test fix bug in nag test * rewrite nag test * rewrite nag * fix nag with in-place operations * fix nag with in-place operations --- python/mxnet/optimizer/optimizer.py | 4 ++-- tests/python/unittest/test_optimizer.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index d290a3f2fea2..6ffbbcffc384 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -978,10 +978,10 @@ def update(self, index, weight, grad, state): if state is not None: mom = state mom[:] *= self.momentum - grad += wd * weight mom[:] += grad + mom[:] += wd * weight grad[:] += self.momentum * mom - weight[:] += -lr * grad + weight[:] -= lr * grad else: assert self.momentum == 0.0 weight[:] += -lr * (grad + wd * weight) diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 935bd9ab1823..3fdd1cd6bb87 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -385,10 +385,10 @@ def update(self, index, weight, grad, state): else: mom = state mom[:] *= self.momentum - grad += wd * weight mom[:] += grad + mom[:] += wd * weight grad[:] += self.momentum * mom - weight[:] += -lr * grad + weight[:] -= lr * grad else: grad32 = array(grad, ctx=grad.context, dtype=np.float32) grad32 = grad32 * self.rescale_grad @@ -400,10 +400,10 @@ def update(self, index, weight, grad, state): weight32[:] += -lr * (grad32 + wd * weight32) else: mom[:] *= self.momentum - grad32 += wd * weight32 mom[:] += grad32 + mom[:] += wd * weight32 grad32[:] += self.momentum * mom - weight32[:] += -lr * grad32 + weight32[:] -= lr * grad32 tmp = weight32.astype(weight.dtype) tmp.copyto(weight)