Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix bug in nag optimizer (#13683)
Browse files Browse the repository at this point in the history
* fix bug in nag optimizer

```
grad += wd * weight
mom[:] += grad
grad[:] += self.momentum * mom
weight[:] += -lr * grad
```
This will minus wd*weight twice, but in`state = momentum * state + grad + wd * weight   weight = weight - (lr * (grad + momentum * state)) ` only minus once.

* fix bug in nag test

fix bug in nag test

* rewrite nag test

* rewrite nag

* fix nag with in-place operations

* fix nag with in-place operations
  • Loading branch information
solin319 authored and szha committed Jan 16, 2019
1 parent 2616275 commit 9314689
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,10 @@ def update(self, index, weight, grad, state):
if state is not None:
mom = state
mom[:] *= self.momentum
grad += wd * weight
mom[:] += grad
mom[:] += wd * weight
grad[:] += self.momentum * mom
weight[:] += -lr * grad
weight[:] -= lr * grad
else:
assert self.momentum == 0.0
weight[:] += -lr * (grad + wd * weight)
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,10 @@ def update(self, index, weight, grad, state):
else:
mom = state
mom[:] *= self.momentum
grad += wd * weight
mom[:] += grad
mom[:] += wd * weight
grad[:] += self.momentum * mom
weight[:] += -lr * grad
weight[:] -= lr * grad
else:
grad32 = array(grad, ctx=grad.context, dtype=np.float32)
grad32 = grad32 * self.rescale_grad
Expand All @@ -400,10 +400,10 @@ def update(self, index, weight, grad, state):
weight32[:] += -lr * (grad32 + wd * weight32)
else:
mom[:] *= self.momentum
grad32 += wd * weight32
mom[:] += grad32
mom[:] += wd * weight32
grad32[:] += self.momentum * mom
weight32[:] += -lr * grad32
weight32[:] -= lr * grad32
tmp = weight32.astype(weight.dtype)
tmp.copyto(weight)

Expand Down

2 comments on commit 9314689

@chinakook
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this bug is fix, shall we change our learning rate when we are using nag?

@szha
Copy link
Member

@szha szha commented on 9314689 Jan 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should. Let's take the time before 1.5 to come up with recommendation regarding this change.

Please sign in to comment.