From 80885711fc55018fcda44d5031e931c0ceeb6ddf Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Mon, 25 Nov 2019 19:30:04 +0000 Subject: [PATCH 1/2] changing data type of 't' to int in lamb_update_phase1 --- src/operator/optimizer_op-inl.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 698f7977a963..e28c061122bc 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1567,7 +1567,7 @@ struct LambUpdatePhaseOneParam : public dmlc::Parameter float beta1; float beta2; float epsilon; - float t; + int t; bool bias_correction; float wd; float rescale_grad; @@ -1624,7 +1624,7 @@ struct LambUpdatePhaseOneKernel { DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data, const DType clip_gradient, const DType rescale_grad, const DType beta1, const DType beta2, const DType wd, - const DType epsilon, const DType t, + const DType epsilon, const int t, bool bias_correction, const OpReqType req) { using namespace mshadow_op; @@ -1639,8 +1639,8 @@ struct LambUpdatePhaseOneKernel { DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i]; if (bias_correction) { - DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t)); - DType var_hat = var_data[i] / (1 - power::Map(beta2, t)); + DType mean_hat = mean_data[i] / (1. - std::pow(beta1, t)); + DType var_hat = var_data[i] / (1 - std::pow(beta2, t)); g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i]; } KERNEL_ASSIGN(out_data[i], req, g); @@ -1668,7 +1668,7 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, static_cast(param.clip_gradient), static_cast(param.rescale_grad), static_cast(param.beta1), static_cast(param.beta2), static_cast(param.wd), static_cast(param.epsilon), - static_cast(param.t), static_cast(param.bias_correction), req[0]); + static_cast(param.t), static_cast(param.bias_correction), req[0]); }); } From c5b221c6443ab6dd28e8850500ec1ed6c832668a Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Tue, 3 Dec 2019 00:24:29 +0000 Subject: [PATCH 2/2] taking operation beta^t out of kernel call --- src/operator/optimizer_op-inl.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index e28c061122bc..146e411b447c 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1623,8 +1623,8 @@ struct LambUpdatePhaseOneKernel { MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data, const DType clip_gradient, const DType rescale_grad, - const DType beta1, const DType beta2, const DType wd, - const DType epsilon, const int t, + const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t, + const DType wd, const DType epsilon, const int t, bool bias_correction, const OpReqType req) { using namespace mshadow_op; @@ -1639,8 +1639,8 @@ struct LambUpdatePhaseOneKernel { DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i]; if (bias_correction) { - DType mean_hat = mean_data[i] / (1. - std::pow(beta1, t)); - DType var_hat = var_data[i] / (1 - std::pow(beta2, t)); + DType mean_hat = mean_data[i] / (1. - beta1_t); + DType var_hat = var_data[i] / (1 - beta2_t); g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i]; } KERNEL_ASSIGN(out_data[i], req, g); @@ -1657,6 +1657,8 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, const LambUpdatePhaseOneParam& param = nnvm::get(attrs.parsed); Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + DType beta1_t = std::pow(param.beta1, param.t); + DType beta2_t = std::pow(param.beta2, param.t); Tensor weight = inputs[0].FlatTo2D(s); Tensor grad = inputs[1].FlatTo2D(s); Tensor mean = inputs[2].FlatTo2D(s); @@ -1666,7 +1668,7 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, Kernel::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, static_cast(param.clip_gradient), static_cast(param.rescale_grad), - static_cast(param.beta1), static_cast(param.beta2), + static_cast(param.beta1), beta1_t, static_cast(param.beta2), beta2_t, static_cast(param.wd), static_cast(param.epsilon), static_cast(param.t), static_cast(param.bias_correction), req[0]); });