Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[OP] changing data type of 't' to int in lamb_update_phase1 (#16903)
Browse files Browse the repository at this point in the history
* changing data type of 't' to int in lamb_update_phase1

* taking operation beta^t out of kernel call
  • Loading branch information
access2rohit authored and eric-haibin-lin committed Dec 10, 2019
1 parent c7d484e commit ff27b4b
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam>
float beta1;
float beta2;
float epsilon;
float t;
int t;
bool bias_correction;
float wd;
float rescale_grad;
Expand Down Expand Up @@ -1623,8 +1623,8 @@ struct LambUpdatePhaseOneKernel {
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2, const DType wd,
const DType epsilon, const DType t,
const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t,
const DType wd, const DType epsilon, const int t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

Expand All @@ -1639,8 +1639,8 @@ struct LambUpdatePhaseOneKernel {
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
DType mean_hat = mean_data[i] / (1. - beta1_t);
DType var_hat = var_data[i] / (1 - beta2_t);
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
}
KERNEL_ASSIGN(out_data[i], req, g);
Expand All @@ -1657,6 +1657,8 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
DType beta1_t = std::pow(param.beta1, param.t);
DType beta2_t = std::pow(param.beta2, param.t);
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Expand All @@ -1666,9 +1668,9 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.beta1), beta1_t, static_cast<DType>(param.beta2), beta2_t,
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
static_cast<int>(param.t), static_cast<bool>(param.bias_correction), req[0]);
});
}

Expand Down

0 comments on commit ff27b4b

Please sign in to comment.