Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
changing data type of 't' to int in lamb_update_phase1
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Nov 25, 2019
1 parent 6bff547 commit 04813e7
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam>
float beta1;
float beta2;
float epsilon;
float t;
int t;
bool bias_correction;
float wd;
float rescale_grad;
Expand Down Expand Up @@ -1624,7 +1624,7 @@ struct LambUpdatePhaseOneKernel {
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2, const DType wd,
const DType epsilon, const DType t,
const DType epsilon, const int t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

Expand All @@ -1639,8 +1639,8 @@ struct LambUpdatePhaseOneKernel {
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
DType mean_hat = mean_data[i] / (1. - std::pow(beta1, t));
DType var_hat = var_data[i] / (1 - std::pow(beta2, t));
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
}
KERNEL_ASSIGN(out_data[i], req, g);
Expand Down Expand Up @@ -1668,7 +1668,7 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
static_cast<int>(param.t), static_cast<bool>(param.bias_correction), req[0]);
});
}

Expand Down

0 comments on commit 04813e7

Please sign in to comment.