diff --git a/docs/api/python/index.md b/docs/api/python/index.md index 8f60bcd0f13c..de86aedff691 100644 --- a/docs/api/python/index.md +++ b/docs/api/python/index.md @@ -136,6 +136,7 @@ Code examples are placed throughout the API documentation and these can be run a :maxdepth: 1 optimization/optimization.md + optimization/contrib.md ``` ## Profiler API diff --git a/docs/api/python/optimization/contrib.md b/docs/api/python/optimization/contrib.md new file mode 100644 index 000000000000..8fc261f4f052 --- /dev/null +++ b/docs/api/python/optimization/contrib.md @@ -0,0 +1,52 @@ +# Contrib Optimization API + +```eval_rst + .. currentmodule:: mxnet.optimizer.contrib +``` + +## Overview + +This document summaries the contrib APIs used to initialize and update the model +weights during training + +```eval_rst +.. autosummary:: + :nosignatures: + + mxnet.optimizer.contrib +``` + +The `Contrib Optimization` API, defined in the `optimizer.contrib` package, provides +many useful experimental APIs for new features. +This is a place for the community to try out the new features, +so that feature contributors can receive feedback. + +```eval_rst +.. warning:: This package contains experimental APIs and may change in the near future. +``` + +In the rest of this document, we list routines provided by the `optimizer.contrib` package. + +## Contrib + +```eval_rst +.. currentmodule:: mxnet.optimizer.contrib + +.. autosummary:: + :nosignatures: + + GroupAdaGrad +``` + +## API Reference + + + +```eval_rst + +.. automodule:: mxnet.optimizer.contrib + :members: + +``` + + diff --git a/python/mxnet/optimizer/__init__.py b/python/mxnet/optimizer/__init__.py new file mode 100644 index 000000000000..72eb5a741520 --- /dev/null +++ b/python/mxnet/optimizer/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Optimizer API of MXNet.""" + +from . import optimizer, contrib +# pylint: disable=wildcard-import +from .optimizer import * +# pylint: enable=wildcard-import + +__all__ = optimizer.__all__ + ['contrib'] diff --git a/python/mxnet/optimizer/contrib.py b/python/mxnet/optimizer/contrib.py new file mode 100644 index 000000000000..d269aa1bd069 --- /dev/null +++ b/python/mxnet/optimizer/contrib.py @@ -0,0 +1,100 @@ +# coding: utf-8 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=too-many-lines +"""Contrib optimizers.""" +from ..ndarray import (NDArray, clip, contrib, mean, sqrt, square, zeros) +from .optimizer import Optimizer + +# convenience wrapper for Optimizer.Register +register = Optimizer.register # pylint: disable=invalid-name + +__all__ = ['GroupAdaGrad'] + + +@register +class GroupAdaGrad(Optimizer): + """Adagrad optimizer with row-wise learning rates. + + This class implements the AdaGrad optimizer described in *Adaptive + Subgradient Methods for Online Learning and Stochastic Optimization*, and + available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but + uses only a single learning rate for every row of the parameter array. + + This optimizer updates each weight by:: + + grad = clip(grad * rescale_grad, clip_gradient) + history += mean(square(grad), axis=1, keepdims=True) + div = grad / sqrt(history + float_stable_eps) + weight -= div * lr + + Weights are updated lazily if the gradient is sparse. + + For details of the update algorithm see + :class:`~mxnet.ndarray.contrib.group_adagrad_update`. + + This optimizer accepts the following parameters in addition to those + accepted by :class:`.Optimizer`. Weight decay is not supported. + + Parameters + ---------- + eps: float, optional + Initial value of the history accumulator. Avoids division by 0. + + """ + + def __init__(self, eps=1e-5, **kwargs): + super(GroupAdaGrad, self).__init__(**kwargs) + self.float_stable_eps = eps + + def create_state(self, index, weight): + assert len(weight.shape) == 2 + history = zeros( + (weight.shape[0], 1), weight.context, stype=weight.stype) + return history + + def update(self, index, weight, grad, state): + assert (isinstance(weight, NDArray)) + assert (isinstance(grad, NDArray)) + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + assert wd == 0, 'Weight decay is not supported for GroupAdaGrad' + + is_sparse = grad.stype == 'row_sparse' + if is_sparse: + kwargs = { + 'epsilon': self.float_stable_eps, + 'rescale_grad': self.rescale_grad + } + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient + contrib.group_adagrad_update( + weight, + grad, + state, + out=weight, + lr=lr, + **kwargs) + else: + grad = grad * self.rescale_grad + if self.clip_gradient is not None: + grad = clip(grad, -self.clip_gradient, self.clip_gradient) + state[:] += mean(square(grad), axis=1, keepdims=True) + div = lr * grad / sqrt(state + self.float_stable_eps) + weight[:] -= div diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer/optimizer.py similarity index 98% rename from python/mxnet/optimizer.py rename to python/mxnet/optimizer/optimizer.py index b69d0c9af0dc..8f9cf366f09b 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -23,13 +23,19 @@ import pickle import warnings import numpy -from .base import py_str -from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) -from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, - mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update) -from .ndarray import sparse -from .random import normal +from ..base import py_str +from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) +from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, + mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, + signsgd_update, signum_update) +from ..ndarray import sparse +from ..random import normal + +__all__ = [ + 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD', + 'NAG', 'NDArray', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', + 'Signum', 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register' +] class Optimizer(object): diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index c555b2fdfaf8..0bb28a0ef13a 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1957,3 +1957,44 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc % (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l), str(buckets), str(probs))) return cs_ret_l + +def compare_ndarray_tuple(t1, t2, rtol=None, atol=None): + """Compare ndarray tuple.""" + if t1 is not None and t2 is not None: + if isinstance(t1, tuple): + for s1, s2 in zip(t1, t2): + compare_ndarray_tuple(s1, s2, rtol, atol) + else: + assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol) + + +def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default', + rtol=1e-4, atol=1e-5, compare_states=True): + """Compare opt1 and opt2.""" + if w_stype == 'default': + w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) + w1 = w2.copyto(default_context()) + elif w_stype == 'row_sparse' or w_stype == 'csr': + w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype) + w1 = w2.copyto(default_context()).tostype('default') + else: + raise Exception("type not supported yet") + if g_stype == 'default': + g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) + g1 = g2.copyto(default_context()) + elif g_stype == 'row_sparse' or g_stype == 'csr': + g2 = rand_ndarray(shape, g_stype, dtype=dtype) + g1 = g2.copyto(default_context()).tostype('default') + else: + raise Exception("type not supported yet") + + state1 = opt1.create_state_multi_precision(0, w1) + state2 = opt2.create_state_multi_precision(0, w2) + if compare_states: + compare_ndarray_tuple(state1, state2) + + opt1.update_multi_precision(0, w1, g1, state1) + opt2.update_multi_precision(0, w2, g2, state2) + if compare_states: + compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol) + assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol) diff --git a/src/operator/contrib/optimizer_op-inl.h b/src/operator/contrib/optimizer_op-inl.h new file mode 100644 index 000000000000..fd556a4231cb --- /dev/null +++ b/src/operator/contrib/optimizer_op-inl.h @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file optimizer_op-inl.h + * \brief Optimizer operators + * \author Leonard Lausen + */ +#ifndef MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_ +#define MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "../elemwise_op_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../tensor/init_op.h" +#include "../tensor/util/tensor_util-inl.h" + +namespace mxnet { +namespace op { + +struct GroupAdagradParam : public dmlc::Parameter { + float lr; + float epsilon; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(GroupAdagradParam) { + DMLC_DECLARE_FIELD(lr).describe("Learning rate"); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe( + "Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(epsilon).set_default(1.0e-5).describe( + "Epsilon for numerical stability"); + } +}; + +inline bool GroupAdagradStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 1U); + const int weight_stype = in_attrs->at(0); + const int grad_stype = in_attrs->at(1); + const int state_stype = in_attrs->at(2); + bool dispatched = false; + if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { + // dns, ... -> dns + dispatched = storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode, + DispatchMode::kFCompute); + } + if (!dispatched && grad_stype == kRowSparseStorage && + (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) && + state_stype == weight_stype) { + // weight and state share stype, grad's stype = rsp + dispatched = storage_type_assign( + out_attrs, static_cast(weight_stype), dispatch_mode, + DispatchMode::kFComputeEx); + } + return dispatched; +} + +/*! \brief kernel for sparse adagrad update with group sparsity regularization + */ +template struct GroupAdagradDnsRspKernel { + template + MSHADOW_XINLINE static void + Map(int i, const index_t row_length, DType *out_data, DType *state_data, + DType *weight_data, const IType *grad_idx, const DType *grad_data, + const DType clip_gradient, const DType rescale_grad, const DType lr, + const DType eps) { + using namespace mshadow_op; + + // Helper to obtain index into weight / state arrays + auto get_data_j = [&i, &grad_idx, &row_length](index_t j) -> index_t { + return grad_idx[i] * row_length + j; + }; + // Helper to obtain explicit rescaled and clipped grad + auto get_grad_rescaled = [&i, &row_length, &grad_data, &rescale_grad, + &clip_gradient](index_t j) -> DType { + index_t grad_j = i * row_length + j; + DType grad_rescaled = grad_data[grad_j] * rescale_grad; + if (clip_gradient >= 0.0f) { + grad_rescaled = clip::Map(grad_rescaled, clip_gradient); + } + return grad_rescaled; + }; + + // Update history states + DType grad_ssq = 0; + for (index_t j = 0; j < row_length; j++) { + const DType grad_rescaled = get_grad_rescaled(j); + grad_ssq += grad_rescaled * grad_rescaled; + } + state_data[grad_idx[i]] += grad_ssq / row_length; + + // Standard Adagrad Update + for (index_t j = 0; j < row_length; j++) { + // clang-format off + const DType grad_rescaled = get_grad_rescaled(j); + index_t data_j = get_data_j(j); + const DType div = lr * grad_rescaled / square_root::Map(state_data[grad_idx[i]] + eps); + out_data[data_j] = weight_data[data_j] - div; + // clang-format on + } + } +}; + +/* + * \brief Group Adagrad update implementation for dense weight and row_sparse + * grad. + */ +template +inline void GroupAdagradUpdateDnsRspDnsImpl( + const GroupAdagradParam ¶m, const OpContext &ctx, const TBlob &weight, + const NDArray &grad, const TBlob &state, const OpReqType &req, TBlob *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + CHECK_EQ(grad.storage_type(), kRowSparseStorage); + // if gradients are zeros, no weights are updated + if (req == kNullOp) { + return; + } + CHECK_EQ(req, kWriteInplace) + << "kWriteInplace is expected for sparse group_adagrad_update"; + CHECK_GT(weight.shape_.Size(), 0); + CHECK_GT(state.shape_.Size(), 0); + + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + DType *weight_data = weight.dptr(); + DType *out_data = out->dptr(); + const IType *grad_idx = grad.aux_data(rowsparse::kIdx).dptr(); + const DType *grad_val = grad.data().dptr(); + DType *state_data = state.dptr(); + const nnvm::dim_t num_grad = grad.aux_shape(rowsparse::kIdx)[0]; + const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); + + if (!grad.storage_initialized()) { + // Lazy update with 0 gradient + return; + } + + Kernel, xpu>::Launch( + s, num_grad, row_length, out_data, state_data, weight_data, grad_idx, + grad_val, static_cast(param.clip_gradient), + static_cast(param.rescale_grad), static_cast(param.lr), + static_cast(param.epsilon)); + }); + }); +} + +/* + * \brief AdaGrad update implementation for row_sparse grad. Both standard + * update and lazy update are supported. + */ +template +inline void +GroupAdagradUpdateRspRspRspImpl(const GroupAdagradParam ¶m, + const OpContext &ctx, const NDArray &weight, + const NDArray &grad, const NDArray &state, + const OpReqType &req, NDArray *out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace rowsparse; + CheckAllRowsPresent(weight, "GroupAdagradUpdate", "weights"); + Stream *s = ctx.get_stream(); + // fill history with zero values + if (!state.storage_initialized()) { + NDArray state_zeros = state; + FillDnsZerosRspImpl(s, &state_zeros); + } else { + CheckAllRowsPresent(state, "GroupAdagradUpdate", "states"); + } + // reuse dns rsp implementation when storage_shape == shape + TBlob out_blob = out->data(); + GroupAdagradUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, + state.data(), req, &out_blob); +} + +template +inline void GroupAdagradUpdateEx(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const GroupAdagradParam ¶m = nnvm::get(attrs.parsed); + const auto weight_stype = inputs[0].storage_type(); + const auto grad_stype = inputs[1].storage_type(); + const auto state_stype = inputs[2].storage_type(); + const auto output_stype = outputs[0].storage_type(); + + if (state_stype == weight_stype && output_stype == weight_stype && + weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + GroupAdagradUpdateRspRspRspImpl(param, ctx, inputs[0], inputs[1], + inputs[2], req[0], &out); + } else if (state_stype == weight_stype && output_stype == weight_stype && + weight_stype == kDefaultStorage && + grad_stype == kRowSparseStorage) { + TBlob out_blob = outputs[0].data(); + GroupAdagradUpdateDnsRspDnsImpl(param, ctx, inputs[0].data(), + inputs[1], inputs[2].data(), req[0], + &out_blob); + } else { + LogUnimplementedOp(attrs, ctx, inputs, req, outputs); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_ diff --git a/src/operator/contrib/optimizer_op.cc b/src/operator/contrib/optimizer_op.cc new file mode 100644 index 000000000000..96f431bc569d --- /dev/null +++ b/src/operator/contrib/optimizer_op.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file optimizer_op.cc + * \brief Optimizer operators + * \author Leonard Lausen + */ +#include "../elemwise_op_common.h" +#include "./optimizer_op-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(GroupAdagradParam); + +/*! + * \brief Shape inference function for Group AdaGrad. + */ +inline bool GroupAdagradShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 1U); + + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + + return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U && + (in_attrs->at(0)[0] == in_attrs->at(1)[0]) && + (in_attrs->at(0)[0] == in_attrs->at(2)[0]); +} + +NNVM_REGISTER_OP(_contrib_group_adagrad_update) +.describe(R"code(Update function for Group AdaGrad optimizer. + +Referenced from *Adaptive Subgradient Methods for Online Learning and Stochastic Optimization*, +and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but +uses only a single learning rate for every row of the parameter array. + +Updates are applied by:: + + grad = clip(grad * rescale_grad, clip_gradient) + history += mean(square(grad), axis=1, keepdims=True) + div = grad / sqrt(history + float_stable_eps) + weight -= div * lr + +Weights are updated lazily if the gradient is sparse. + +Note that non-zero values for the weight decay option are not supported. + +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", GroupAdagradShape) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FInferStorageType", GroupAdagradStorageType) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2}; + }) +.set_attr("FComputeEx", GroupAdagradUpdateEx) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("history", "NDArray-or-Symbol", "History") +.add_arguments(GroupAdagradParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/optimizer_op.cu b/src/operator/contrib/optimizer_op.cu new file mode 100644 index 000000000000..40d99c5f0071 --- /dev/null +++ b/src/operator/contrib/optimizer_op.cu @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file optimizer_op.cu + * \brief Optimizer operators + * \author Leonard Lausen + */ +#include "./optimizer_op-inl.h" +#include + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_group_adagrad_update) +.set_attr("FComputeEx", GroupAdagradUpdateEx); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py new file mode 100644 index 000000000000..8ff8a7e1436b --- /dev/null +++ b/tests/python/unittest/test_contrib_optimizer.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import itertools + +import numpy as np + +import mxnet as mx +from mxnet.test_utils import * + + +# * GroupAdaGrad +class PyGroupAdaGrad(mx.optimizer.Optimizer): + """The python reference of Group AdaGrad optimizer. + + Parameters + ---------- + eps: float, optional + Small value to avoid division by 0. + + """ + + def __init__(self, eps=1e-5, **kwargs): + super(PyGroupAdaGrad, self).__init__(**kwargs) + self.float_stable_eps = eps + + def create_state(self, index, weight): + assert len(weight.shape) == 2 + history = mx.nd.zeros( + (weight.shape[0], 1), weight.context, stype=weight.stype) + return history + + def update(self, index, weight, grad, state): + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + assert wd == 0 + + history = state + grad = grad * self.rescale_grad + if self.clip_gradient is not None: + grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient) + history[:] += mx.nd.mean(mx.nd.square(grad), axis=1, keepdims=True) + div = lr * grad / mx.nd.sqrt(history + self.float_stable_eps) + weight[:] -= div + + +def test_group_adagrad(): + mx.random.seed(0) + opt1 = PyGroupAdaGrad + opt2 = mx.optimizer.contrib.GroupAdaGrad + shape = (3, 4) + eps_options = [{}, {'eps': 1e-8}] + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + for dtype in [np.float32]: + for options in itertools.product(eps_options, cg_options, rg_options): + kwarg = dict(wd=0.0) + for option in options: + kwarg.update(option) + compare_optimizer( + opt1(**kwarg), + opt2(**kwarg), + shape, + dtype, + compare_states=False) + compare_optimizer( + opt1(**kwarg), + opt2(**kwarg), + shape, + dtype, + w_stype='row_sparse', + g_stype='row_sparse', + compare_states=False) + compare_optimizer( + opt1(**kwarg), + opt2(**kwarg), + shape, + dtype, + g_stype='row_sparse', + compare_states=False) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 496a61f356b3..334b7d4c0fdb 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -71,43 +71,6 @@ def test_lr_wd_mult(): assert not mx.test_utils.almost_equal(args1['fc1_bias'], args2['fc1_bias'], 1e-1) assert not mx.test_utils.almost_equal(args1['fc2_weight'], args2['fc2_weight'], 1e-1) -def compare_ndarray_tuple(t1, t2, rtol=None, atol=None): - if t1 is not None and t2 is not None: - if isinstance(t1, tuple): - for s1, s2 in zip(t1, t2): - compare_ndarray_tuple(s1, s2, rtol, atol) - else: - assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol) - - -def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default', - rtol=1e-4, atol=1e-5): - if w_stype == 'default': - w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) - w1 = w2.copyto(default_context()) - elif w_stype == 'row_sparse' or w_stype == 'csr': - w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype) - w1 = w2.copyto(default_context()).tostype('default') - else: - raise Exception("type not supported yet") - if g_stype == 'default': - g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) - g1 = g2.copyto(default_context()) - elif g_stype == 'row_sparse' or g_stype == 'csr': - g2 = rand_ndarray(shape, g_stype, dtype=dtype) - g1 = g2.copyto(default_context()).tostype('default') - else: - raise Exception("type not supported yet") - - state1 = opt1.create_state_multi_precision(0, w1) - state2 = opt2.create_state_multi_precision(0, w2) - compare_ndarray_tuple(state1, state2) - - opt1.update_multi_precision(0, w1, g1, state1) - opt2.update_multi_precision(0, w2, g2, state2) - compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol) - assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol) - # SGD class PySGD(mx.optimizer.Optimizer):