diff --git a/docs/api/python/optimization/contrib.md b/docs/api/python/optimization/contrib.md index 9d3f3483113e..8fc261f4f052 100644 --- a/docs/api/python/optimization/contrib.md +++ b/docs/api/python/optimization/contrib.md @@ -35,7 +35,7 @@ In the rest of this document, we list routines provided by the `optimizer.contri .. autosummary:: :nosignatures: - ProximalGroupAdaGrad + GroupAdaGrad ``` ## API Reference diff --git a/python/mxnet/optimizer/__init__.py b/python/mxnet/optimizer/__init__.py index 4840413ccaa6..72eb5a741520 100644 --- a/python/mxnet/optimizer/__init__.py +++ b/python/mxnet/optimizer/__init__.py @@ -17,6 +17,7 @@ """Optimizer API of MXNet.""" from . import optimizer, contrib +# pylint: disable=wildcard-import from .optimizer import * # pylint: enable=wildcard-import diff --git a/python/mxnet/optimizer/contrib.py b/python/mxnet/optimizer/contrib.py index 8cf48261036e..1baf2ff1020a 100644 --- a/python/mxnet/optimizer/contrib.py +++ b/python/mxnet/optimizer/contrib.py @@ -18,19 +18,18 @@ # pylint: disable=too-many-lines """Contrib optimizers.""" -from ..ndarray import (NDArray, clip, contrib, full, mean, norm, sparse, sqrt, - square, zeros) +from ..ndarray import (NDArray, clip, contrib, mean, sqrt, square, zeros) from .optimizer import Optimizer # convenience wrapper for Optimizer.Register register = Optimizer.register # pylint: disable=invalid-name -__all__ = ['ProximalGroupAdaGrad'] +__all__ = ['GroupAdaGrad'] @register -class ProximalGroupAdaGrad(Optimizer): - """Proximal Adagrad optimizer with row-wise learning rates. +class GroupAdaGrad(Optimizer): + """Adagrad optimizer with row-wise learning rates. This class implements the AdaGrad optimizer described in *Adaptive Subgradient Methods for Online Learning and Stochastic Optimization*, and @@ -44,12 +43,11 @@ class ProximalGroupAdaGrad(Optimizer): div = grad / sqrt(history + float_stable_eps) weight -= div * lr - If `l2_regularization_strength > 0` a proximal operator is used to optimize - with group lasso objective. Weights are updated lazily if the gradient is - sparse. In particular, before using a set of weights for a forward pass, - you may want to ensure that the lazily accumulated group lasso - regularization is applied. This can be achieved by creating a sparse - gradient array that contains explicit 0 data for the indices to be updated: + Weights are updated lazily if the gradient is sparse. In particular, before + using a set of weights for a forward pass, you may want to ensure that the + lazily accumulated group lasso regularization is applied. This can be + achieved by creating a sparse gradient array that contains explicit 0 data + for the indices to be updated: fake_grad = mx.nd.sparse.row_sparse_array( (mx.nd.zeros((len(indices), dim)), indices)) @@ -60,38 +58,27 @@ class ProximalGroupAdaGrad(Optimizer): trainer.step(batch_size=1) For details of the update algorithm see - :class:`~mxnet.ndarray.contrib.proximal_group_adagrad_update`. + :class:`~mxnet.ndarray.contrib.group_adagrad_update`. This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`. Weight decay is not supported. Parameters ---------- - l2_regularization_strength : float - Strength of group lasso L2 regularization. eps: float, optional Initial value of the history accumulator. Avoids division by 0. """ - def __init__(self, l2_regularization_strength=0.0, eps=1e-5, **kwargs): - super(ProximalGroupAdaGrad, self).__init__(**kwargs) - self.l2_regularization_strength = l2_regularization_strength + def __init__(self, eps=1e-5, **kwargs): + super(GroupAdaGrad, self).__init__(**kwargs) self.float_stable_eps = eps def create_state(self, index, weight): assert len(weight.shape) == 2 history = zeros( (weight.shape[0], 1), weight.context, stype=weight.stype) - last_update = None - if self.l2_regularization_strength > 0: - last_update = full( - shape=(weight.shape[0], ), - val=self.num_update, - ctx=weight.context) - else: - last_update = zeros(1, ctx=weight.context) - return (history, last_update) + return history def update(self, index, weight, grad, state): assert (isinstance(weight, NDArray)) @@ -99,11 +86,9 @@ def update(self, index, weight, grad, state): self._update_count(index) lr = self._get_lr(index) wd = self._get_wd(index) - assert wd == 0, 'Weight decay is not supported for ProximalGroupAdaGrad' + assert wd == 0, 'Weight decay is not supported for GroupAdaGrad' is_sparse = grad.stype == 'row_sparse' - history = state[0] - last_update = state[1] if is_sparse: kwargs = { 'epsilon': self.float_stable_eps, @@ -111,35 +96,17 @@ def update(self, index, weight, grad, state): } if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient - if self.l2_regularization_strength: - kwargs['l2_regularization_strength'] = \ - self.l2_regularization_strength - contrib.proximal_group_adagrad_update( + contrib.group_adagrad_update( weight, grad, - history, + state, out=weight, - last_update=last_update, lr=lr, - current_update=self.num_update, **kwargs) - elif self.l2_regularization_strength > 0: - grad = grad * self.rescale_grad - if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) - history[:] += mean(square(grad), axis=1, keepdims=True) - div = lr * grad / sqrt(history + self.float_stable_eps) - num_skipped = (self.num_update - last_update).expand_dims(1) - scaled_l2 = lr / sqrt(history + self.float_stable_eps) \ - * self.l2_regularization_strength * num_skipped - nrm = norm(weight - div, ord=2, axis=1, keepdims=True) - weight[:] = (weight - div) * (1 - scaled_l2 / nrm) - weight[:] *= nrm > scaled_l2 - last_update[:] = self.num_update else: grad = grad * self.rescale_grad if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) - history[:] += mean(square(grad), axis=1, keepdims=True) - div = lr * grad / sqrt(history + self.float_stable_eps) + state[:] += mean(square(grad), axis=1, keepdims=True) + div = lr * grad / sqrt(state + self.float_stable_eps) weight[:] -= div diff --git a/src/operator/contrib/optimizer_op-inl.h b/src/operator/contrib/optimizer_op-inl.h index 0bbe9cf7d1f4..fd556a4231cb 100644 --- a/src/operator/contrib/optimizer_op-inl.h +++ b/src/operator/contrib/optimizer_op-inl.h @@ -43,15 +43,12 @@ namespace mxnet { namespace op { -struct ProximalGroupAdagradParam - : public dmlc::Parameter { +struct GroupAdagradParam : public dmlc::Parameter { float lr; float epsilon; float rescale_grad; float clip_gradient; - float l2_regularization_strength; - float current_update; - DMLC_DECLARE_PARAMETER(ProximalGroupAdagradParam) { + DMLC_DECLARE_PARAMETER(GroupAdagradParam) { DMLC_DECLARE_FIELD(lr).describe("Learning rate"); DMLC_DECLARE_FIELD(rescale_grad) .set_default(1.0f) @@ -62,29 +59,21 @@ struct ProximalGroupAdagradParam "Clip gradient to the range of [-clip_gradient, clip_gradient] " "If clip_gradient <= 0, gradient clipping is turned off. " "grad = max(min(grad, clip_gradient), -clip_gradient)."); - DMLC_DECLARE_FIELD(l2_regularization_strength) - .set_default(0.0f) - .describe("Lambda term for group lasso objective."); DMLC_DECLARE_FIELD(epsilon).set_default(1.0e-5).describe( "Epsilon for numerical stability"); - DMLC_DECLARE_FIELD(current_update) - .set_default(0.0f) - .describe("Current update iteration for lazy update with group lasso " - "objective."); } }; -inline bool ProximalGroupAdagradStorageType(const nnvm::NodeAttrs &attrs, - const int dev_mask, - DispatchMode *dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 4U); +inline bool GroupAdagradStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 1U); const int weight_stype = in_attrs->at(0); const int grad_stype = in_attrs->at(1); const int state_stype = in_attrs->at(2); - const int counter_stype = in_attrs->at(3); bool dispatched = false; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { // dns, ... -> dns @@ -92,7 +81,6 @@ inline bool ProximalGroupAdagradStorageType(const nnvm::NodeAttrs &attrs, DispatchMode::kFCompute); } if (!dispatched && grad_stype == kRowSparseStorage && - counter_stype == kDefaultStorage && (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) && state_stype == weight_stype) { // weight and state share stype, grad's stype = rsp @@ -105,14 +93,13 @@ inline bool ProximalGroupAdagradStorageType(const nnvm::NodeAttrs &attrs, /*! \brief kernel for sparse adagrad update with group sparsity regularization */ -template struct ProximalGroupAdagradDnsRspKernel { +template struct GroupAdagradDnsRspKernel { template MSHADOW_XINLINE static void Map(int i, const index_t row_length, DType *out_data, DType *state_data, DType *weight_data, const IType *grad_idx, const DType *grad_data, - DType *last_update_data, const DType current_update, - const DType clip_gradient, const DType rescale_grad, - const DType l2_regularization_strength, const DType lr, const DType eps) { + const DType clip_gradient, const DType rescale_grad, const DType lr, + const DType eps) { using namespace mshadow_op; // Helper to obtain index into weight / state arrays @@ -138,82 +125,26 @@ template struct ProximalGroupAdagradDnsRspKernel { } state_data[grad_idx[i]] += grad_ssq / row_length; - // Number of weight updates skipped due to lazy_update - DType delay{0}; - if (l2_regularization_strength > 0) { - // last_update_data[grad_idx[i]] is only valid if - // l2_regularization_strength > 0. Otherwise may be out of bounds read. - delay = current_update - last_update_data[grad_idx[i]]; - last_update_data[grad_idx[i]] = current_update; - } - - if (l2_regularization_strength <= 0 || delay < 0) { - if (delay < 0) { - std::printf("Got invalid last_update in proximal_adagrad_update. " - "Using standard Adagrad update.\n"); - } - - // Standard Adagrad Update - for (index_t j = 0; j < row_length; j++) { - // clang-format off - const DType grad_rescaled = get_grad_rescaled(j); - index_t data_j = get_data_j(j); - const DType div = lr * grad_rescaled / square_root::Map(state_data[grad_idx[i]] + eps); - out_data[data_j] = weight_data[data_j] - div; - // clang-format on - } - } else { - // Compute L2 norm of updated parameter using scaled sum of squares - DType norm, scale; - mshadow_op::nrm2::SetInitValue(norm, scale); - for (index_t j = 0; j < row_length; j++) { - const DType grad_rescaled = get_grad_rescaled(j); - index_t data_j = get_data_j(j); - const DType val = - (weight_data[data_j] - - lr / std::sqrt(state_data[grad_idx[i]] + eps) * grad_rescaled); - mshadow_op::nrm2::Reduce(norm, val, scale); - } - mshadow_op::nrm2::Finalize(norm, scale); - - // Compute regularization lambda - DType lambda = l2_regularization_strength * lr / - square_root::Map(state_data[grad_idx[i]] + eps); - DType l2_scale = 1 - lambda / norm; - if (l2_scale < 0) { - l2_scale = 0; - } else if (l2_scale > 0) { - scale = math::pow(scale, delay); - } - - if (l2_scale == 0) { - // Soft threshold weights (proximal map for group lasso) - for (index_t j = 0; j < row_length; j++) { - index_t data_j = get_data_j(j); - out_data[data_j] = 0; - } - } else { - for (index_t j = 0; j < row_length; j++) { - // clang-format off - const DType grad_rescaled = get_grad_rescaled(j); - index_t data_j = get_data_j(j); - const DType div = lr * grad_rescaled / square_root::Map(state_data[grad_idx[i]] + eps); - out_data[data_j] = (weight_data[data_j] - div) * l2_scale; - // clang-format on - } - } + // Standard Adagrad Update + for (index_t j = 0; j < row_length; j++) { + // clang-format off + const DType grad_rescaled = get_grad_rescaled(j); + index_t data_j = get_data_j(j); + const DType div = lr * grad_rescaled / square_root::Map(state_data[grad_idx[i]] + eps); + out_data[data_j] = weight_data[data_j] - div; + // clang-format on } } }; /* - * \brief Proximal Group Adagrad update implementation for dense weight and row_sparse grad. + * \brief Group Adagrad update implementation for dense weight and row_sparse + * grad. */ template -inline void ProximalGroupAdagradUpdateDnsRspDnsImpl( - const ProximalGroupAdagradParam ¶m, const OpContext &ctx, - const TBlob &weight, const NDArray &grad, const TBlob &state, - const TBlob &last_update, const OpReqType &req, TBlob *out) { +inline void GroupAdagradUpdateDnsRspDnsImpl( + const GroupAdagradParam ¶m, const OpContext &ctx, const TBlob &weight, + const NDArray &grad, const TBlob &state, const OpReqType &req, TBlob *out) { using namespace mshadow; using namespace mshadow::expr; using namespace mshadow_op; @@ -225,7 +156,7 @@ inline void ProximalGroupAdagradUpdateDnsRspDnsImpl( return; } CHECK_EQ(req, kWriteInplace) - << "kWriteInplace is expected for sparse proximal_adagrad_update"; + << "kWriteInplace is expected for sparse group_adagrad_update"; CHECK_GT(weight.shape_.Size(), 0); CHECK_GT(state.shape_.Size(), 0); @@ -236,7 +167,6 @@ inline void ProximalGroupAdagradUpdateDnsRspDnsImpl( const IType *grad_idx = grad.aux_data(rowsparse::kIdx).dptr(); const DType *grad_val = grad.data().dptr(); DType *state_data = state.dptr(); - DType *last_update_data = last_update.dptr(); const nnvm::dim_t num_grad = grad.aux_shape(rowsparse::kIdx)[0]; const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); @@ -245,73 +175,67 @@ inline void ProximalGroupAdagradUpdateDnsRspDnsImpl( return; } - Kernel, xpu>::Launch( + Kernel, xpu>::Launch( s, num_grad, row_length, out_data, state_data, weight_data, grad_idx, - grad_val, last_update_data, static_cast(param.current_update), - static_cast(param.clip_gradient), - static_cast(param.rescale_grad), - static_cast(param.l2_regularization_strength), - static_cast(param.lr), static_cast(param.epsilon)); + grad_val, static_cast(param.clip_gradient), + static_cast(param.rescale_grad), static_cast(param.lr), + static_cast(param.epsilon)); }); }); } /* - * \brief Proximal adagrad update implementation for row_sparse grad. - * Both standard update and lazy update are supported. + * \brief AdaGrad update implementation for row_sparse grad. Both standard + * update and lazy update are supported. */ template -inline void ProximalGroupAdagradUpdateRspRspRspImpl( - const ProximalGroupAdagradParam ¶m, const OpContext &ctx, - const NDArray &weight, const NDArray &grad, const NDArray &state, - const NDArray &last_update_buffer, const OpReqType &req, NDArray *out) { +inline void +GroupAdagradUpdateRspRspRspImpl(const GroupAdagradParam ¶m, + const OpContext &ctx, const NDArray &weight, + const NDArray &grad, const NDArray &state, + const OpReqType &req, NDArray *out) { using namespace mshadow; using namespace mxnet_op; using namespace rowsparse; - CheckAllRowsPresent(weight, "ProximalGroupAdagradUpdate", "weights"); + CheckAllRowsPresent(weight, "GroupAdagradUpdate", "weights"); Stream *s = ctx.get_stream(); // fill history with zero values if (!state.storage_initialized()) { NDArray state_zeros = state; FillDnsZerosRspImpl(s, &state_zeros); } else { - CheckAllRowsPresent(state, "ProximalGroupAdagradUpdate", "states"); + CheckAllRowsPresent(state, "GroupAdagradUpdate", "states"); } // reuse dns rsp implementation when storage_shape == shape TBlob out_blob = out->data(); - ProximalGroupAdagradUpdateDnsRspDnsImpl( - param, ctx, weight.data(), grad, state.data(), last_update_buffer.data(), - req, &out_blob); + GroupAdagradUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, + state.data(), req, &out_blob); } template -inline void ProximalGroupAdagradUpdateEx(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - const ProximalGroupAdagradParam ¶m = - nnvm::get(attrs.parsed); +inline void GroupAdagradUpdateEx(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const GroupAdagradParam ¶m = nnvm::get(attrs.parsed); const auto weight_stype = inputs[0].storage_type(); const auto grad_stype = inputs[1].storage_type(); const auto state_stype = inputs[2].storage_type(); - const auto counter_stype = inputs[3].storage_type(); const auto output_stype = outputs[0].storage_type(); if (state_stype == weight_stype && output_stype == weight_stype && - weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && - counter_stype == kDefaultStorage) { + weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) { NDArray out = outputs[0]; - ProximalGroupAdagradUpdateRspRspRspImpl( - param, ctx, inputs[0], inputs[1], inputs[2], inputs[3], req[0], &out); + GroupAdagradUpdateRspRspRspImpl(param, ctx, inputs[0], inputs[1], + inputs[2], req[0], &out); } else if (state_stype == weight_stype && output_stype == weight_stype && weight_stype == kDefaultStorage && - grad_stype == kRowSparseStorage && - counter_stype == kDefaultStorage) { + grad_stype == kRowSparseStorage) { TBlob out_blob = outputs[0].data(); - ProximalGroupAdagradUpdateDnsRspDnsImpl( - param, ctx, inputs[0].data(), inputs[1], inputs[2].data(), - inputs[3].data(), req[0], &out_blob); + GroupAdagradUpdateDnsRspDnsImpl(param, ctx, inputs[0].data(), + inputs[1], inputs[2].data(), req[0], + &out_blob); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/contrib/optimizer_op.cc b/src/operator/contrib/optimizer_op.cc index 278ec62eab63..3abc70d6fdf3 100644 --- a/src/operator/contrib/optimizer_op.cc +++ b/src/operator/contrib/optimizer_op.cc @@ -23,21 +23,21 @@ * \brief Optimizer operators * \author Leonard Lausen */ -#include "./optimizer_op-inl.h" #include "../elemwise_op_common.h" +#include "./optimizer_op-inl.h" namespace mxnet { namespace op { -DMLC_REGISTER_PARAMETER(ProximalGroupAdagradParam); +DMLC_REGISTER_PARAMETER(GroupAdagradParam); /*! - * \brief Shape inference function for Proximal Group AdaGrad. + * \brief Shape inference function for Group AdaGrad. */ -inline bool ProximalGroupAdagradShape(const nnvm::NodeAttrs &attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 4U); +inline bool GroupAdagradShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 1U); SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); @@ -50,8 +50,8 @@ inline bool ProximalGroupAdagradShape(const nnvm::NodeAttrs &attrs, (in_attrs->at(0)[0] == in_attrs->at(2)[0]); } -NNVM_REGISTER_OP(_contrib_proximal_group_adagrad_update) -.describe(R"code(Update function for Proximal Group AdaGrad optimizer. +NNVM_REGISTER_OP(_contrib_group_adagrad_update) +.describe(R"code(Update function for Group AdaGrad optimizer. Referenced from *Adaptive Subgradient Methods for Online Learning and Stochastic Optimization*, and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but @@ -64,31 +64,29 @@ Updates are applied by:: div = grad / sqrt(history + float_stable_eps) weight -= div * lr -If `l2_regularization_strength > 0` a proximal operator is used to optimize with -group lasso objective. Weights are updated lazily if the gradient is sparse. -In particular, before using a set of weights for a forward pass, you may -want to ensure that the lazily accumulated group lasso regularization is -applied. +Weights are updated lazily if the gradient is sparse. In particular, before +using a set of weights for a forward pass, you may want to ensure that the +lazily accumulated group lasso regularization is applied. Note that non-zero values for the weight decay option are not supported. )code" ADD_FILELINE) -.set_num_inputs(4) +.set_num_inputs(3) .set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", ProximalGroupAdagradShape) -.set_attr("FInferType", ElemwiseType<4, 1>) -.set_attr("FInferStorageType", ProximalGroupAdagradStorageType) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", GroupAdagradShape) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FInferStorageType", GroupAdagradStorageType) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { - return std::vector{2, 3}; + return std::vector{2}; }) -.set_attr("FComputeEx", ProximalGroupAdagradUpdateEx) +.set_attr("FComputeEx", GroupAdagradUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("history", "NDArray-or-Symbol", "History") .add_argument("last_update", "NDArray-or-Symbol", "Array storing last update counter for each row.") -.add_arguments(ProximalGroupAdagradParam::__FIELDS__()); +.add_arguments(GroupAdagradParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/optimizer_op.cu b/src/operator/contrib/optimizer_op.cu index 49221e17c42c..40d99c5f0071 100644 --- a/src/operator/contrib/optimizer_op.cu +++ b/src/operator/contrib/optimizer_op.cu @@ -29,8 +29,8 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_contrib_proximal_group_adagrad_update) -.set_attr("FComputeEx", ProximalGroupAdagradUpdateEx); +NNVM_REGISTER_OP(_contrib_group_adagrad_update) +.set_attr("FComputeEx", GroupAdagradUpdateEx); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py index 71a50d8dc065..8ff8a7e1436b 100644 --- a/tests/python/unittest/test_contrib_optimizer.py +++ b/tests/python/unittest/test_contrib_optimizer.py @@ -23,22 +23,19 @@ from mxnet.test_utils import * -# ProximalGroupAdaGrad -class PyProximalGroupAdaGrad(mx.optimizer.Optimizer): - """The python reference of Proximal Group AdaGrad optimizer. +# * GroupAdaGrad +class PyGroupAdaGrad(mx.optimizer.Optimizer): + """The python reference of Group AdaGrad optimizer. Parameters ---------- - l2_regularization_strength : float - Strength of group lasso L2 regularization. eps: float, optional Small value to avoid division by 0. """ - def __init__(self, l2_regularization_strength=0.0, eps=1e-5, **kwargs): - super(PyProximalGroupAdaGrad, self).__init__(**kwargs) - self.l2_regularization_strength = l2_regularization_strength + def __init__(self, eps=1e-5, **kwargs): + super(PyGroupAdaGrad, self).__init__(**kwargs) self.float_stable_eps = eps def create_state(self, index, weight): @@ -59,34 +56,19 @@ def update(self, index, weight, grad, state): grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient) history[:] += mx.nd.mean(mx.nd.square(grad), axis=1, keepdims=True) div = lr * grad / mx.nd.sqrt(history + self.float_stable_eps) + weight[:] -= div - if self.l2_regularization_strength > 0: - scaled_l2 = lr / mx.nd.sqrt(history + self.float_stable_eps) \ - * self.l2_regularization_strength - norm = mx.nd.norm(weight - div, ord=2, axis=1, keepdims=True) - weight[:] = (weight - div) * \ - (1 - scaled_l2 / norm) - weight[:] *= norm > scaled_l2 - else: - weight[:] -= div - -def test_proximal_group_adagrad(): +def test_group_adagrad(): mx.random.seed(0) - opt1 = PyProximalGroupAdaGrad - opt2 = mx.optimizer.contrib.ProximalGroupAdaGrad + opt1 = PyGroupAdaGrad + opt2 = mx.optimizer.contrib.GroupAdaGrad shape = (3, 4) eps_options = [{}, {'eps': 1e-8}] cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] - l2_options = [{ - 'l2_regularization_strength': 0.0 - }, { - 'l2_regularization_strength': 0.05 - }] for dtype in [np.float32]: - for options in itertools.product(eps_options, cg_options, rg_options, - l2_options): + for options in itertools.product(eps_options, cg_options, rg_options): kwarg = dict(wd=0.0) for option in options: kwarg.update(option) @@ -96,25 +78,21 @@ def test_proximal_group_adagrad(): shape, dtype, compare_states=False) - if kwarg.get('l2_regularization_strength', 0.0) == 0.0: - # By design results for PyOp which always performs - # dense update will differ if - # l2_regularization_strength > 0 - compare_optimizer( - opt1(**kwarg), - opt2(**kwarg), - shape, - dtype, - w_stype='row_sparse', - g_stype='row_sparse', - compare_states=False) - compare_optimizer( - opt1(**kwarg), - opt2(**kwarg), - shape, - dtype, - g_stype='row_sparse', - compare_states=False) + compare_optimizer( + opt1(**kwarg), + opt2(**kwarg), + shape, + dtype, + w_stype='row_sparse', + g_stype='row_sparse', + compare_states=False) + compare_optimizer( + opt1(**kwarg), + opt2(**kwarg), + shape, + dtype, + g_stype='row_sparse', + compare_states=False) if __name__ == '__main__':