Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Use In-place operator to prevent memory spikes in optimizer updates #13960

Merged
merged 1 commit into from
Feb 15, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,8 +1011,9 @@ def update(self, index, weight, grad, state):
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), shape=weight.shape,
dtype=weight.dtype, ctx=weight.context)
weight[:] += - lr/2 * (grad + wd * weight)
weight[:] += normal(0, math.sqrt(lr), shape=weight.shape,
dtype=weight.dtype, ctx=weight.context)



Expand Down Expand Up @@ -1292,9 +1293,11 @@ def update(self, index, weight, grad, state):
acc_g, acc_delta = state

# update g, delta
acc_g[:] = self.rho * acc_g + (1. - self.rho) * grad * grad
acc_g[:] *= self.rho
acc_g[:] += (1. - self.rho) * grad * grad
current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad
acc_delta[:] = self.rho * acc_delta + (1. - self.rho) * current_delta * current_delta
acc_delta[:] *= self.rho
acc_delta[:] += (1. - self.rho) * current_delta * current_delta

# update weight
weight[:] -= current_delta + wd * weight
Expand Down Expand Up @@ -1427,7 +1430,8 @@ def update(self, index, weight, grad, state):

# update m_t and u_t
m_t, u_t = state
m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
m_t[:] *= self.beta1
m_t[:] += (1. - self.beta1) * grad
u_t[:] = maximum(self.beta2 * u_t, NDabs(grad))

# update weight
Expand Down Expand Up @@ -1490,8 +1494,10 @@ def update(self, index, weight, grad, state):

# update m_t and v_t
m_t, v_t = state
m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
v_t[:] = self.beta2 * v_t + (1. - self.beta2) * grad * grad
m_t[:] *= self.beta1
m_t[:] += (1. - self.beta1) * grad
v_t[:] *= self.beta2
v_t[:] += (1. - self.beta2) * grad * grad

grad_prime = grad / (1. - self.m_schedule)
m_t_prime = m_t / (1. - m_schedule_next)
Expand Down