diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 3d76b33ae765..66bd4f3f3ba4 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -33,6 +33,7 @@ #include #include #include +#include #include "../operator_common.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" @@ -48,7 +49,6 @@ struct AdamWParam : public dmlc::Parameter { float epsilon; float wd; float eta; - float rescale_grad; float clip_gradient; DMLC_DECLARE_PARAMETER(AdamWParam) { DMLC_DECLARE_FIELD(lr) @@ -69,9 +69,6 @@ struct AdamWParam : public dmlc::Parameter { "The penalty scales with the square of the magnitude of each weight."); DMLC_DECLARE_FIELD(eta) .describe("Learning rate schedule multiplier"); - DMLC_DECLARE_FIELD(rescale_grad) - .set_default(1.0f) - .describe("Rescale gradient to grad = rescale_grad*grad."); DMLC_DECLARE_FIELD(clip_gradient) .set_default(-1.0f) .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " @@ -80,44 +77,138 @@ struct AdamWParam : public dmlc::Parameter { } }; +// rescale_grad is a reserved argument at position -1. Example: +// n_in = 2: weight, grad (fp16) +// n_out = 1: weight (fp16) +// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32) +template +inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + // rescale_grad.shape = (1,) + SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mshadow::Shape1(1)); + return ElemwiseAttr( + attrs, in_attrs, out_attrs, TShape()); +} + +// rescale_grad is a reserved argument at position -1. Example: +// n_in = 2: weight, grad (fp16) +// n_out = 1: weight (fp16) +// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32) +template +inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + for (int i = n_in; i < total_in; ++i) { + TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32); + } + return ElemwiseAttr( + attrs, in_attrs, out_attrs, -1); +} + +template +struct MPAdamWKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data, + float* var_data, const DType* weight_data, const DType* grad_data, float* weight32, + const float param_clip_gradient, const float param_beta1, const float param_beta2, + const float param_eta, const float param_lr, const float param_wd, + const float param_rescale_grad, const float param_epsilon) { + float w = weight32[i]; + float mean = mean_data[i]; + float var = var_data[i]; + float scaled_grad = param_rescale_grad*static_cast(grad_data[i]); + if (param_clip_gradient >= 0.0f) { + mean = param_beta1 * mean + + (1 - param_beta1) * mshadow_op::clip::Map(scaled_grad, param_clip_gradient); + var = param_beta2 * var + (1 - param_beta2) * + mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad, param_clip_gradient)); + } else { + mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad; + var = param_beta2 * var + (1 - param_beta2) * mshadow_op::square::Map(scaled_grad); + } + mean_data[i] = mean; + var_data[i] = var; + w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + + param_wd * w); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } +}; + + +template +struct MPAdamWUpdate { + static inline void Forward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs, + const float rescale_grad) { + using namespace mxnet_op; + AdamWParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor weight32 = inputs[4].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_, + var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1, + param.beta2, param.eta, param.lr, param.wd, rescale_grad, param.epsilon); + }); + }); + } +}; + /* * \brief adam_w update. */ template -inline void AdamWUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace mshadow_op; - const AdamWParam& param = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor mean = inputs[2].FlatTo2D(s); - Tensor var = inputs[3].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); +struct AdamWUpdate { + static inline void Forward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs, + const float rescale_grad) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const AdamWParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); - grad = scalar(param.rescale_grad) * grad; - if (param.clip_gradient >= 0.0f) { - mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * - F(grad, DType(param.clip_gradient)); - var = scalar(param.beta2)*var + scalar(1.f-param.beta2)*F( - F(grad, DType(param.clip_gradient))); - } else { - mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * grad; - var = scalar(param.beta2)*var + scalar(1.f-param.beta2) * F(grad); - } - Assign(out, req[0], - weight - - scalar(param.eta) * (scalar(param.lr) * - mean / (F(var) + scalar(param.epsilon)) + - (scalar(param.wd) * weight))); - }); -} + grad = scalar(rescale_grad) * grad; + if (param.clip_gradient >= 0.0f) { + mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * + F(grad, DType(param.clip_gradient)); + var = scalar(param.beta2)*var + scalar(1.f-param.beta2)*F( + F(grad, DType(param.clip_gradient))); + } else { + mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * grad; + var = scalar(param.beta2)*var + scalar(1.f-param.beta2) * F(grad); + } + Assign(out, req[0], + weight - + scalar(param.eta) * (scalar(param.lr) * + mean / (F(var) + scalar(param.epsilon)) + + (scalar(param.wd) * weight))); + }); + } +}; } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index 94623fe08a9e..2fbc39743c93 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -24,12 +24,76 @@ * \author Haibin Lin */ #include "./adamw-inl.h" +#include "../optimizer_op-inl.h" namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(AdamWParam); +template