Skip to content

Commit

Permalink
mp adamw update
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 9, 2019
1 parent 2ff8ce0 commit bc07cfd
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/operator/contrib/adamw-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,58 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
}
};

struct MPAdamWKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data,
float* var_data, const DType* weight_data, const DType* grad_data, float* weight32,
const float param_clip_gradient, const float param_beta1, const float param_beta2,
const float param_eta, const float param_lr, const float param_wd,
const float param_rescale_grad, const float param_epsilon, const OpReqType req) {
float w = weight32[i];
float mean = mean_data[i];
float var = var_data[i];
float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]);
if (param_clip_gradient >= 0.0f) {
mean = param_beta1 * mean +
(1 - param_beta1) * mshadow_op::clip::Map(scaled_grad, param_clip_gradient);
var = param_beta2 * var + (1 - param_beta2) *
mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad, param_clip_gradient));
} else {
mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad;
var = param_beta2 * var + (1 - param_beta2) * mshadow_op::square::Map(scaled_grad);
}
mean_data[i] = mean;
var_data[i] = var;
w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
+ param_wd * w);
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
};

template<typename xpu>
inline void MPAdamWUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
AdamWParam param = nnvm::get<AdamWParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
//MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<MPAdamWKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_,
var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1,
param.beta2, param.eta, param.lr, param.wd, param.rescale_grad, param.epsilon, req[0]);
});
}

/*
* \brief adam_w update.
*/
Expand Down
21 changes: 21 additions & 0 deletions src/operator/contrib/adamw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,33 @@
* \author Haibin Lin
*/
#include "./adamw-inl.h"
#include "../optimizer_op-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(AdamWParam);

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
.describe("Update function for multi-precision AdamW optimizer")
.set_num_inputs(5)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamWParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
// TODO rename: MP_SGD_InferType
.set_attr<nnvm::FInferType>("FInferType", MP_SGD_InferType<2, 1, 5>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3, 4};
})
.set_attr<FCompute>("FCompute<cpu>", MPAdamWUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_arguments(AdamWParam::__FIELDS__());

NNVM_REGISTER_OP(_contrib_adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
Expand Down
3 changes: 3 additions & 0 deletions src/operator/contrib/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ namespace op {
NNVM_REGISTER_OP(_contrib_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", AdamWUpdate<gpu>);

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPAdamWUpdate<gpu>);

} // namespace op
} // namespace mxnet

0 comments on commit bc07cfd

Please sign in to comment.