From 9c21d96fa60603306daf805cd115e7ce6a914454 Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Wed, 4 Sep 2019 21:41:27 +0000 Subject: [PATCH] mom update --- src/operator/optimizer_op-inl.h | 11 ++++++++++- tests/python/unittest/test_optimizer.py | 2 ++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index adffbd3a430e..08a6832879e9 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1069,11 +1069,15 @@ struct NAGMomKernel { mom_data[i] = param_momentum*mom_data[i]; KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1) *(mom_data[i]-(param_lr*(mshadow_op::clip::Map(param_rescale_grad - *grad_data[i],param_clip_gradient)+(param_wd*weight_data[i]))))); + *grad_data[i], param_clip_gradient)+(param_wd*weight_data[i]))))); + mom_data[i] = mom_data[i] - (param_lr*((mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient))+(param_wd*weight_data[i]))); } else { mom_data[i] = param_momentum*mom_data[i]; KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1) *(mom_data[i]-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i])))); + mom_data[i] = mom_data[i] - param_lr*((param_rescale_grad*grad_data[i]) + +(param_wd*weight_data[i])); } } }; @@ -1116,12 +1120,17 @@ struct MP_NAGMomKernel { w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr *(mshadow_op::clip::Map(param_rescale_grad*static_cast(grad_data[i]), param_clip_gradient)+(param_wd*w))); + mom_data[i] = mom_data[i] - param_lr + *((mshadow_op::clip::Map(param_rescale_grad*static_cast(grad_data[i]), + param_clip_gradient))+(param_wd*w)); weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); } else { mom_data[i] = param_momentum*mom_data[i]; w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr *(param_rescale_grad*static_cast(grad_data[i])+(param_wd*w))); + mom_data[i] = mom_data[i] - param_lr + *((param_rescale_grad*static_cast(grad_data[i]))+(param_wd*w)); weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); } diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index b30f0f36950b..cea469960f64 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -385,6 +385,7 @@ def update(self, index, weight, grad, state): else: mom = state weight[:] += (self.momentum**2 * mom) - lr*(self.momentum + 1)*(grad + wd*weight) + mom[:] = (self.momentum*mom) - lr*(grad + wd*weight) else: grad32 = array(grad, ctx=grad.context, dtype=np.float32) grad32 = grad32 * self.rescale_grad @@ -396,6 +397,7 @@ def update(self, index, weight, grad, state): weight32[:] += -lr * (grad32 + wd * weight32) else: weight32[:] += (self.momentum**2 * mom) - lr*(self.momentum+1)*(grad32 + wd*weight32) + mom[:] = (self.momentum*mom) - lr*(grad32 + wd*weight32) tmp = weight32.astype(weight.dtype) tmp.copyto(weight)