Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

NAG Optimizer with multi-precision support #14568

Merged
merged 3 commits into from
May 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update,
signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update,
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update)
from ..ndarray import sparse
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def update(self, index, weight, grad, state):

@register
class NAG(Optimizer):
"""Nesterov accelerated SGD.
"""Nesterov accelerated gradient.

This optimizer updates each weight by::

Expand All @@ -1051,33 +1051,59 @@ def __init__(self, momentum=0.0, **kwargs):
super(NAG, self).__init__(**kwargs)
self.momentum = momentum

def create_state_multi_precision(self, index, weight):
weight_master_copy = None
if self.multi_precision and weight.dtype == numpy.float16:
weight_master_copy = weight.astype(numpy.float32)
return (self.create_state(index, weight_master_copy), weight_master_copy)
if weight.dtype == numpy.float16 and not self.multi_precision:
warnings.warn("Accumulating with float16 in optimizer can lead to "
"poor accuracy or slow convergence. "
"Consider using multi_precision=True option of the "
"NAG optimizer")
return self.create_state(index, weight)

def create_state(self, index, weight):
momentum = None
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype)
return momentum

def update(self, index, weight, grad, state):
def _update_impl(self, index, weight, grad, state, multi_precision=False):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

if state is not None:
mom = state
mom[:] *= self.momentum
mom[:] += grad
mom[:] += wd * weight
grad[:] += self.momentum * mom
weight[:] -= lr * grad
if not multi_precision:
if state is not None:
nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
else:
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
else:
assert self.momentum == 0.0
weight[:] += -lr * (grad + wd * weight)
if state[0] is not None:
mp_nag_mom_update(weight, grad, state[0], state[1], out=weight,
lr=lr, wd=wd, **kwargs)
else:
mp_sgd_update(weight, grad, state[1], out=weight,
lr=lr, wd=wd, **kwargs)

def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state, multi_precision=False)

def update_multi_precision(self, index, weight, grad, state):
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 \
and isinstance(state, (tuple, list))
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)


@register
class SGLD(Optimizer):
Expand Down Expand Up @@ -1380,7 +1406,7 @@ def update(self, index, weight, grad, state):
# preprocess grad
grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
grad = clip(grad, - self.clip_gradient, self.clip_gradient)

# accumulated g and delta initlization
acc_g, acc_delta = state
Expand Down
163 changes: 162 additions & 1 deletion src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct MultiSGDMomParam : public dmlc::Parameter<MultiSGDMomParam> {
}
};


template<typename ParamType, int input_stride>
inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
Expand Down Expand Up @@ -639,7 +640,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
}

template<int n_in, int n_out, int total_in>
inline bool MP_SGD_InferType(const nnvm::NodeAttrs& attrs,
inline bool MP_InferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
Expand Down Expand Up @@ -1003,6 +1004,166 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
}


struct NAGParam : public dmlc::Parameter<NAGParam> {
float lr;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(NAGParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude "
"of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct NAGMomParam : public dmlc::Parameter<NAGMomParam> {
float lr;
float momentum;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(NAGMomParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(momentum)
.set_default(0.0f)
.describe("The decay rate of momentum estimates at each epoch.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude "
"of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct NAGMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data,
const DType* weight_data, const DType* grad_data,
const DType param_clip_gradient, const DType param_momentum,
const DType param_lr, const DType param_wd,
const DType param_rescale_grad, const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient)
+ (param_wd*weight_data[i]);
KERNEL_ASSIGN(out_data[i], req, weight_data[i]
- param_lr*(param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient)));
} else {
mom_data[i] = param_momentum*mom_data[i]
+ param_rescale_grad*grad_data[i]
+ (param_wd*weight_data[i]);
KERNEL_ASSIGN(out_data[i], req, weight_data[i]
- param_lr*(param_momentum*mom_data[i]
+ param_rescale_grad*grad_data[i]));
}
}
};

template<typename xpu>
inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
mom.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.momentum), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.rescale_grad),
req[0]);
});
}

struct MP_NAGMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
float* mom_data, const DType* weight_data,
const DType* grad_data, float* weight32,
const float param_clip_gradient,
const float param_momentum, const float param_lr,
const float param_wd, const float param_rescale_grad,
const OpReqType req) {
float w = weight32[i];
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad
*static_cast<float>(grad_data[i]), param_clip_gradient)
+ (param_wd*w);
w = w - param_lr*(param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad
*static_cast<float>(grad_data[i]),
param_clip_gradient));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
} else {
mom_data[i] = param_momentum*mom_data[i]
+ param_rescale_grad*static_cast<float>(grad_data[i])
+ (param_wd*w);
w = w - param_lr*(param_momentum*mom_data[i]
+ param_rescale_grad*static_cast<float>(grad_data[i]));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
}
};

template<typename xpu>
inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<MP_NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
mom.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
param.clip_gradient, param.momentum, param.lr, param.wd,
param.rescale_grad, req[0]);
});
}


struct FTMLParam : public dmlc::Parameter<FTMLParam> {
float lr;
float beta1;
Expand Down
57 changes: 55 additions & 2 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ DMLC_REGISTER_PARAMETER(MultiSGDParam);
DMLC_REGISTER_PARAMETER(MultiSGDMomParam);
DMLC_REGISTER_PARAMETER(FTMLParam);
DMLC_REGISTER_PARAMETER(AdamParam);
DMLC_REGISTER_PARAMETER(NAGParam);
DMLC_REGISTER_PARAMETER(NAGMomParam);
DMLC_REGISTER_PARAMETER(RMSPropParam);
DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
DMLC_REGISTER_PARAMETER(FtrlParam);
Expand Down Expand Up @@ -590,7 +592,7 @@ NNVM_REGISTER_OP(mp_sgd_update)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", MP_SGD_InferType<2, 1, 3>)
.set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 3>)
.set_attr<FCompute>("FCompute<cpu>", MP_SGDUpdate<cpu>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
Expand All @@ -607,7 +609,7 @@ NNVM_REGISTER_OP(mp_sgd_mom_update)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDMomParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", MP_SGD_InferType<2, 1, 4>)
.set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 4>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
Expand Down Expand Up @@ -705,6 +707,57 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a
.add_arguments(AdamParam::__FIELDS__());


NNVM_REGISTER_OP(nag_mom_update)
.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer.
It updates the weights using the following formula,

.. math::
v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\
W_t = W_{t-1} - v_t

Where
:math:`\eta` is the learning rate of the optimizer
:math:`\gamma` is the decay rate of the momentum estimate
:math:`\v_t` is the update vector at time step `t`
:math:`\W_t` is the weight vector at time step `t`

)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NAGMomParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
})
.set_attr<FCompute>("FCompute<cpu>", NAGMomUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mom", "NDArray-or-Symbol", "Momentum")
.add_arguments(NAGMomParam::__FIELDS__());


NNVM_REGISTER_OP(mp_nag_mom_update)
.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NAGMomParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 4>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", MP_NAGMomUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mom", "NDArray-or-Symbol", "Momentum")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_arguments(NAGMomParam::__FIELDS__());


NNVM_REGISTER_OP(rmsprop_update)
.describe(R"code(Update function for `RMSProp` optimizer.

Expand Down
6 changes: 6 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ NNVM_REGISTER_OP(multi_mp_sgd_update)
NNVM_REGISTER_OP(multi_mp_sgd_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MultiSGDMomUpdate<gpu, single_precision, 4>);

NNVM_REGISTER_OP(nag_mom_update)
.set_attr<FCompute>("FCompute<gpu>", NAGMomUpdate<gpu>);

NNVM_REGISTER_OP(mp_nag_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MP_NAGMomUpdate<gpu>);

NNVM_REGISTER_OP(ftml_update)
.set_attr<FCompute>("FCompute<gpu>", FTMLUpdate<gpu>);

Expand Down
Loading