Skip to content

Commit

Permalink
In-place updates for Nadam, Adadelta, Adamax and SGLD (apache#13960)
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudhacharya authored and stephenrawls committed Feb 16, 2019
1 parent 0bba2ab commit f51c8cf
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,8 +1091,9 @@ def update(self, index, weight, grad, state):
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), shape=weight.shape,
dtype=weight.dtype, ctx=weight.context)
weight[:] += - lr/2 * (grad + wd * weight)
weight[:] += normal(0, math.sqrt(lr), shape=weight.shape,
dtype=weight.dtype, ctx=weight.context)



Expand Down Expand Up @@ -1372,9 +1373,11 @@ def update(self, index, weight, grad, state):
acc_g, acc_delta = state

# update g, delta
acc_g[:] = self.rho * acc_g + (1. - self.rho) * grad * grad
acc_g[:] *= self.rho
acc_g[:] += (1. - self.rho) * grad * grad
current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad
acc_delta[:] = self.rho * acc_delta + (1. - self.rho) * current_delta * current_delta
acc_delta[:] *= self.rho
acc_delta[:] += (1. - self.rho) * current_delta * current_delta

# update weight
weight[:] -= current_delta + wd * weight
Expand Down Expand Up @@ -1507,7 +1510,8 @@ def update(self, index, weight, grad, state):

# update m_t and u_t
m_t, u_t = state
m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
m_t[:] *= self.beta1
m_t[:] += (1. - self.beta1) * grad
u_t[:] = maximum(self.beta2 * u_t, NDabs(grad))

# update weight
Expand Down Expand Up @@ -1570,8 +1574,10 @@ def update(self, index, weight, grad, state):

# update m_t and v_t
m_t, v_t = state
m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
v_t[:] = self.beta2 * v_t + (1. - self.beta2) * grad * grad
m_t[:] *= self.beta1
m_t[:] += (1. - self.beta1) * grad
v_t[:] *= self.beta2
v_t[:] += (1. - self.beta2) * grad * grad

grad_prime = grad / (1. - self.m_schedule)
m_t_prime = m_t / (1. - m_schedule_next)
Expand Down

0 comments on commit f51c8cf

Please sign in to comment.