diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 3d76b33ae765..28b68047c40e 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -80,6 +80,58 @@ struct AdamWParam : public dmlc::Parameter { } }; +struct MPAdamWKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data, + float* var_data, const DType* weight_data, const DType* grad_data, float* weight32, + const float param_clip_gradient, const float param_beta1, const float param_beta2, + const float param_eta, const float param_lr, const float param_wd, + const float param_rescale_grad, const float param_epsilon, const OpReqType req) { + float w = weight32[i]; + float mean = mean_data[i]; + float var = var_data[i]; + float scaled_grad = param_rescale_grad*static_cast(grad_data[i]); + if (param_clip_gradient >= 0.0f) { + mean = param_beta1 * mean + + (1 - param_beta1) * mshadow_op::clip::Map(scaled_grad, param_clip_gradient); + var = param_beta2 * var + (1 - param_beta2) * + mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad, param_clip_gradient)); + } else { + mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad; + var = param_beta2 * var + (1 - param_beta2) * mshadow_op::square::Map(scaled_grad); + } + mean_data[i] = mean; + var_data[i] = var; + w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + + param_wd * w); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } +}; + +template +inline void MPAdamWUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + AdamWParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor weight32 = inputs[4].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + //MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_, + var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1, + param.beta2, param.eta, param.lr, param.wd, param.rescale_grad, param.epsilon, req[0]); + }); +} + /* * \brief adam_w update. */ diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index 94623fe08a9e..3fa5b0c6a0f5 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -24,12 +24,33 @@ * \author Haibin Lin */ #include "./adamw-inl.h" +#include "../optimizer_op-inl.h" namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(AdamWParam); +NNVM_REGISTER_OP(_contrib_mp_adamw_update) +.describe("Update function for multi-precision AdamW optimizer") +.set_num_inputs(5) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<5, 1>) +// TODO rename: MP_SGD_InferType +.set_attr("FInferType", MP_SGD_InferType<2, 1, 5>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3, 4}; + }) +.set_attr("FCompute", MPAdamWUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mean", "NDArray-or-Symbol", "Moving mean") +.add_argument("var", "NDArray-or-Symbol", "Moving variance") +.add_argument("weight32", "NDArray-or-Symbol", "Weight32") +.add_arguments(AdamWParam::__FIELDS__()); + NNVM_REGISTER_OP(_contrib_adamw_update) .describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index b7452f861e2d..c4213569f1ba 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -31,5 +31,8 @@ namespace op { NNVM_REGISTER_OP(_contrib_adamw_update) .set_attr("FCompute", AdamWUpdate); +NNVM_REGISTER_OP(_contrib_mp_adamw_update) +.set_attr("FCompute", MPAdamWUpdate); + } // namespace op } // namespace mxnet