From 1a065e8cc747e443da672dd9a360ec96a95355a2 Mon Sep 17 00:00:00 2001 From: solin319 Date: Wed, 19 Dec 2018 11:51:02 +0800 Subject: [PATCH 1/6] fix bug in nag optimizer ``` grad += wd * weight mom[:] += grad grad[:] += self.momentum * mom weight[:] += -lr * grad ``` This will minus wd*weight twice, but in`state = momentum * state + grad + wd * weight weight = weight - (lr * (grad + momentum * state)) ` only minus once. --- python/mxnet/optimizer/optimizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index a085b6fe2ef6..430a6872d47c 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -974,8 +974,7 @@ def update(self, index, weight, grad, state): if state is not None: mom = state mom[:] *= self.momentum - grad += wd * weight - mom[:] += grad + mom[:] += grad + wd * weight grad[:] += self.momentum * mom weight[:] += -lr * grad else: From 641bf6c696e2b043116489cdf6df03853d00a94e Mon Sep 17 00:00:00 2001 From: solin319 Date: Wed, 19 Dec 2018 15:07:00 +0800 Subject: [PATCH 2/6] fix bug in nag test fix bug in nag test --- tests/python/unittest/test_optimizer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index acf24ee1b794..16fc1678f274 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -385,8 +385,7 @@ def update(self, index, weight, grad, state): else: mom = state mom[:] *= self.momentum - grad += wd * weight - mom[:] += grad + mom[:] += grad + wd * weight grad[:] += self.momentum * mom weight[:] += -lr * grad else: @@ -400,8 +399,7 @@ def update(self, index, weight, grad, state): weight32[:] += -lr * (grad32 + wd * weight32) else: mom[:] *= self.momentum - grad32 += wd * weight32 - mom[:] += grad32 + mom[:] += grad32 + wd * weight32 grad32[:] += self.momentum * mom weight32[:] += -lr * grad32 tmp = weight32.astype(weight.dtype) From ea63cb16731dc2a93c5fe7ebfda683404731d3b5 Mon Sep 17 00:00:00 2001 From: solin319 Date: Thu, 20 Dec 2018 10:09:02 +0800 Subject: [PATCH 3/6] rewrite nag test --- tests/python/unittest/test_optimizer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 16fc1678f274..85bbb65c2e67 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -384,10 +384,9 @@ def update(self, index, weight, grad, state): weight[:] += -lr * (grad + wd * weight) else: mom = state - mom[:] *= self.momentum - mom[:] += grad + wd * weight + mom[:] = self.momentum * mom[:] + grad + wd * weight grad[:] += self.momentum * mom - weight[:] += -lr * grad + weight[:] -= lr * grad else: grad32 = array(grad, ctx=grad.context, dtype=np.float32) grad32 = grad32 * self.rescale_grad @@ -398,10 +397,9 @@ def update(self, index, weight, grad, state): if self.momentum == 0.0: weight32[:] += -lr * (grad32 + wd * weight32) else: - mom[:] *= self.momentum - mom[:] += grad32 + wd * weight32 + mom[:] = self.momentum * mom[:] + grad32 + wd * weight32 grad32[:] += self.momentum * mom - weight32[:] += -lr * grad32 + weight32[:] -= lr * grad32 tmp = weight32.astype(weight.dtype) tmp.copyto(weight) From 7633a1dbc2ae808ab1192320ab4ccd69c11e8601 Mon Sep 17 00:00:00 2001 From: solin319 Date: Thu, 20 Dec 2018 10:11:10 +0800 Subject: [PATCH 4/6] rewrite nag --- python/mxnet/optimizer/optimizer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 430a6872d47c..ad79b78eb256 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -973,10 +973,9 @@ def update(self, index, weight, grad, state): if state is not None: mom = state - mom[:] *= self.momentum - mom[:] += grad + wd * weight + mom[:] = self.momentum * mom[:] + grad + wd * weight grad[:] += self.momentum * mom - weight[:] += -lr * grad + weight[:] -= lr * grad else: assert self.momentum == 0.0 weight[:] += -lr * (grad + wd * weight) From 510dd33b7f6e3feab84b048d35c2b38eec7bfd90 Mon Sep 17 00:00:00 2001 From: solin319 Date: Wed, 16 Jan 2019 22:12:53 +0800 Subject: [PATCH 5/6] fix nag with in-place operations --- tests/python/unittest/test_optimizer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 85bbb65c2e67..1a42b3903a03 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -384,7 +384,9 @@ def update(self, index, weight, grad, state): weight[:] += -lr * (grad + wd * weight) else: mom = state - mom[:] = self.momentum * mom[:] + grad + wd * weight + mom[:] *= self.momentum + mom[:] += grad + mom[:] += wd * weight grad[:] += self.momentum * mom weight[:] -= lr * grad else: @@ -397,7 +399,9 @@ def update(self, index, weight, grad, state): if self.momentum == 0.0: weight32[:] += -lr * (grad32 + wd * weight32) else: - mom[:] = self.momentum * mom[:] + grad32 + wd * weight32 + mom[:] *= self.momentum + mom[:] += grad32 + mom[:] += wd * weight32 grad32[:] += self.momentum * mom weight32[:] -= lr * grad32 tmp = weight32.astype(weight.dtype) From 0fa146619ca8e1f89591acdaff547c44cf4f2eb1 Mon Sep 17 00:00:00 2001 From: solin319 Date: Wed, 16 Jan 2019 22:15:15 +0800 Subject: [PATCH 6/6] fix nag with in-place operations --- python/mxnet/optimizer/optimizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index ad79b78eb256..501ad84d0b2e 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -973,7 +973,9 @@ def update(self, index, weight, grad, state): if state is not None: mom = state - mom[:] = self.momentum * mom[:] + grad + wd * weight + mom[:] *= self.momentum + mom[:] += grad + mom[:] += wd * weight grad[:] += self.momentum * mom weight[:] -= lr * grad else: