Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
taking operation beta^t out of kernel call
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Dec 3, 2019
1 parent 8088571 commit c5b221c
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1623,8 +1623,8 @@ struct LambUpdatePhaseOneKernel {
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2, const DType wd,
const DType epsilon, const int t,
const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t,
const DType wd, const DType epsilon, const int t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

Expand All @@ -1639,8 +1639,8 @@ struct LambUpdatePhaseOneKernel {
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - std::pow(beta1, t));
DType var_hat = var_data[i] / (1 - std::pow(beta2, t));
DType mean_hat = mean_data[i] / (1. - beta1_t);
DType var_hat = var_data[i] / (1 - beta2_t);
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
}
KERNEL_ASSIGN(out_data[i], req, g);
Expand All @@ -1657,6 +1657,8 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
DType beta1_t = std::pow(param.beta1, param.t);
DType beta2_t = std::pow(param.beta2, param.t);
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Expand All @@ -1666,7 +1668,7 @@ inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.beta1), beta1_t, static_cast<DType>(param.beta2), beta2_t,
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<int>(param.t), static_cast<bool>(param.bias_correction), req[0]);
});
Expand Down

0 comments on commit c5b221c

Please sign in to comment.