From 4ac083b444338e8db424fbaf16d3f0c512a78f8f Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Wed, 13 Mar 2019 15:44:21 -0700 Subject: [PATCH 1/3] nag_mp --- python/mxnet/optimizer/optimizer.py | 64 ++++-- src/operator/optimizer_op-inl.h | 250 +++++++++++++++++++++++- src/operator/optimizer_op.cc | 86 +++++++- src/operator/optimizer_op.cu | 12 ++ tests/python/unittest/test_optimizer.py | 28 +-- 5 files changed, 401 insertions(+), 39 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 613ae8985aca..dc712dcda44a 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -28,9 +28,9 @@ from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update, - multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, - multi_mp_sgd_mom_update) + signsgd_update, signum_update, nag_update, nag_mom_update, mp_nag_update, + mp_nag_mom_update, multi_sgd_update, multi_sgd_mom_update, + multi_mp_sgd_update, multi_mp_sgd_mom_update) from ..ndarray import sparse from ..random import normal @@ -1029,7 +1029,7 @@ def update(self, index, weight, grad, state): @register class NAG(Optimizer): - """Nesterov accelerated SGD. + """Nesterov accelerated gradient. This optimizer updates each weight by:: @@ -1051,33 +1051,61 @@ def __init__(self, momentum=0.0, **kwargs): super(NAG, self).__init__(**kwargs) self.momentum = momentum + def create_state_multi_precision(self, index, weight): + weight_master_copy = None + if self.multi_precision and weight.dtype == numpy.float16: + weight_master_copy = weight.astype(numpy.float32) + return (self.create_state(index, weight_master_copy), weight_master_copy) + if weight.dtype == numpy.float16 and not self.multi_precision: + warnings.warn("Accumulating with float16 in optimizer can lead to " + "poor accuracy or slow convergence. " + "Consider using multi_precision=True option of the " + "NAG optimizer") + return self.create_state(index, weight) + def create_state(self, index, weight): momentum = None if self.momentum != 0.0: momentum = zeros(weight.shape, weight.context, dtype=weight.dtype) return momentum - def update(self, index, weight, grad, state): + def _update_impl(self, index, weight, grad, state, multi_precision=False): assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) self._update_count(index) lr = self._get_lr(index) wd = self._get_wd(index) - grad = grad * self.rescale_grad - if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) + kwargs = {'rescale_grad': self.rescale_grad} + if self.momentum > 0: + kwargs['momentum'] = self.momentum + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient - if state is not None: - mom = state - mom[:] *= self.momentum - mom[:] += grad - mom[:] += wd * weight - grad[:] += self.momentum * mom - weight[:] -= lr * grad + if not multi_precision: + if state is not None: + nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) + else: + nag_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) + else: + if state[0] is not None: + mp_nag_mom_update(weight, grad, state[0], state[1], out=weight, + lr=lr, wd=wd, **kwargs) + else: + mp_nag_update(weight, grad, state[1], out=weight, + lr=lr, wd=wd, **kwargs) + + def update(self, index, weight, grad, state): + self._update_impl(index, weight, grad, state, multi_precision=False) + + def update_multi_precision(self, index, weight, grad, state): + if not isinstance(index, (tuple, list)): + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 else: - assert self.momentum == 0.0 - weight[:] += -lr * (grad + wd * weight) + use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 + self._update_impl(index, weight, grad, state, + multi_precision=use_multi_precision) + @register class SGLD(Optimizer): @@ -1380,7 +1408,7 @@ def update(self, index, weight, grad, state): # preprocess grad grad *= self.rescale_grad if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) + grad = clip(grad, - self.clip_gradient, self.clip_gradient) # accumulated g and delta initlization acc_g, acc_delta = state diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index bd923aebbb80..bc736aa1f818 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -140,6 +140,7 @@ struct MultiSGDMomParam : public dmlc::Parameter { } }; + template inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, @@ -639,7 +640,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs, } template -inline bool MP_SGD_InferType(const nnvm::NodeAttrs& attrs, +inline bool MP_InferType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; @@ -1003,6 +1004,253 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, } +struct NAGParam : public dmlc::Parameter { + float lr; + float wd; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(NAGParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude " + "of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +struct NAGKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, + const DType* weight_data, const DType* grad_data, + const DType param_clip_gradient, const DType param_lr, + const DType param_wd, const DType param_rescale_grad, + const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + KERNEL_ASSIGN(out_data[i], req, + weight_data[i] + - param_lr * (mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient) + + param_wd*weight_data[i])); + } else { + KERNEL_ASSIGN(out_data[i], req, + weight_data[i] + - param_lr * (param_rescale_grad*grad_data[i] + + (param_wd*weight_data[i]))); + } + } +}; + +template +inline void NAGUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const NAGParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + weight.dptr_, grad.dptr_, static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad), req[0]); + }); +} + +struct NAGMomParam : public dmlc::Parameter { + float lr; + float momentum; + float wd; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(NAGMomParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(momentum) + .set_default(0.0f) + .describe("The decay rate of momentum estimates at each epoch."); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude " + "of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +struct NAGMomKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, + const DType* weight_data, const DType* grad_data, + const DType param_clip_gradient, const DType param_momentum, + const DType param_lr, const DType param_wd, + const DType param_rescale_grad, const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + mom_data[i] = param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient) + + (param_wd*weight_data[i]); + KERNEL_ASSIGN(out_data[i], req, weight_data[i] + - param_lr*(param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient))); + } else { + mom_data[i] = param_momentum*mom_data[i] + + param_rescale_grad*grad_data[i] + + (param_wd*weight_data[i]); + KERNEL_ASSIGN(out_data[i], req, weight_data[i] + - param_lr*(param_momentum*mom_data[i] + + param_rescale_grad*grad_data[i])); + } + } +}; + +template +inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + NAGMomParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mom = inputs[2].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + mom.dptr_, weight.dptr_, grad.dptr_, + static_cast(param.clip_gradient), + static_cast(param.momentum), static_cast(param.lr), + static_cast(param.wd), static_cast(param.rescale_grad), + req[0]); + }); +} + +struct MP_NAGKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, + const DType* weight_data, const DType* grad_data, + float* weight32, const float param_clip_gradient, + const float param_lr, const float param_wd, + const float param_rescale_grad, + const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + float w = weight32[i]; + w = w - param_lr * (mshadow_op::clip::Map(param_rescale_grad + *static_cast(grad_data[i]), param_clip_gradient) + + param_wd*w); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, (DType)w); + } else { + float w = weight32[i]; + w = w - param_lr * (param_rescale_grad + *static_cast(grad_data[i]) + (param_wd*w)); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, (DType)w); + } + } +}; + +template +inline void MP_NAGUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const NAGParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor weight32 = inputs[2].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, + param.lr, param.wd, param.rescale_grad, req[0]); + }); +} + +struct MP_NAGMomKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, + float* mom_data, const DType* weight_data, + const DType* grad_data, float* weight32, + const float param_clip_gradient, + const float param_momentum, const float param_lr, + const float param_wd, const float param_rescale_grad, + const OpReqType req) { + float w = weight32[i]; + if (param_clip_gradient >= 0.0f) { + mom_data[i] = param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad + *static_cast(grad_data[i]), param_clip_gradient) + + (param_wd*w); + w = w - param_lr*(param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad + *static_cast(grad_data[i]), + param_clip_gradient)); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } else { + mom_data[i] = param_momentum*mom_data[i] + + param_rescale_grad*static_cast(grad_data[i]) + + (param_wd*w); + w = w - param_lr*(param_momentum*mom_data[i] + + param_rescale_grad*static_cast(grad_data[i])); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } + } +}; + +template +inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + NAGMomParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mom = inputs[2].FlatTo2D(s); + Tensor weight32 = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + mom.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, + param.clip_gradient, param.momentum, param.lr, param.wd, + param.rescale_grad, req[0]); + }); +} + + struct FTMLParam : public dmlc::Parameter { float lr; float beta1; diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 367b91b2646c..859fa58638b0 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -35,6 +35,8 @@ DMLC_REGISTER_PARAMETER(MultiSGDParam); DMLC_REGISTER_PARAMETER(MultiSGDMomParam); DMLC_REGISTER_PARAMETER(FTMLParam); DMLC_REGISTER_PARAMETER(AdamParam); +DMLC_REGISTER_PARAMETER(NAGParam); +DMLC_REGISTER_PARAMETER(NAGMomParam); DMLC_REGISTER_PARAMETER(RMSPropParam); DMLC_REGISTER_PARAMETER(RMSPropAlexParam); DMLC_REGISTER_PARAMETER(FtrlParam); @@ -590,7 +592,7 @@ NNVM_REGISTER_OP(mp_sgd_update) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<3, 1>) -.set_attr("FInferType", MP_SGD_InferType<2, 1, 3>) +.set_attr("FInferType", MP_InferType<2, 1, 3>) .set_attr("FCompute", MP_SGDUpdate) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { @@ -607,7 +609,7 @@ NNVM_REGISTER_OP(mp_sgd_mom_update) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<4, 1>) -.set_attr("FInferType", MP_SGD_InferType<2, 1, 4>) +.set_attr("FInferType", MP_InferType<2, 1, 4>) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3}; @@ -705,6 +707,86 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a .add_arguments(AdamParam::__FIELDS__()); +NNVM_REGISTER_OP(nag_update) +MXNET_ADD_SPARSE_OP_ALIAS(nag_update) +.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. +NAG update consists of the following steps, + +state = momentum * state + grad + wd * weight +weight = weight - (lr * (grad + momentum * state)) +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", NAGUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_arguments(NAGParam::__FIELDS__()); + + +NNVM_REGISTER_OP(nag_mom_update) +MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update) +.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2}; + }) +.set_attr("FCompute", NAGMomUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mom", "NDArray-or-Symbol", "Momentum") +.add_arguments(NAGMomParam::__FIELDS__()); + + +NNVM_REGISTER_OP(mp_nag_update) +MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update) +.describe(R"code(Multi-precision NAG update. +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", MP_InferType<2, 1, 3>) +.set_attr("FCompute", MP_NAGUpdate) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2}; + }) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "gradient") +.add_argument("weight32", "NDArray-or-Symbol", "Weight32") +.add_arguments(NAGParam::__FIELDS__()); + + +NNVM_REGISTER_OP(mp_nag_mom_update) +MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_mom_update) +.describe(R"code(Multi-precision NAG update. +)code" ADD_FILELINE) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<4, 1>) +.set_attr("FInferType", MP_InferType<2, 1, 4>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) +.set_attr("FCompute", MP_NAGMomUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mom", "NDArray-or-Symbol", "Momentum") +.add_argument("weight32", "NDArray-or-Symbol", "Weight32") +.add_arguments(NAGMomParam::__FIELDS__()); + + NNVM_REGISTER_OP(rmsprop_update) .describe(R"code(Update function for `RMSProp` optimizer. diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index c42cf1831c43..5d361b7e528d 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -251,6 +251,18 @@ NNVM_REGISTER_OP(multi_mp_sgd_update) NNVM_REGISTER_OP(multi_mp_sgd_mom_update) .set_attr("FCompute", MultiSGDMomUpdate); +NNVM_REGISTER_OP(nag_update) +.set_attr("FCompute", NAGUpdate); + +NNVM_REGISTER_OP(nag_mom_update) +.set_attr("FCompute", NAGMomUpdate); + +NNVM_REGISTER_OP(mp_nag_update) +.set_attr("FCompute", MP_NAGUpdate); + +NNVM_REGISTER_OP(mp_nag_mom_update) +.set_attr("FCompute", MP_NAGMomUpdate); + NNVM_REGISTER_OP(ftml_update) .set_attr("FCompute", FTMLUpdate); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index d5aabcb4b1e5..e151cfde2306 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -346,7 +346,7 @@ def create_state(self, index, weight): if self.momentum != 0.0: momentum = mx.nd.zeros(weight.shape, weight.context, dtype=np.float32) weight_master_copy = array(weight, ctx=weight.context, dtype=np.float32) - return (weight_master_copy, momentum) + return (momentum, weight_master_copy) else: if self.momentum != 0.0: momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) @@ -394,8 +394,8 @@ def update(self, index, weight, grad, state): grad32 = grad32 * self.rescale_grad if self.clip_gradient is not None: grad32 = mx.nd.clip(grad32, -self.clip_gradient, self.clip_gradient) - mom = state[1] - weight32 = state[0] + mom = state[0] + weight32 = state[1] if self.momentum == 0.0: weight32[:] += -lr * (grad32 + wd * weight32) else: @@ -417,23 +417,15 @@ def test_nag(): rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}] + for dtype in [np.float16, np.float32, np.float64]: - for mom_option in mom_options: - for cg_option in cg_options: - for rg_option in rg_options: - for wd_option in wd_options: - for mp_option in mp_options: - kwarg = {} - kwarg.update(mom_option) - kwarg.update(cg_option) - kwarg.update(rg_option) - kwarg.update(wd_option) - kwarg.update(mp_option) - if (dtype == np.float16 and - ('multi_precision' not in kwarg or + for params in itertools.product(mom_options, cg_options, rg_options, + wd_options, mp_options): + kwarg = {k: v for param in params for k, v in param.items()} + if (dtype == np.float16 and ('multi_precision' not in kwarg or not kwarg['multi_precision'])): - continue - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + continue + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) #SGLD class PySGLD(mx.optimizer.Optimizer): From 11895e5bdf6a92ee83ce8639ffeffa6d5b6f1453 Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Sat, 13 Apr 2019 18:00:18 -0700 Subject: [PATCH 2/3] doc --- python/mxnet/optimizer/optimizer.py | 6 ++---- src/operator/optimizer_op.cc | 26 +++++++++++++++++--------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index dc712dcda44a..60ee44d03b20 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -1099,10 +1099,8 @@ def update(self, index, weight, grad, state): self._update_impl(index, weight, grad, state, multi_precision=False) def update_multi_precision(self, index, weight, grad, state): - if not isinstance(index, (tuple, list)): - use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 - else: - use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 \ + and isinstance(state, (tuple, list)) self._update_impl(index, weight, grad, state, multi_precision=use_multi_precision) diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 859fa58638b0..e77bd416e37a 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -708,12 +708,11 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a NNVM_REGISTER_OP(nag_update) -MXNET_ADD_SPARSE_OP_ALIAS(nag_update) .describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. -NAG update consists of the following steps, +It updates the weights using the following formula, + +weight = weight - (lr * (grad + wd * weight)) -state = momentum * state + grad + wd * weight -weight = weight - (lr * (grad + momentum * state)) )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -727,8 +726,19 @@ weight = weight - (lr * (grad + momentum * state)) NNVM_REGISTER_OP(nag_mom_update) -MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update) .describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. +It updates the weights using the following formula, + +.. math:: + v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\ + W_t = W_{t-1} - v_t + +Where +:math:`\eta` is the learning rate of the optimizer +:math:`\gamma` is the decay rate of the momentum estimate +:math:`\v_t` is the update vector at time step `t` +:math:`\W_t` is the weight vector at time step `t` + )code" ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(1) @@ -747,8 +757,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update) NNVM_REGISTER_OP(mp_nag_update) -MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update) -.describe(R"code(Multi-precision NAG update. +.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. )code" ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(1) @@ -767,8 +776,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update) NNVM_REGISTER_OP(mp_nag_mom_update) -MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_mom_update) -.describe(R"code(Multi-precision NAG update. +.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. )code" ADD_FILELINE) .set_num_inputs(4) .set_num_outputs(1) From 5c17c79d4144c054c68e1b6793a7b22833cab8b9 Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Mon, 29 Apr 2019 08:54:28 -0700 Subject: [PATCH 3/3] reuse sgd updates where convenient --- python/mxnet/optimizer/optimizer.py | 10 +-- src/operator/optimizer_op-inl.h | 87 ------------------------- src/operator/optimizer_op.cc | 37 ----------- src/operator/optimizer_op.cu | 6 -- tests/python/unittest/test_optimizer.py | 2 +- 5 files changed, 6 insertions(+), 136 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 60ee44d03b20..c2c1aa6a76f4 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -28,9 +28,9 @@ from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update, nag_update, nag_mom_update, mp_nag_update, - mp_nag_mom_update, multi_sgd_update, multi_sgd_mom_update, - multi_mp_sgd_update, multi_mp_sgd_mom_update) + signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update, + multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, + multi_mp_sgd_mom_update) from ..ndarray import sparse from ..random import normal @@ -1086,13 +1086,13 @@ def _update_impl(self, index, weight, grad, state, multi_precision=False): if state is not None: nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) else: - nag_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) + sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) else: if state[0] is not None: mp_nag_mom_update(weight, grad, state[0], state[1], out=weight, lr=lr, wd=wd, **kwargs) else: - mp_nag_update(weight, grad, state[1], out=weight, + mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs) def update(self, index, weight, grad, state): diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index bc736aa1f818..50637a8e7b42 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1029,48 +1029,6 @@ struct NAGParam : public dmlc::Parameter { } }; -struct NAGKernel { - template - MSHADOW_XINLINE static void Map(int i, DType* out_data, - const DType* weight_data, const DType* grad_data, - const DType param_clip_gradient, const DType param_lr, - const DType param_wd, const DType param_rescale_grad, - const OpReqType req) { - if (param_clip_gradient >= 0.0f) { - KERNEL_ASSIGN(out_data[i], req, - weight_data[i] - - param_lr * (mshadow_op::clip::Map(param_rescale_grad*grad_data[i], - param_clip_gradient) - + param_wd*weight_data[i])); - } else { - KERNEL_ASSIGN(out_data[i], req, - weight_data[i] - - param_lr * (param_rescale_grad*grad_data[i] - + (param_wd*weight_data[i]))); - } - } -}; - -template -inline void NAGUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mxnet_op; - const NAGParam& param = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); - Kernel::Launch(s, weight.shape_.Size(), out.dptr_, - weight.dptr_, grad.dptr_, static_cast(param.clip_gradient), - static_cast(param.lr), static_cast(param.wd), - static_cast(param.rescale_grad), req[0]); - }); -} - struct NAGMomParam : public dmlc::Parameter { float lr; float momentum; @@ -1150,51 +1108,6 @@ inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs, }); } -struct MP_NAGKernel { - template - MSHADOW_XINLINE static void Map(int i, DType* out_data, - const DType* weight_data, const DType* grad_data, - float* weight32, const float param_clip_gradient, - const float param_lr, const float param_wd, - const float param_rescale_grad, - const OpReqType req) { - if (param_clip_gradient >= 0.0f) { - float w = weight32[i]; - w = w - param_lr * (mshadow_op::clip::Map(param_rescale_grad - *static_cast(grad_data[i]), param_clip_gradient) - + param_wd*w); - weight32[i] = w; - KERNEL_ASSIGN(out_data[i], req, (DType)w); - } else { - float w = weight32[i]; - w = w - param_lr * (param_rescale_grad - *static_cast(grad_data[i]) + (param_wd*w)); - weight32[i] = w; - KERNEL_ASSIGN(out_data[i], req, (DType)w); - } - } -}; - -template -inline void MP_NAGUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mxnet_op; - const NAGParam& param = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor weight32 = inputs[2].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); - Kernel::Launch(s, weight.shape_.Size(), out.dptr_, - weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, - param.lr, param.wd, param.rescale_grad, req[0]); - }); -} - struct MP_NAGMomKernel { template MSHADOW_XINLINE static void Map(int i, DType* out_data, diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index e77bd416e37a..01410863640f 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -707,24 +707,6 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a .add_arguments(AdamParam::__FIELDS__()); -NNVM_REGISTER_OP(nag_update) -.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. -It updates the weights using the following formula, - -weight = weight - (lr * (grad + wd * weight)) - -)code" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", ElemwiseShape<2, 1>) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FCompute", NAGUpdate) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "Gradient") -.add_arguments(NAGParam::__FIELDS__()); - - NNVM_REGISTER_OP(nag_mom_update) .describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. It updates the weights using the following formula, @@ -756,25 +738,6 @@ Where .add_arguments(NAGMomParam::__FIELDS__()); -NNVM_REGISTER_OP(mp_nag_update) -.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. -)code" ADD_FILELINE) -.set_num_inputs(3) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", ElemwiseShape<3, 1>) -.set_attr("FInferType", MP_InferType<2, 1, 3>) -.set_attr("FCompute", MP_NAGUpdate) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{2}; - }) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "gradient") -.add_argument("weight32", "NDArray-or-Symbol", "Weight32") -.add_arguments(NAGParam::__FIELDS__()); - - NNVM_REGISTER_OP(mp_nag_mom_update) .describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. )code" ADD_FILELINE) diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 5d361b7e528d..2c72462de016 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -251,15 +251,9 @@ NNVM_REGISTER_OP(multi_mp_sgd_update) NNVM_REGISTER_OP(multi_mp_sgd_mom_update) .set_attr("FCompute", MultiSGDMomUpdate); -NNVM_REGISTER_OP(nag_update) -.set_attr("FCompute", NAGUpdate); - NNVM_REGISTER_OP(nag_mom_update) .set_attr("FCompute", NAGMomUpdate); -NNVM_REGISTER_OP(mp_nag_update) -.set_attr("FCompute", MP_NAGUpdate); - NNVM_REGISTER_OP(mp_nag_mom_update) .set_attr("FCompute", MP_NAGMomUpdate); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index e151cfde2306..3e6cdd0997ce 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -425,7 +425,7 @@ def test_nag(): if (dtype == np.float16 and ('multi_precision' not in kwarg or not kwarg['multi_precision'])): continue - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, rtol=1e-3, atol=1e-4) #SGLD class PySGLD(mx.optimizer.Optimizer):