Skip to content

Commit

Permalink
NAG Optimizer with multi-precision support (apache#14568)
Browse files Browse the repository at this point in the history
* nag_mp

* doc

* reuse sgd updates where convenient
  • Loading branch information
Anirudh authored and haohuw committed Jun 23, 2019
1 parent 7b4dd4b commit b4db5c4
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 37 deletions.
58 changes: 42 additions & 16 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update,
signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update,
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update)
from ..ndarray import sparse
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def update(self, index, weight, grad, state):

@register
class NAG(Optimizer):
"""Nesterov accelerated SGD.
"""Nesterov accelerated gradient.
This optimizer updates each weight by::
Expand All @@ -1051,33 +1051,59 @@ def __init__(self, momentum=0.0, **kwargs):
super(NAG, self).__init__(**kwargs)
self.momentum = momentum

def create_state_multi_precision(self, index, weight):
weight_master_copy = None
if self.multi_precision and weight.dtype == numpy.float16:
weight_master_copy = weight.astype(numpy.float32)
return (self.create_state(index, weight_master_copy), weight_master_copy)
if weight.dtype == numpy.float16 and not self.multi_precision:
warnings.warn("Accumulating with float16 in optimizer can lead to "
"poor accuracy or slow convergence. "
"Consider using multi_precision=True option of the "
"NAG optimizer")
return self.create_state(index, weight)

def create_state(self, index, weight):
momentum = None
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype)
return momentum

def update(self, index, weight, grad, state):
def _update_impl(self, index, weight, grad, state, multi_precision=False):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

if state is not None:
mom = state
mom[:] *= self.momentum
mom[:] += grad
mom[:] += wd * weight
grad[:] += self.momentum * mom
weight[:] -= lr * grad
if not multi_precision:
if state is not None:
nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
else:
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
else:
assert self.momentum == 0.0
weight[:] += -lr * (grad + wd * weight)
if state[0] is not None:
mp_nag_mom_update(weight, grad, state[0], state[1], out=weight,
lr=lr, wd=wd, **kwargs)
else:
mp_sgd_update(weight, grad, state[1], out=weight,
lr=lr, wd=wd, **kwargs)

def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state, multi_precision=False)

def update_multi_precision(self, index, weight, grad, state):
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 \
and isinstance(state, (tuple, list))
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)


@register
class SGLD(Optimizer):
Expand Down Expand Up @@ -1380,7 +1406,7 @@ def update(self, index, weight, grad, state):
# preprocess grad
grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
grad = clip(grad, - self.clip_gradient, self.clip_gradient)

# accumulated g and delta initlization
acc_g, acc_delta = state
Expand Down
163 changes: 162 additions & 1 deletion src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct MultiSGDMomParam : public dmlc::Parameter<MultiSGDMomParam> {
}
};


template<typename ParamType, int input_stride>
inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
Expand Down Expand Up @@ -639,7 +640,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
}

template<int n_in, int n_out, int total_in>
inline bool MP_SGD_InferType(const nnvm::NodeAttrs& attrs,
inline bool MP_InferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
Expand Down Expand Up @@ -1003,6 +1004,166 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
}


struct NAGParam : public dmlc::Parameter<NAGParam> {
float lr;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(NAGParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude "
"of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct NAGMomParam : public dmlc::Parameter<NAGMomParam> {
float lr;
float momentum;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(NAGMomParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(momentum)
.set_default(0.0f)
.describe("The decay rate of momentum estimates at each epoch.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude "
"of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct NAGMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data,
const DType* weight_data, const DType* grad_data,
const DType param_clip_gradient, const DType param_momentum,
const DType param_lr, const DType param_wd,
const DType param_rescale_grad, const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient)
+ (param_wd*weight_data[i]);
KERNEL_ASSIGN(out_data[i], req, weight_data[i]
- param_lr*(param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient)));
} else {
mom_data[i] = param_momentum*mom_data[i]
+ param_rescale_grad*grad_data[i]
+ (param_wd*weight_data[i]);
KERNEL_ASSIGN(out_data[i], req, weight_data[i]
- param_lr*(param_momentum*mom_data[i]
+ param_rescale_grad*grad_data[i]));
}
}
};

template<typename xpu>
inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
mom.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.momentum), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.rescale_grad),
req[0]);
});
}

struct MP_NAGMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
float* mom_data, const DType* weight_data,
const DType* grad_data, float* weight32,
const float param_clip_gradient,
const float param_momentum, const float param_lr,
const float param_wd, const float param_rescale_grad,
const OpReqType req) {
float w = weight32[i];
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad
*static_cast<float>(grad_data[i]), param_clip_gradient)
+ (param_wd*w);
w = w - param_lr*(param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad
*static_cast<float>(grad_data[i]),
param_clip_gradient));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
} else {
mom_data[i] = param_momentum*mom_data[i]
+ param_rescale_grad*static_cast<float>(grad_data[i])
+ (param_wd*w);
w = w - param_lr*(param_momentum*mom_data[i]
+ param_rescale_grad*static_cast<float>(grad_data[i]));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
}
};

template<typename xpu>
inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<MP_NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
mom.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
param.clip_gradient, param.momentum, param.lr, param.wd,
param.rescale_grad, req[0]);
});
}


struct FTMLParam : public dmlc::Parameter<FTMLParam> {
float lr;
float beta1;
Expand Down
57 changes: 55 additions & 2 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ DMLC_REGISTER_PARAMETER(MultiSGDParam);
DMLC_REGISTER_PARAMETER(MultiSGDMomParam);
DMLC_REGISTER_PARAMETER(FTMLParam);
DMLC_REGISTER_PARAMETER(AdamParam);
DMLC_REGISTER_PARAMETER(NAGParam);
DMLC_REGISTER_PARAMETER(NAGMomParam);
DMLC_REGISTER_PARAMETER(RMSPropParam);
DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
DMLC_REGISTER_PARAMETER(FtrlParam);
Expand Down Expand Up @@ -590,7 +592,7 @@ NNVM_REGISTER_OP(mp_sgd_update)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", MP_SGD_InferType<2, 1, 3>)
.set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 3>)
.set_attr<FCompute>("FCompute<cpu>", MP_SGDUpdate<cpu>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
Expand All @@ -607,7 +609,7 @@ NNVM_REGISTER_OP(mp_sgd_mom_update)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDMomParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", MP_SGD_InferType<2, 1, 4>)
.set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 4>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
Expand Down Expand Up @@ -705,6 +707,57 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a
.add_arguments(AdamParam::__FIELDS__());


NNVM_REGISTER_OP(nag_mom_update)
.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer.
It updates the weights using the following formula,
.. math::
v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\
W_t = W_{t-1} - v_t
Where
:math:`\eta` is the learning rate of the optimizer
:math:`\gamma` is the decay rate of the momentum estimate
:math:`\v_t` is the update vector at time step `t`
:math:`\W_t` is the weight vector at time step `t`
)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NAGMomParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
})
.set_attr<FCompute>("FCompute<cpu>", NAGMomUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mom", "NDArray-or-Symbol", "Momentum")
.add_arguments(NAGMomParam::__FIELDS__());


NNVM_REGISTER_OP(mp_nag_mom_update)
.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NAGMomParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 4>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", MP_NAGMomUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mom", "NDArray-or-Symbol", "Momentum")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_arguments(NAGMomParam::__FIELDS__());


NNVM_REGISTER_OP(rmsprop_update)
.describe(R"code(Update function for `RMSProp` optimizer.
Expand Down
6 changes: 6 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ NNVM_REGISTER_OP(multi_mp_sgd_update)
NNVM_REGISTER_OP(multi_mp_sgd_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MultiSGDMomUpdate<gpu, single_precision, 4>);

NNVM_REGISTER_OP(nag_mom_update)
.set_attr<FCompute>("FCompute<gpu>", NAGMomUpdate<gpu>);

NNVM_REGISTER_OP(mp_nag_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MP_NAGMomUpdate<gpu>);

NNVM_REGISTER_OP(ftml_update)
.set_attr<FCompute>("FCompute<gpu>", FTMLUpdate<gpu>);

Expand Down
Loading

0 comments on commit b4db5c4

Please sign in to comment.