diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 3c03f8061d87..d22bd644d21f 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1066,21 +1066,18 @@ struct NAGMomKernel { const DType param_lr, const DType param_wd, const DType param_rescale_grad, const OpReqType req) { if (param_clip_gradient >= 0.0f) { - mom_data[i] = param_momentum*mom_data[i] - + mshadow_op::clip::Map(param_rescale_grad*grad_data[i], - param_clip_gradient) - + (param_wd*weight_data[i]); - KERNEL_ASSIGN(out_data[i], req, weight_data[i] - - param_lr*(param_momentum*mom_data[i] - + mshadow_op::clip::Map(param_rescale_grad*grad_data[i], - param_clip_gradient))); + mom_data[i] = param_momentum*mom_data[i]; + KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1) + *(mom_data[i]-(param_lr*(mshadow_op::clip::Map(param_rescale_grad + *grad_data[i], param_clip_gradient)+(param_wd*weight_data[i]))))); + mom_data[i] = mom_data[i] - (param_lr*((mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient))+(param_wd*weight_data[i]))); } else { - mom_data[i] = param_momentum*mom_data[i] - + param_rescale_grad*grad_data[i] - + (param_wd*weight_data[i]); - KERNEL_ASSIGN(out_data[i], req, weight_data[i] - - param_lr*(param_momentum*mom_data[i] - + param_rescale_grad*grad_data[i])); + mom_data[i] = param_momentum*mom_data[i]; + KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1) + *(mom_data[i]-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i])))); + mom_data[i] = mom_data[i] - param_lr*((param_rescale_grad*grad_data[i]) + +(param_wd*weight_data[i])); } } }; @@ -1119,22 +1116,21 @@ struct MP_NAGMomKernel { const OpReqType req) { float w = weight32[i]; if (param_clip_gradient >= 0.0f) { - mom_data[i] = param_momentum*mom_data[i] - + mshadow_op::clip::Map(param_rescale_grad - *static_cast(grad_data[i]), param_clip_gradient) - + (param_wd*w); - w = w - param_lr*(param_momentum*mom_data[i] - + mshadow_op::clip::Map(param_rescale_grad - *static_cast(grad_data[i]), - param_clip_gradient)); + mom_data[i] = param_momentum*mom_data[i]; + w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr + *(mshadow_op::clip::Map(param_rescale_grad*static_cast(grad_data[i]), + param_clip_gradient)+(param_wd*w))); + mom_data[i] = mom_data[i] - param_lr + *((mshadow_op::clip::Map(param_rescale_grad*static_cast(grad_data[i]), + param_clip_gradient))+(param_wd*w)); weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); } else { - mom_data[i] = param_momentum*mom_data[i] - + param_rescale_grad*static_cast(grad_data[i]) - + (param_wd*w); - w = w - param_lr*(param_momentum*mom_data[i] - + param_rescale_grad*static_cast(grad_data[i])); + mom_data[i] = param_momentum*mom_data[i]; + w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr + *(param_rescale_grad*static_cast(grad_data[i])+(param_wd*w))); + mom_data[i] = mom_data[i] - param_lr + *((param_rescale_grad*static_cast(grad_data[i]))+(param_wd*w)); weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); } diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 3e6cdd0997ce..cea469960f64 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -384,11 +384,8 @@ def update(self, index, weight, grad, state): weight[:] += -lr * (grad + wd * weight) else: mom = state - mom[:] *= self.momentum - mom[:] += grad - mom[:] += wd * weight - grad[:] += self.momentum * mom - weight[:] -= lr * grad + weight[:] += (self.momentum**2 * mom) - lr*(self.momentum + 1)*(grad + wd*weight) + mom[:] = (self.momentum*mom) - lr*(grad + wd*weight) else: grad32 = array(grad, ctx=grad.context, dtype=np.float32) grad32 = grad32 * self.rescale_grad @@ -399,11 +396,8 @@ def update(self, index, weight, grad, state): if self.momentum == 0.0: weight32[:] += -lr * (grad32 + wd * weight32) else: - mom[:] *= self.momentum - mom[:] += grad32 - mom[:] += wd * weight32 - grad32[:] += self.momentum * mom - weight32[:] -= lr * grad32 + weight32[:] += (self.momentum**2 * mom) - lr*(self.momentum+1)*(grad32 + wd*weight32) + mom[:] = (self.momentum*mom) - lr*(grad32 + wd*weight32) tmp = weight32.astype(weight.dtype) tmp.copyto(weight)