Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

changing data type of 't' to int in lamb_update_phase1 #16903

Merged
merged 2 commits into from
Dec 7, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam>
float beta1;
float beta2;
float epsilon;
float t;
int t;
bool bias_correction;
float wd;
float rescale_grad;
Expand Down Expand Up @@ -1623,8 +1623,8 @@ struct LambUpdatePhaseOneKernel {
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2, const DType wd,
const DType epsilon, const DType t,
const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t,
const DType wd, const DType epsilon, const int t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

Expand All @@ -1639,8 +1639,8 @@ struct LambUpdatePhaseOneKernel {
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
DType mean_hat = mean_data[i] / (1. - beta1_t);
DType var_hat = var_data[i] / (1 - beta2_t);
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
}
KERNEL_ASSIGN(out_data[i], req, g);
Expand All @@ -1657,6 +1657,8 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
DType beta1_t = std::pow(param.beta1, param.t);
DType beta2_t = std::pow(param.beta2, param.t);
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Expand All @@ -1666,9 +1668,9 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.beta1), beta1_t, static_cast<DType>(param.beta2), beta2_t,
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
static_cast<int>(param.t), static_cast<bool>(param.bias_correction), req[0]);
});
}

Expand Down