diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index e5c5ff026112..23562921bc88 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -1099,10 +1099,8 @@ def update(self, index, weight, grad, state): self._update_impl(index, weight, grad, state, multi_precision=False) def update_multi_precision(self, index, weight, grad, state): - if not isinstance(index, (tuple, list)): - use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 - else: - use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 \ + and isinstance(state, (tuple, list)) self._update_impl(index, weight, grad, state, multi_precision=use_multi_precision) diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 859fa58638b0..e77bd416e37a 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -708,12 +708,11 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a NNVM_REGISTER_OP(nag_update) -MXNET_ADD_SPARSE_OP_ALIAS(nag_update) .describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. -NAG update consists of the following steps, +It updates the weights using the following formula, + +weight = weight - (lr * (grad + wd * weight)) -state = momentum * state + grad + wd * weight -weight = weight - (lr * (grad + momentum * state)) )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -727,8 +726,19 @@ weight = weight - (lr * (grad + momentum * state)) NNVM_REGISTER_OP(nag_mom_update) -MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update) .describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. +It updates the weights using the following formula, + +.. math:: + v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\ + W_t = W_{t-1} - v_t + +Where +:math:`\eta` is the learning rate of the optimizer +:math:`\gamma` is the decay rate of the momentum estimate +:math:`\v_t` is the update vector at time step `t` +:math:`\W_t` is the weight vector at time step `t` + )code" ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(1) @@ -747,8 +757,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update) NNVM_REGISTER_OP(mp_nag_update) -MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update) -.describe(R"code(Multi-precision NAG update. +.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. )code" ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(1) @@ -767,8 +776,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update) NNVM_REGISTER_OP(mp_nag_mom_update) -MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_mom_update) -.describe(R"code(Multi-precision NAG update. +.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. )code" ADD_FILELINE) .set_num_inputs(4) .set_num_outputs(1)