Skip to content

Commit

Permalink
Multi-precision AdamW update op (apache#14171)
Browse files Browse the repository at this point in the history
* mp adamw update

* Softmax fp16 (#201)

* softmax for fp16 with fp32 accumulator

* return AType in kernel

* add dtype

* kernel

* adamw with nan check

* add doc

* Revert "Softmax fp16 (#201)"

This reverts commit 5869e0a.

* add test

* more test for fp16

* skip update for rescale = 0
  • Loading branch information
eric-haibin-lin authored Feb 20, 2019
1 parent b396a63 commit 349803c
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 42 deletions.
165 changes: 128 additions & 37 deletions src/operator/contrib/adamw-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include <cmath>
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
Expand All @@ -48,7 +49,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
float epsilon;
float wd;
float eta;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(AdamWParam) {
DMLC_DECLARE_FIELD(lr)
Expand All @@ -69,9 +69,6 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(eta)
.describe("Learning rate schedule multiplier");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
Expand All @@ -80,44 +77,138 @@ struct AdamWParam : public dmlc::Parameter<AdamWParam> {
}
};

// rescale_grad is a reserved argument at position -1. Example:
// n_in = 2: weight, grad (fp16)
// n_out = 1: weight (fp16)
// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32)
template<int n_in, int n_out, int total_in>
inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
// rescale_grad.shape = (1,)
SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mshadow::Shape1(1));
return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string, n_in, n_out>(
attrs, in_attrs, out_attrs, TShape());
}

// rescale_grad is a reserved argument at position -1. Example:
// n_in = 2: weight, grad (fp16)
// n_out = 1: weight (fp16)
// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32)
template<int n_in, int n_out, int total_in>
inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
for (int i = n_in; i < total_in; ++i) {
TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
}
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, n_out>(
attrs, in_attrs, out_attrs, -1);
}

template<int req>
struct MPAdamWKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data,
float* var_data, const DType* weight_data, const DType* grad_data, float* weight32,
const float param_clip_gradient, const float param_beta1, const float param_beta2,
const float param_eta, const float param_lr, const float param_wd,
const float param_rescale_grad, const float param_epsilon) {
float w = weight32[i];
float mean = mean_data[i];
float var = var_data[i];
float scaled_grad = param_rescale_grad*static_cast<float>(grad_data[i]);
if (param_clip_gradient >= 0.0f) {
mean = param_beta1 * mean +
(1 - param_beta1) * mshadow_op::clip::Map(scaled_grad, param_clip_gradient);
var = param_beta2 * var + (1 - param_beta2) *
mshadow_op::square::Map(mshadow_op::clip::Map(scaled_grad, param_clip_gradient));
} else {
mean = param_beta1 * mean + (1 - param_beta1) * scaled_grad;
var = param_beta2 * var + (1 - param_beta2) * mshadow_op::square::Map(scaled_grad);
}
mean_data[i] = mean;
var_data[i] = var;
w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)
+ param_wd * w);
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
};


template<typename xpu>
struct MPAdamWUpdate {
static inline void Forward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs,
const float rescale_grad) {
using namespace mxnet_op;
AdamWParam param = nnvm::get<AdamWParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<MPAdamWKernel<req_type>, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_,
var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1,
param.beta2, param.eta, param.lr, param.wd, rescale_grad, param.epsilon);
});
});
}
};

/*
* \brief adam_w update.
*/
template<typename xpu>
inline void AdamWUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
struct AdamWUpdate {
static inline void Forward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs,
const float rescale_grad) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
const AdamWParam& param = nnvm::get<AdamWParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(param.rescale_grad) * grad;
if (param.clip_gradient >= 0.0f) {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
F<clip>(grad, DType(param.clip_gradient));
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>(
F<clip>(grad, DType(param.clip_gradient)));
} else {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
}
Assign(out, req[0],
weight -
scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
(scalar<DType>(param.wd) * weight)));
});
}
grad = scalar<DType>(rescale_grad) * grad;
if (param.clip_gradient >= 0.0f) {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
F<clip>(grad, DType(param.clip_gradient));
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>(
F<clip>(grad, DType(param.clip_gradient)));
} else {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
}
Assign(out, req[0],
weight -
scalar<DType>(param.eta) * (scalar<DType>(param.lr) *
mean / (F<square_root>(var) + scalar<DType>(param.epsilon)) +
(scalar<DType>(param.wd) * weight)));
});
}
};

} // namespace op
} // namespace mxnet
Expand Down
76 changes: 72 additions & 4 deletions src/operator/contrib/adamw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,76 @@
* \author Haibin Lin
*/
#include "./adamw-inl.h"
#include "../optimizer_op-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(AdamWParam);

template<template <typename xpu> class F>
inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
// copy to cpu and check NaN value
TBlob scale_blob = inputs[inputs.size() - 1];
MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
float scalef = static_cast<float>(*scale_blob.dptr<DType>());
if (!std::isfinite(scalef) || scalef == 0) return;
std::vector<TBlob> inputs_wo_scale;
size_t num_in = inputs.size();
inputs_wo_scale.reserve(num_in - 1);
for (size_t i = 0; i < num_in - 1; i++) inputs_wo_scale.emplace_back(inputs[i]);
F<cpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
});
}

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
.describe(R"code(Update function for multi-precision AdamW optimizer.
AdamW is seen as a modification of Adam by decoupling the weight decay from the
optimization steps taken w.r.t. the loss function.
Adam update consists of the following steps, where g represents gradient and m, v
are 1st and 2nd order moment estimates (mean and variance).
.. math::
g_t = \nabla J(W_{t-1})\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})
It updates the weights using::
m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
the update is skipped.
)code" ADD_FILELINE)
.set_num_inputs(6)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamWParam>)
.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<2, 1, 6>)
.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<2, 1, 6>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3, 4};
})
.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<MPAdamWUpdate>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

NNVM_REGISTER_OP(_contrib_adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
Expand All @@ -50,21 +114,25 @@ It updates the weights using::
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)
Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0,
the update is skipped.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_inputs(5)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamWParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<nnvm::FInferShape>("FInferShape", MPUpdateInferShape<4, 1, 5>)
.set_attr<nnvm::FInferType>("FInferType", MPUpdateInferType<4, 1, 5>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", AdamWUpdate<cpu>)
.set_attr<FCompute>("FCompute<cpu>", MPUpdateCPU<AdamWUpdate>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

} // namespace op
Expand Down
27 changes: 26 additions & 1 deletion src/operator/contrib/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,33 @@
namespace mxnet {
namespace op {

template<template <typename xpu> class F>
inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
// copy to cpu and check NaN value
TBlob scale_blob = inputs[inputs.size() - 1];
MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, {
DType scale = 0;
CUDA_CALL(cudaMemcpy(&scale, scale_blob.dptr<DType>(), sizeof(DType),
cudaMemcpyDeviceToHost));
float scalef = static_cast<float>(scale);
if (!std::isfinite(scalef) || scalef == 0) return;
std::vector<TBlob> inputs_wo_scale;
size_t num_in = inputs.size();
inputs_wo_scale.reserve(num_in - 1);
for (size_t i = 0; i < num_in - 1; i++) inputs_wo_scale.emplace_back(inputs[i]);
F<gpu>::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef);
});
}

NNVM_REGISTER_OP(_contrib_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", AdamWUpdate<gpu>);
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>);

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>);

} // namespace op
} // namespace mxnet
Loading

0 comments on commit 349803c

Please sign in to comment.