diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index 8f60bcd0f13c..de86aedff691 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -136,6 +136,7 @@ Code examples are placed throughout the API documentation and these can be run a
:maxdepth: 1
optimization/optimization.md
+ optimization/contrib.md
```
## Profiler API
diff --git a/docs/api/python/optimization/contrib.md b/docs/api/python/optimization/contrib.md
new file mode 100644
index 000000000000..8fc261f4f052
--- /dev/null
+++ b/docs/api/python/optimization/contrib.md
@@ -0,0 +1,52 @@
+# Contrib Optimization API
+
+```eval_rst
+ .. currentmodule:: mxnet.optimizer.contrib
+```
+
+## Overview
+
+This document summaries the contrib APIs used to initialize and update the model
+weights during training
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ mxnet.optimizer.contrib
+```
+
+The `Contrib Optimization` API, defined in the `optimizer.contrib` package, provides
+many useful experimental APIs for new features.
+This is a place for the community to try out the new features,
+so that feature contributors can receive feedback.
+
+```eval_rst
+.. warning:: This package contains experimental APIs and may change in the near future.
+```
+
+In the rest of this document, we list routines provided by the `optimizer.contrib` package.
+
+## Contrib
+
+```eval_rst
+.. currentmodule:: mxnet.optimizer.contrib
+
+.. autosummary::
+ :nosignatures:
+
+ GroupAdaGrad
+```
+
+## API Reference
+
+
+
+```eval_rst
+
+.. automodule:: mxnet.optimizer.contrib
+ :members:
+
+```
+
+
diff --git a/python/mxnet/optimizer/__init__.py b/python/mxnet/optimizer/__init__.py
new file mode 100644
index 000000000000..72eb5a741520
--- /dev/null
+++ b/python/mxnet/optimizer/__init__.py
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Optimizer API of MXNet."""
+
+from . import optimizer, contrib
+# pylint: disable=wildcard-import
+from .optimizer import *
+# pylint: enable=wildcard-import
+
+__all__ = optimizer.__all__ + ['contrib']
diff --git a/python/mxnet/optimizer/contrib.py b/python/mxnet/optimizer/contrib.py
new file mode 100644
index 000000000000..d269aa1bd069
--- /dev/null
+++ b/python/mxnet/optimizer/contrib.py
@@ -0,0 +1,100 @@
+# coding: utf-8
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=too-many-lines
+"""Contrib optimizers."""
+from ..ndarray import (NDArray, clip, contrib, mean, sqrt, square, zeros)
+from .optimizer import Optimizer
+
+# convenience wrapper for Optimizer.Register
+register = Optimizer.register # pylint: disable=invalid-name
+
+__all__ = ['GroupAdaGrad']
+
+
+@register
+class GroupAdaGrad(Optimizer):
+ """Adagrad optimizer with row-wise learning rates.
+
+ This class implements the AdaGrad optimizer described in *Adaptive
+ Subgradient Methods for Online Learning and Stochastic Optimization*, and
+ available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but
+ uses only a single learning rate for every row of the parameter array.
+
+ This optimizer updates each weight by::
+
+ grad = clip(grad * rescale_grad, clip_gradient)
+ history += mean(square(grad), axis=1, keepdims=True)
+ div = grad / sqrt(history + float_stable_eps)
+ weight -= div * lr
+
+ Weights are updated lazily if the gradient is sparse.
+
+ For details of the update algorithm see
+ :class:`~mxnet.ndarray.contrib.group_adagrad_update`.
+
+ This optimizer accepts the following parameters in addition to those
+ accepted by :class:`.Optimizer`. Weight decay is not supported.
+
+ Parameters
+ ----------
+ eps: float, optional
+ Initial value of the history accumulator. Avoids division by 0.
+
+ """
+
+ def __init__(self, eps=1e-5, **kwargs):
+ super(GroupAdaGrad, self).__init__(**kwargs)
+ self.float_stable_eps = eps
+
+ def create_state(self, index, weight):
+ assert len(weight.shape) == 2
+ history = zeros(
+ (weight.shape[0], 1), weight.context, stype=weight.stype)
+ return history
+
+ def update(self, index, weight, grad, state):
+ assert (isinstance(weight, NDArray))
+ assert (isinstance(grad, NDArray))
+ self._update_count(index)
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+ assert wd == 0, 'Weight decay is not supported for GroupAdaGrad'
+
+ is_sparse = grad.stype == 'row_sparse'
+ if is_sparse:
+ kwargs = {
+ 'epsilon': self.float_stable_eps,
+ 'rescale_grad': self.rescale_grad
+ }
+ if self.clip_gradient:
+ kwargs['clip_gradient'] = self.clip_gradient
+ contrib.group_adagrad_update(
+ weight,
+ grad,
+ state,
+ out=weight,
+ lr=lr,
+ **kwargs)
+ else:
+ grad = grad * self.rescale_grad
+ if self.clip_gradient is not None:
+ grad = clip(grad, -self.clip_gradient, self.clip_gradient)
+ state[:] += mean(square(grad), axis=1, keepdims=True)
+ div = lr * grad / sqrt(state + self.float_stable_eps)
+ weight[:] -= div
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer/optimizer.py
similarity index 98%
rename from python/mxnet/optimizer.py
rename to python/mxnet/optimizer/optimizer.py
index b69d0c9af0dc..8f9cf366f09b 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer/optimizer.py
@@ -23,13 +23,19 @@
import pickle
import warnings
import numpy
-from .base import py_str
-from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
-from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
- mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
- signsgd_update, signum_update)
-from .ndarray import sparse
-from .random import normal
+from ..base import py_str
+from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
+from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
+ mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
+ signsgd_update, signum_update)
+from ..ndarray import sparse
+from ..random import normal
+
+__all__ = [
+ 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD',
+ 'NAG', 'NDArray', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD',
+ 'Signum', 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
+]
class Optimizer(object):
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index c555b2fdfaf8..0bb28a0ef13a 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1957,3 +1957,44 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc
% (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
str(buckets), str(probs)))
return cs_ret_l
+
+def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
+ """Compare ndarray tuple."""
+ if t1 is not None and t2 is not None:
+ if isinstance(t1, tuple):
+ for s1, s2 in zip(t1, t2):
+ compare_ndarray_tuple(s1, s2, rtol, atol)
+ else:
+ assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol)
+
+
+def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
+ rtol=1e-4, atol=1e-5, compare_states=True):
+ """Compare opt1 and opt2."""
+ if w_stype == 'default':
+ w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
+ w1 = w2.copyto(default_context())
+ elif w_stype == 'row_sparse' or w_stype == 'csr':
+ w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
+ w1 = w2.copyto(default_context()).tostype('default')
+ else:
+ raise Exception("type not supported yet")
+ if g_stype == 'default':
+ g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
+ g1 = g2.copyto(default_context())
+ elif g_stype == 'row_sparse' or g_stype == 'csr':
+ g2 = rand_ndarray(shape, g_stype, dtype=dtype)
+ g1 = g2.copyto(default_context()).tostype('default')
+ else:
+ raise Exception("type not supported yet")
+
+ state1 = opt1.create_state_multi_precision(0, w1)
+ state2 = opt2.create_state_multi_precision(0, w2)
+ if compare_states:
+ compare_ndarray_tuple(state1, state2)
+
+ opt1.update_multi_precision(0, w1, g1, state1)
+ opt2.update_multi_precision(0, w2, g2, state2)
+ if compare_states:
+ compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
+ assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
diff --git a/src/operator/contrib/optimizer_op-inl.h b/src/operator/contrib/optimizer_op-inl.h
new file mode 100644
index 000000000000..fd556a4231cb
--- /dev/null
+++ b/src/operator/contrib/optimizer_op-inl.h
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file optimizer_op-inl.h
+ * \brief Optimizer operators
+ * \author Leonard Lausen
+ */
+#ifndef MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_
+#define MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "../elemwise_op_common.h"
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../tensor/init_op.h"
+#include "../tensor/util/tensor_util-inl.h"
+
+namespace mxnet {
+namespace op {
+
+struct GroupAdagradParam : public dmlc::Parameter {
+ float lr;
+ float epsilon;
+ float rescale_grad;
+ float clip_gradient;
+ DMLC_DECLARE_PARAMETER(GroupAdagradParam) {
+ DMLC_DECLARE_FIELD(lr).describe("Learning rate");
+ DMLC_DECLARE_FIELD(rescale_grad)
+ .set_default(1.0f)
+ .describe("Rescale gradient to grad = rescale_grad*grad.");
+ DMLC_DECLARE_FIELD(clip_gradient)
+ .set_default(-1.0f)
+ .describe(
+ "Clip gradient to the range of [-clip_gradient, clip_gradient] "
+ "If clip_gradient <= 0, gradient clipping is turned off. "
+ "grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(epsilon).set_default(1.0e-5).describe(
+ "Epsilon for numerical stability");
+ }
+};
+
+inline bool GroupAdagradStorageType(const nnvm::NodeAttrs &attrs,
+ const int dev_mask,
+ DispatchMode *dispatch_mode,
+ std::vector *in_attrs,
+ std::vector *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 3U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ const int weight_stype = in_attrs->at(0);
+ const int grad_stype = in_attrs->at(1);
+ const int state_stype = in_attrs->at(2);
+ bool dispatched = false;
+ if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+ // dns, ... -> dns
+ dispatched = storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode,
+ DispatchMode::kFCompute);
+ }
+ if (!dispatched && grad_stype == kRowSparseStorage &&
+ (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
+ state_stype == weight_stype) {
+ // weight and state share stype, grad's stype = rsp
+ dispatched = storage_type_assign(
+ out_attrs, static_cast(weight_stype), dispatch_mode,
+ DispatchMode::kFComputeEx);
+ }
+ return dispatched;
+}
+
+/*! \brief kernel for sparse adagrad update with group sparsity regularization
+ */
+template struct GroupAdagradDnsRspKernel {
+ template
+ MSHADOW_XINLINE static void
+ Map(int i, const index_t row_length, DType *out_data, DType *state_data,
+ DType *weight_data, const IType *grad_idx, const DType *grad_data,
+ const DType clip_gradient, const DType rescale_grad, const DType lr,
+ const DType eps) {
+ using namespace mshadow_op;
+
+ // Helper to obtain index into weight / state arrays
+ auto get_data_j = [&i, &grad_idx, &row_length](index_t j) -> index_t {
+ return grad_idx[i] * row_length + j;
+ };
+ // Helper to obtain explicit rescaled and clipped grad
+ auto get_grad_rescaled = [&i, &row_length, &grad_data, &rescale_grad,
+ &clip_gradient](index_t j) -> DType {
+ index_t grad_j = i * row_length + j;
+ DType grad_rescaled = grad_data[grad_j] * rescale_grad;
+ if (clip_gradient >= 0.0f) {
+ grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
+ }
+ return grad_rescaled;
+ };
+
+ // Update history states
+ DType grad_ssq = 0;
+ for (index_t j = 0; j < row_length; j++) {
+ const DType grad_rescaled = get_grad_rescaled(j);
+ grad_ssq += grad_rescaled * grad_rescaled;
+ }
+ state_data[grad_idx[i]] += grad_ssq / row_length;
+
+ // Standard Adagrad Update
+ for (index_t j = 0; j < row_length; j++) {
+ // clang-format off
+ const DType grad_rescaled = get_grad_rescaled(j);
+ index_t data_j = get_data_j(j);
+ const DType div = lr * grad_rescaled / square_root::Map(state_data[grad_idx[i]] + eps);
+ out_data[data_j] = weight_data[data_j] - div;
+ // clang-format on
+ }
+ }
+};
+
+/*
+ * \brief Group Adagrad update implementation for dense weight and row_sparse
+ * grad.
+ */
+template
+inline void GroupAdagradUpdateDnsRspDnsImpl(
+ const GroupAdagradParam ¶m, const OpContext &ctx, const TBlob &weight,
+ const NDArray &grad, const TBlob &state, const OpReqType &req, TBlob *out) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ using namespace mshadow_op;
+ using namespace mxnet_op;
+ Stream *s = ctx.get_stream();
+ CHECK_EQ(grad.storage_type(), kRowSparseStorage);
+ // if gradients are zeros, no weights are updated
+ if (req == kNullOp) {
+ return;
+ }
+ CHECK_EQ(req, kWriteInplace)
+ << "kWriteInplace is expected for sparse group_adagrad_update";
+ CHECK_GT(weight.shape_.Size(), 0);
+ CHECK_GT(state.shape_.Size(), 0);
+
+ MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+ MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
+ DType *weight_data = weight.dptr();
+ DType *out_data = out->dptr();
+ const IType *grad_idx = grad.aux_data(rowsparse::kIdx).dptr();
+ const DType *grad_val = grad.data().dptr();
+ DType *state_data = state.dptr();
+ const nnvm::dim_t num_grad = grad.aux_shape(rowsparse::kIdx)[0];
+ const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+
+ if (!grad.storage_initialized()) {
+ // Lazy update with 0 gradient
+ return;
+ }
+
+ Kernel, xpu>::Launch(
+ s, num_grad, row_length, out_data, state_data, weight_data, grad_idx,
+ grad_val, static_cast(param.clip_gradient),
+ static_cast(param.rescale_grad), static_cast(param.lr),
+ static_cast(param.epsilon));
+ });
+ });
+}
+
+/*
+ * \brief AdaGrad update implementation for row_sparse grad. Both standard
+ * update and lazy update are supported.
+ */
+template
+inline void
+GroupAdagradUpdateRspRspRspImpl(const GroupAdagradParam ¶m,
+ const OpContext &ctx, const NDArray &weight,
+ const NDArray &grad, const NDArray &state,
+ const OpReqType &req, NDArray *out) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using namespace rowsparse;
+ CheckAllRowsPresent(weight, "GroupAdagradUpdate", "weights");
+ Stream *s = ctx.get_stream();
+ // fill history with zero values
+ if (!state.storage_initialized()) {
+ NDArray state_zeros = state;
+ FillDnsZerosRspImpl(s, &state_zeros);
+ } else {
+ CheckAllRowsPresent(state, "GroupAdagradUpdate", "states");
+ }
+ // reuse dns rsp implementation when storage_shape == shape
+ TBlob out_blob = out->data();
+ GroupAdagradUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad,
+ state.data(), req, &out_blob);
+}
+
+template
+inline void GroupAdagradUpdateEx(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector &inputs,
+ const std::vector &req,
+ const std::vector &outputs) {
+ const GroupAdagradParam ¶m = nnvm::get(attrs.parsed);
+ const auto weight_stype = inputs[0].storage_type();
+ const auto grad_stype = inputs[1].storage_type();
+ const auto state_stype = inputs[2].storage_type();
+ const auto output_stype = outputs[0].storage_type();
+
+ if (state_stype == weight_stype && output_stype == weight_stype &&
+ weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) {
+ NDArray out = outputs[0];
+ GroupAdagradUpdateRspRspRspImpl(param, ctx, inputs[0], inputs[1],
+ inputs[2], req[0], &out);
+ } else if (state_stype == weight_stype && output_stype == weight_stype &&
+ weight_stype == kDefaultStorage &&
+ grad_stype == kRowSparseStorage) {
+ TBlob out_blob = outputs[0].data();
+ GroupAdagradUpdateDnsRspDnsImpl(param, ctx, inputs[0].data(),
+ inputs[1], inputs[2].data(), req[0],
+ &out_blob);
+ } else {
+ LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+ }
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_
diff --git a/src/operator/contrib/optimizer_op.cc b/src/operator/contrib/optimizer_op.cc
new file mode 100644
index 000000000000..96f431bc569d
--- /dev/null
+++ b/src/operator/contrib/optimizer_op.cc
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file optimizer_op.cc
+ * \brief Optimizer operators
+ * \author Leonard Lausen
+ */
+#include "../elemwise_op_common.h"
+#include "./optimizer_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(GroupAdagradParam);
+
+/*!
+ * \brief Shape inference function for Group AdaGrad.
+ */
+inline bool GroupAdagradShape(const nnvm::NodeAttrs &attrs,
+ std::vector *in_attrs,
+ std::vector *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 3U);
+ CHECK_EQ(out_attrs->size(), 1U);
+
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
+ SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+ SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
+
+ return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U &&
+ (in_attrs->at(0)[0] == in_attrs->at(1)[0]) &&
+ (in_attrs->at(0)[0] == in_attrs->at(2)[0]);
+}
+
+NNVM_REGISTER_OP(_contrib_group_adagrad_update)
+.describe(R"code(Update function for Group AdaGrad optimizer.
+
+Referenced from *Adaptive Subgradient Methods for Online Learning and Stochastic Optimization*,
+and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but
+uses only a single learning rate for every row of the parameter array.
+
+Updates are applied by::
+
+ grad = clip(grad * rescale_grad, clip_gradient)
+ history += mean(square(grad), axis=1, keepdims=True)
+ div = grad / sqrt(history + float_stable_eps)
+ weight -= div * lr
+
+Weights are updated lazily if the gradient is sparse.
+
+Note that non-zero values for the weight decay option are not supported.
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser)
+.set_attr("FInferShape", GroupAdagradShape)
+.set_attr("FInferType", ElemwiseType<3, 1>)
+.set_attr("FInferStorageType", GroupAdagradStorageType)
+.set_attr("FMutateInputs",
+ [](const nnvm::NodeAttrs& attrs) {
+ return std::vector{2};
+ })
+.set_attr("FComputeEx", GroupAdagradUpdateEx)
+.add_argument("weight", "NDArray-or-Symbol", "Weight")
+.add_argument("grad", "NDArray-or-Symbol", "Gradient")
+.add_argument("history", "NDArray-or-Symbol", "History")
+.add_arguments(GroupAdagradParam::__FIELDS__());
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/contrib/optimizer_op.cu b/src/operator/contrib/optimizer_op.cu
new file mode 100644
index 000000000000..40d99c5f0071
--- /dev/null
+++ b/src/operator/contrib/optimizer_op.cu
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file optimizer_op.cu
+ * \brief Optimizer operators
+ * \author Leonard Lausen
+ */
+#include "./optimizer_op-inl.h"
+#include
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_contrib_group_adagrad_update)
+.set_attr("FComputeEx", GroupAdagradUpdateEx);
+
+} // namespace op
+} // namespace mxnet
diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py
new file mode 100644
index 000000000000..8ff8a7e1436b
--- /dev/null
+++ b/tests/python/unittest/test_contrib_optimizer.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import itertools
+
+import numpy as np
+
+import mxnet as mx
+from mxnet.test_utils import *
+
+
+# * GroupAdaGrad
+class PyGroupAdaGrad(mx.optimizer.Optimizer):
+ """The python reference of Group AdaGrad optimizer.
+
+ Parameters
+ ----------
+ eps: float, optional
+ Small value to avoid division by 0.
+
+ """
+
+ def __init__(self, eps=1e-5, **kwargs):
+ super(PyGroupAdaGrad, self).__init__(**kwargs)
+ self.float_stable_eps = eps
+
+ def create_state(self, index, weight):
+ assert len(weight.shape) == 2
+ history = mx.nd.zeros(
+ (weight.shape[0], 1), weight.context, stype=weight.stype)
+ return history
+
+ def update(self, index, weight, grad, state):
+ self._update_count(index)
+ lr = self._get_lr(index)
+ wd = self._get_wd(index)
+ assert wd == 0
+
+ history = state
+ grad = grad * self.rescale_grad
+ if self.clip_gradient is not None:
+ grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
+ history[:] += mx.nd.mean(mx.nd.square(grad), axis=1, keepdims=True)
+ div = lr * grad / mx.nd.sqrt(history + self.float_stable_eps)
+ weight[:] -= div
+
+
+def test_group_adagrad():
+ mx.random.seed(0)
+ opt1 = PyGroupAdaGrad
+ opt2 = mx.optimizer.contrib.GroupAdaGrad
+ shape = (3, 4)
+ eps_options = [{}, {'eps': 1e-8}]
+ cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+ rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+ for dtype in [np.float32]:
+ for options in itertools.product(eps_options, cg_options, rg_options):
+ kwarg = dict(wd=0.0)
+ for option in options:
+ kwarg.update(option)
+ compare_optimizer(
+ opt1(**kwarg),
+ opt2(**kwarg),
+ shape,
+ dtype,
+ compare_states=False)
+ compare_optimizer(
+ opt1(**kwarg),
+ opt2(**kwarg),
+ shape,
+ dtype,
+ w_stype='row_sparse',
+ g_stype='row_sparse',
+ compare_states=False)
+ compare_optimizer(
+ opt1(**kwarg),
+ opt2(**kwarg),
+ shape,
+ dtype,
+ g_stype='row_sparse',
+ compare_states=False)
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 496a61f356b3..334b7d4c0fdb 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -71,43 +71,6 @@ def test_lr_wd_mult():
assert not mx.test_utils.almost_equal(args1['fc1_bias'], args2['fc1_bias'], 1e-1)
assert not mx.test_utils.almost_equal(args1['fc2_weight'], args2['fc2_weight'], 1e-1)
-def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
- if t1 is not None and t2 is not None:
- if isinstance(t1, tuple):
- for s1, s2 in zip(t1, t2):
- compare_ndarray_tuple(s1, s2, rtol, atol)
- else:
- assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol)
-
-
-def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
- rtol=1e-4, atol=1e-5):
- if w_stype == 'default':
- w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
- w1 = w2.copyto(default_context())
- elif w_stype == 'row_sparse' or w_stype == 'csr':
- w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
- w1 = w2.copyto(default_context()).tostype('default')
- else:
- raise Exception("type not supported yet")
- if g_stype == 'default':
- g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
- g1 = g2.copyto(default_context())
- elif g_stype == 'row_sparse' or g_stype == 'csr':
- g2 = rand_ndarray(shape, g_stype, dtype=dtype)
- g1 = g2.copyto(default_context()).tostype('default')
- else:
- raise Exception("type not supported yet")
-
- state1 = opt1.create_state_multi_precision(0, w1)
- state2 = opt2.create_state_multi_precision(0, w2)
- compare_ndarray_tuple(state1, state2)
-
- opt1.update_multi_precision(0, w1, g1, state1)
- opt2.update_multi_precision(0, w2, g2, state2)
- compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
- assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
-
# SGD
class PySGD(mx.optimizer.Optimizer):