Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
readable updates in unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudhacharya committed Sep 4, 2019
1 parent 9029fff commit 5f885e3
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,7 @@ def update(self, index, weight, grad, state):
weight[:] += -lr * (grad + wd * weight)
else:
mom = state
mom[:] *= self.momentum
weight[:] -= self.momentum * mom[:]
grad += wd * weight
grad *= lr
weight[:] -= (self.momentum + 1) * grad
weight[:] += (self.momentum**2 * mom) - lr*(self.momentum + 1)*(grad + wd*weight)
else:
grad32 = array(grad, ctx=grad.context, dtype=np.float32)
grad32 = grad32 * self.rescale_grad
Expand All @@ -399,11 +395,7 @@ def update(self, index, weight, grad, state):
if self.momentum == 0.0:
weight32[:] += -lr * (grad32 + wd * weight32)
else:
mom[:] *= self.momentum
weight32[:] -= self.momentum * mom[:]
grad32 += wd * weight32
grad32 *= lr
weight32[:] -= (self.momentum + 1) * grad32
weight32[:] += (self.momentum**2 * mom) - lr*(self.momentum+1)*(grad32 + wd*weight32)
tmp = weight32.astype(weight.dtype)
tmp.copyto(weight)

Expand Down

0 comments on commit 5f885e3

Please sign in to comment.