Skip to content

Commit

Permalink
fix bug in nag optimizer (apache#13683)
Browse files Browse the repository at this point in the history
* fix bug in nag optimizer

```
grad += wd * weight
mom[:] += grad
grad[:] += self.momentum * mom
weight[:] += -lr * grad
```
This will minus wd*weight twice, but in`state = momentum * state + grad + wd * weight   weight = weight - (lr * (grad + momentum * state)) ` only minus once.

* fix bug in nag test

fix bug in nag test

* rewrite nag test

* rewrite nag

* fix nag with in-place operations

* fix nag with in-place operations
  • Loading branch information
solin319 authored and haohuw committed Jun 23, 2019
1 parent 8c14c36 commit 90e1318
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,10 @@ def update(self, index, weight, grad, state):
if state is not None:
mom = state
mom[:] *= self.momentum
grad += wd * weight
mom[:] += grad
mom[:] += wd * weight
grad[:] += self.momentum * mom
weight[:] += -lr * grad
weight[:] -= lr * grad
else:
assert self.momentum == 0.0
weight[:] += -lr * (grad + wd * weight)
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,10 @@ def update(self, index, weight, grad, state):
else:
mom = state
mom[:] *= self.momentum
grad += wd * weight
mom[:] += grad
mom[:] += wd * weight
grad[:] += self.momentum * mom
weight[:] += -lr * grad
weight[:] -= lr * grad
else:
grad32 = array(grad, ctx=grad.context, dtype=np.float32)
grad32 = grad32 * self.rescale_grad
Expand All @@ -400,10 +400,10 @@ def update(self, index, weight, grad, state):
weight32[:] += -lr * (grad32 + wd * weight32)
else:
mom[:] *= self.momentum
grad32 += wd * weight32
mom[:] += grad32
mom[:] += wd * weight32
grad32[:] += self.momentum * mom
weight32[:] += -lr * grad32
weight32[:] -= lr * grad32
tmp = weight32.astype(weight.dtype)
tmp.copyto(weight)

Expand Down

0 comments on commit 90e1318

Please sign in to comment.