Skip to content

Commit

Permalink
add new update method of adam
Browse files Browse the repository at this point in the history
  • Loading branch information
dhc committed Dec 4, 2017
1 parent e3c539d commit 4328ad4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
5 changes: 3 additions & 2 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,13 @@ class Adam(Optimizer):
epsilon : float, optional
Small value to avoid division by 0.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,use_tusimple_update=False,
**kwargs):
super(Adam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.use_tusimple_update = use_tusimple_update

def create_state(self, index, weight):
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
Expand All @@ -579,7 +580,7 @@ def update(self, index, weight, grad, state):
lr *= math.sqrt(coef2)/coef1

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'rescale_grad': self.rescale_grad}
'rescale_grad': self.rescale_grad, 'use_tusimple_update':self.use_tusimple_update}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

Expand Down
2 changes: 1 addition & 1 deletion src/operator/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Operator* CreateOp<gpu>(ConvolutionParam param, int dtype,
backward_compute_type, ctx);
}
if (!convolutionIsSupported) {
LOG(WARNING) << "This convolution is not supported by cudnn, MXNET convolution is applied.";
// LOG(WARNING) << "This convolution is not supported by cudnn, MXNET convolution is applied.";
op = new ConvolutionOp<gpu, DType>(param);
} else {
if (forward_compute_type != desired_forward_compute_type)
Expand Down
24 changes: 20 additions & 4 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
float wd;
float rescale_grad;
float clip_gradient;
bool use_tusimple_update;
DMLC_DECLARE_PARAMETER(AdamParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
Expand All @@ -308,6 +309,9 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(use_tusimple_update)
.set_default(true)
.describe("whether use the gradient of weight decay when caculate mean & var");
}
};

Expand All @@ -328,10 +332,13 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(param.rescale_grad) * grad +
scalar<DType>(param.wd) * weight;

if (param.use_tusimple_update == 0){
grad = scalar<DType>(param.rescale_grad) * grad +
scalar<DType>(param.wd) * weight;
}
else{
grad = scalar<DType>(param.rescale_grad) * grad;
}
if (param.clip_gradient >= 0.0f) {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
F<clip>(grad, DType(param.clip_gradient));
Expand All @@ -341,10 +348,19 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
}
if (param.use_tusimple_update == 0){
Assign(out, req[0],
weight -
scalar<DType>(param.lr) * mean /
(F<square_root>(var) + scalar<DType>(param.epsilon)));
}
else{
Assign(out, req[0],
scalar<DType>(1.f-param.lr*param.wd)*weight -
scalar<DType>(param.lr) * mean /
(F<square_root>(var) + scalar<DType>(param.epsilon)));

}
});
}

Expand Down

0 comments on commit 4328ad4

Please sign in to comment.