Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
mom update
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudhacharya committed Sep 8, 2019
1 parent 0a52a98 commit a3500d7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1069,11 +1069,15 @@ struct NAGMomKernel {
mom_data[i] = param_momentum*mom_data[i];
KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1)
*(mom_data[i]-(param_lr*(mshadow_op::clip::Map(param_rescale_grad
*grad_data[i],param_clip_gradient)+(param_wd*weight_data[i])))));
*grad_data[i], param_clip_gradient)+(param_wd*weight_data[i])))));
mom_data[i] = mom_data[i] - (param_lr*((mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient))+(param_wd*weight_data[i])));
} else {
mom_data[i] = param_momentum*mom_data[i];
KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1)
*(mom_data[i]-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i]))));
mom_data[i] = mom_data[i] - param_lr*((param_rescale_grad*grad_data[i])
+(param_wd*weight_data[i]));
}
}
};
Expand Down Expand Up @@ -1116,12 +1120,17 @@ struct MP_NAGMomKernel {
w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr
*(mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
param_clip_gradient)+(param_wd*w)));
mom_data[i] = mom_data[i] - param_lr
*((mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
param_clip_gradient))+(param_wd*w));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
} else {
mom_data[i] = param_momentum*mom_data[i];
w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr
*(param_rescale_grad*static_cast<float>(grad_data[i])+(param_wd*w)));
mom_data[i] = mom_data[i] - param_lr
*((param_rescale_grad*static_cast<float>(grad_data[i]))+(param_wd*w));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def update(self, index, weight, grad, state):
else:
mom = state
weight[:] += (self.momentum**2 * mom) - lr*(self.momentum + 1)*(grad + wd*weight)
mom[:] = (self.momentum*mom) - lr*(grad + wd*weight)
else:
grad32 = array(grad, ctx=grad.context, dtype=np.float32)
grad32 = grad32 * self.rescale_grad
Expand All @@ -396,6 +397,7 @@ def update(self, index, weight, grad, state):
weight32[:] += -lr * (grad32 + wd * weight32)
else:
weight32[:] += (self.momentum**2 * mom) - lr*(self.momentum+1)*(grad32 + wd*weight32)
mom[:] = (self.momentum*mom) - lr*(grad32 + wd*weight32)
tmp = weight32.astype(weight.dtype)
tmp.copyto(weight)

Expand Down

0 comments on commit a3500d7

Please sign in to comment.