diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 698f7977a963..146e411b447c 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1567,7 +1567,7 @@ struct LambUpdatePhaseOneParam : public dmlc::Parameter float beta1; float beta2; float epsilon; - float t; + int t; bool bias_correction; float wd; float rescale_grad; @@ -1623,8 +1623,8 @@ struct LambUpdatePhaseOneKernel { MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data, const DType clip_gradient, const DType rescale_grad, - const DType beta1, const DType beta2, const DType wd, - const DType epsilon, const DType t, + const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t, + const DType wd, const DType epsilon, const int t, bool bias_correction, const OpReqType req) { using namespace mshadow_op; @@ -1639,8 +1639,8 @@ struct LambUpdatePhaseOneKernel { DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i]; if (bias_correction) { - DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t)); - DType var_hat = var_data[i] / (1 - power::Map(beta2, t)); + DType mean_hat = mean_data[i] / (1. - beta1_t); + DType var_hat = var_data[i] / (1 - beta2_t); g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i]; } KERNEL_ASSIGN(out_data[i], req, g); @@ -1657,6 +1657,8 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, const LambUpdatePhaseOneParam& param = nnvm::get(attrs.parsed); Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + DType beta1_t = std::pow(param.beta1, param.t); + DType beta2_t = std::pow(param.beta2, param.t); Tensor weight = inputs[0].FlatTo2D(s); Tensor grad = inputs[1].FlatTo2D(s); Tensor mean = inputs[2].FlatTo2D(s); @@ -1666,9 +1668,9 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, Kernel::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, static_cast(param.clip_gradient), static_cast(param.rescale_grad), - static_cast(param.beta1), static_cast(param.beta2), + static_cast(param.beta1), beta1_t, static_cast(param.beta2), beta2_t, static_cast(param.wd), static_cast(param.epsilon), - static_cast(param.t), static_cast(param.bias_correction), req[0]); + static_cast(param.t), static_cast(param.bias_correction), req[0]); }); }