Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Multi-precision AdamW update op #14171

Merged
merged 8 commits into from
Feb 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 128 additions & 37 deletions src/operator/contrib/adamw-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include <cmath>
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
Expand All @@ -48,7 +49,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
float epsilon;
float wd;
float eta;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(AdamWParam) {
DMLC_DECLARE_FIELD(lr)
Expand All @@ -69,9 +69,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(eta)
.describe("Learning rate schedule multiplier");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
Expand All @@ -80,44 +77,138 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
}
};

// rescale_grad is a reserved argument at position -1. Example:
// n_in = 2: weight, grad (fp16)
// n_out = 1: weight (fp16)
// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32)
template<int n_in, int n_out, int total_in>
inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
// rescale_grad.shape = (1,)
SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mshadow::Shape1(1));
return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string, n_in, n_out>(
attrs, in_attrs, out_attrs, TShape());
}

// rescale_grad is a reserved argument at position -1. Example:
// n_in = 2: weight, grad (fp16)
// n_out = 1: weight (fp16)
// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32)
template<int n_in, int n_out, int total_in>
inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
for (int i = n_in; i < total_in; ++i) {
TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
}
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, n_out>(
attrs, in_attrs, out_attrs, -1);
}

template<int req>
struct MPAdamWKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data,
float* var_data, const DType* weight_data, const DType* grad_data, float* weight32,
const float param_clip_gradient, const float param_beta1, const float param_beta2,
const float param_eta, const float param_lr, const float param_wd,
const float param_rescale_grad, const float param_epsilon) {
float w = weight32[i];
float mean = mean_data[i];
float var = var_data[i];
float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]);
if (param_clip_gradient >= 0.0f) {
mean = param_beta1 * mean +
(1 - param_beta1) * mshadow_op::clip::Map(scaled_grad, param_clip_gradient);
var = param_beta2 * var + (1 - param_beta2) *
mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad, param_clip_gradient));
} else {
mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad;
var = param_beta2 * var + (1 - param_beta2) * mshadow_op::square::Map(scaled_grad);
}
mean_data[i] = mean;
var_data[i] = var;
w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
+ param_wd * w);
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
};


template<typename xpu>
struct MPAdamWUpdate {
static inline void Forward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs,
const float rescale_grad) {
using namespace mxnet_op;
AdamWParam param = nnvm::get<AdamWParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<MPAdamWKernel<req_type>, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_,
var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1,
param.beta2, param.eta, param.lr, param.wd, rescale_grad, param.epsilon);
});
});
}
};

/*
* \brief adam_w update.
*/
template<typename xpu>
inline void AdamWUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
struct AdamWUpdate {
static inline void Forward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs,
const float rescale_grad) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(param.rescale_grad) * grad;
if (param.clip_gradient >= 0.0f) {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
F<clip>(grad, DType(param.clip_gradient));
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>(
F<clip>(grad, DType(param.clip_gradient)));
} else {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
}
Assign(out, req[0],
weight -
scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
(scalar<DType>(param.wd) * weight)));
});
}
grad = scalar<DType>(rescale_grad) * grad;
if (param.clip_gradient >= 0.0f) {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
F<clip>(grad, DType(param.clip_gradient));
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>(
F<clip>(grad, DType(param.clip_gradient)));
} else {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
}
Assign(out, req[0],
weight -
scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
(scalar<DType>(param.wd) * weight)));
});
}
};

} // namespace op
} // namespace mxnet
Expand Down
76 changes: 72 additions & 4 deletions src/operator/contrib/adamw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,76 @@
* \author Haibin Lin
*/
#include "./adamw-inl.h"
#include "../optimizer_op-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(AdamWParam);

template<template <typename xpu> class F>
inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
// copy to cpu and check NaN value
TBlob scale_blob = inputs[inputs.size() - 1];
MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
float scalef = static_cast<float>(*scale_blob.dptr<DType>());
if (!std::isfinite(scalef) || scalef == 0) return;
std::vector<TBlob> inputs_wo_scale;
size_t num_in = inputs.size();
inputs_wo_scale.reserve(num_in - 1);
for (size_t i = 0; i < num_in - 1; i++) inputs_wo_scale.emplace_back(inputs[i]);
F<cpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
});
}

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
.describe(R"code(Update function for multi-precision AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight decay from the
optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v
are 1st and 2nd order moment estimates (mean and variance).

.. math::

g_t = \nabla J(W_{t-1})\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})

It updates the weights using::

m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
the update is skipped.
)code" ADD_FILELINE)
.set_num_inputs(6)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamWParam>)
.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<2, 1, 6>)
.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<2, 1, 6>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3, 4};
})
.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<MPAdamWUpdate>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

NNVM_REGISTER_OP(_contrib_adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
Expand All @@ -50,21 +114,25 @@ It updates the weights using::
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
the update is skipped.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_inputs(5)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamWParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<4, 1, 5>)
.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<4, 1, 5>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", AdamWUpdate<cpu>)
.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<AdamWUpdate>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

} // namespace op
Expand Down
27 changes: 26 additions & 1 deletion src/operator/contrib/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,33 @@
namespace mxnet {
namespace op {

template<template <typename xpu> class F>
inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
// copy to cpu and check NaN value
TBlob scale_blob = inputs[inputs.size() - 1];
MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
DType scale = 0;
CUDA_CALL(cudaMemcpy(&scale, scale_blob.dptr<DType>(), sizeof(DType),
cudaMemcpyDeviceToHost));
float scalef = static_cast<float>(scale);
if (!std::isfinite(scalef) || scalef == 0) return;
std::vector<TBlob> inputs_wo_scale;
size_t num_in = inputs.size();
inputs_wo_scale.reserve(num_in - 1);
for (size_t i = 0; i < num_in - 1; i++) inputs_wo_scale.emplace_back(inputs[i]);
F<gpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
});
}

NNVM_REGISTER_OP(_contrib_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", AdamWUpdate<gpu>);
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>);

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>);

} // namespace op
} // namespace mxnet
Loading