Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix bug in nag optimizer #13683

Merged
merged 6 commits into from
Jan 16, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,11 +973,9 @@ def update(self, index, weight, grad, state):

if state is not None:
mom = state
mom[:] *= self.momentum
grad += wd * weight
mom[:] += grad
mom[:] = self.momentum * mom[:] + grad + wd * weight
grad[:] += self.momentum * mom
weight[:] += -lr * grad
weight[:] -= lr * grad
else:
assert self.momentum == 0.0
weight[:] += -lr * (grad + wd * weight)
Expand Down
12 changes: 4 additions & 8 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,9 @@ def update(self, index, weight, grad, state):
weight[:] += -lr * (grad + wd * weight)
else:
mom = state
mom[:] *= self.momentum
grad += wd * weight
mom[:] += grad
mom[:] = self.momentum * mom[:] + grad + wd * weight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try doing all these with in-place operators.

grad[:] += self.momentum * mom
weight[:] += -lr * grad
weight[:] -= lr * grad
else:
grad32 = array(grad, ctx=grad.context, dtype=np.float32)
grad32 = grad32 * self.rescale_grad
Expand All @@ -399,11 +397,9 @@ def update(self, index, weight, grad, state):
if self.momentum == 0.0:
weight32[:] += -lr * (grad32 + wd * weight32)
else:
mom[:] *= self.momentum
grad32 += wd * weight32
mom[:] += grad32
mom[:] = self.momentum * mom[:] + grad32 + wd * weight32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try doing all these with in-place operators.

grad32[:] += self.momentum * mom
weight32[:] += -lr * grad32
weight32[:] -= lr * grad32
tmp = weight32.astype(weight.dtype)
tmp.copyto(weight)

Expand Down