Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Adagrad optimizer with row-wise learning rate #12365

Merged
merged 4 commits into from
Oct 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Code examples are placed throughout the API documentation and these can be run a
:maxdepth: 1

optimization/optimization.md
optimization/contrib.md
```

## Profiler API
Expand Down
52 changes: 52 additions & 0 deletions docs/api/python/optimization/contrib.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Contrib Optimization API

```eval_rst
.. currentmodule:: mxnet.optimizer.contrib
```

## Overview

This document summaries the contrib APIs used to initialize and update the model
weights during training

```eval_rst
.. autosummary::
:nosignatures:

mxnet.optimizer.contrib
```

The `Contrib Optimization` API, defined in the `optimizer.contrib` package, provides
many useful experimental APIs for new features.
This is a place for the community to try out the new features,
so that feature contributors can receive feedback.

```eval_rst
.. warning:: This package contains experimental APIs and may change in the near future.
```

In the rest of this document, we list routines provided by the `optimizer.contrib` package.

## Contrib

```eval_rst
.. currentmodule:: mxnet.optimizer.contrib

.. autosummary::
:nosignatures:

GroupAdaGrad
```

## API Reference

<script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script>

```eval_rst

.. automodule:: mxnet.optimizer.contrib
:members:

```

<script>auto_index("api-reference");</script>
24 changes: 24 additions & 0 deletions python/mxnet/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Optimizer API of MXNet."""

from . import optimizer, contrib
# pylint: disable=wildcard-import
from .optimizer import *
# pylint: enable=wildcard-import

__all__ = optimizer.__all__ + ['contrib']
100 changes: 100 additions & 0 deletions python/mxnet/optimizer/contrib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=too-many-lines
"""Contrib optimizers."""
from ..ndarray import (NDArray, clip, contrib, mean, sqrt, square, zeros)
from .optimizer import Optimizer

# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name

__all__ = ['GroupAdaGrad']


@register
class GroupAdaGrad(Optimizer):
"""Adagrad optimizer with row-wise learning rates.

This class implements the AdaGrad optimizer described in *Adaptive
Subgradient Methods for Online Learning and Stochastic Optimization*, and
available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but
uses only a single learning rate for every row of the parameter array.

This optimizer updates each weight by::

grad = clip(grad * rescale_grad, clip_gradient)
history += mean(square(grad), axis=1, keepdims=True)
div = grad / sqrt(history + float_stable_eps)
weight -= div * lr

Weights are updated lazily if the gradient is sparse.

For details of the update algorithm see
:class:`~mxnet.ndarray.contrib.group_adagrad_update`.

This optimizer accepts the following parameters in addition to those
accepted by :class:`.Optimizer`. Weight decay is not supported.

Parameters
----------
eps: float, optional
Initial value of the history accumulator. Avoids division by 0.

"""

def __init__(self, eps=1e-5, **kwargs):
super(GroupAdaGrad, self).__init__(**kwargs)
self.float_stable_eps = eps

def create_state(self, index, weight):
assert len(weight.shape) == 2
history = zeros(
(weight.shape[0], 1), weight.context, stype=weight.stype)
return history

def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
assert wd == 0, 'Weight decay is not supported for GroupAdaGrad'

is_sparse = grad.stype == 'row_sparse'
if is_sparse:
kwargs = {
'epsilon': self.float_stable_eps,
'rescale_grad': self.rescale_grad
}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
contrib.group_adagrad_update(
weight,
grad,
state,
out=weight,
lr=lr,
**kwargs)
else:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
state[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(state + self.float_stable_eps)
weight[:] -= div
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@
import pickle
import warnings
import numpy
from .base import py_str
from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from .ndarray import sparse
from .random import normal
from ..base import py_str
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from ..ndarray import sparse
from ..random import normal

__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD',
'NAG', 'NDArray', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD',
'Signum', 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
]


class Optimizer(object):
Expand Down
41 changes: 41 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,3 +1957,44 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc
% (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
str(buckets), str(probs)))
return cs_ret_l

def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
"""Compare ndarray tuple."""
if t1 is not None and t2 is not None:
if isinstance(t1, tuple):
for s1, s2 in zip(t1, t2):
compare_ndarray_tuple(s1, s2, rtol, atol)
else:
assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol)


def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
rtol=1e-4, atol=1e-5, compare_states=True):
"""Compare opt1 and opt2."""
if w_stype == 'default':
w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
w1 = w2.copyto(default_context())
elif w_stype == 'row_sparse' or w_stype == 'csr':
w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
w1 = w2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")
if g_stype == 'default':
g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
g1 = g2.copyto(default_context())
elif g_stype == 'row_sparse' or g_stype == 'csr':
g2 = rand_ndarray(shape, g_stype, dtype=dtype)
g1 = g2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")

state1 = opt1.create_state_multi_precision(0, w1)
state2 = opt2.create_state_multi_precision(0, w2)
if compare_states:
compare_ndarray_tuple(state1, state2)

opt1.update_multi_precision(0, w1, g1, state1)
opt2.update_multi_precision(0, w2, g2, state2)
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
Loading