Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Adagrad optimizer with row-wise learning rate #12365

Merged
merged 4 commits into from
Oct 11, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Code examples are placed throughout the API documentation and these can be run a
:maxdepth: 1
optimization/optimization.md
optimization/contrib.md
```

## Profiler API
Expand Down
52 changes: 52 additions & 0 deletions docs/api/python/optimization/contrib.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Contrib Optimization API

```eval_rst
.. currentmodule:: mxnet.optimizer.contrib
```

## Overview

This document summaries the contrib APIs used to initialize and update the model
weights during training

```eval_rst
.. autosummary::
:nosignatures:
mxnet.optimizer.contrib
```

The `Contrib Optimization` API, defined in the `optimizer.contrib` package, provides
many useful experimental APIs for new features.
This is a place for the community to try out the new features,
so that feature contributors can receive feedback.

```eval_rst
.. warning:: This package contains experimental APIs and may change in the near future.
```

In the rest of this document, we list routines provided by the `optimizer.contrib` package.

## Contrib

```eval_rst
.. currentmodule:: mxnet.optimizer.contrib
.. autosummary::
:nosignatures:
GroupAdaGrad
```

## API Reference

<script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script>

```eval_rst
.. automodule:: mxnet.optimizer.contrib
:members:
```

<script>auto_index("api-reference");</script>
24 changes: 24 additions & 0 deletions python/mxnet/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Optimizer API of MXNet."""

from . import optimizer, contrib
# pylint: disable=wildcard-import
from .optimizer import *
# pylint: enable=wildcard-import

__all__ = optimizer.__all__ + ['contrib']
100 changes: 100 additions & 0 deletions python/mxnet/optimizer/contrib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=too-many-lines
"""Contrib optimizers."""
from ..ndarray import (NDArray, clip, contrib, mean, sqrt, square, zeros)
from .optimizer import Optimizer

# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name

__all__ = ['GroupAdaGrad']


@register
class GroupAdaGrad(Optimizer):
"""Adagrad optimizer with row-wise learning rates.
This class implements the AdaGrad optimizer described in *Adaptive
Subgradient Methods for Online Learning and Stochastic Optimization*, and
available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but
uses only a single learning rate for every row of the parameter array.
This optimizer updates each weight by::
grad = clip(grad * rescale_grad, clip_gradient)
history += mean(square(grad), axis=1, keepdims=True)
div = grad / sqrt(history + float_stable_eps)
weight -= div * lr
Weights are updated lazily if the gradient is sparse.
For details of the update algorithm see
:class:`~mxnet.ndarray.contrib.group_adagrad_update`.
This optimizer accepts the following parameters in addition to those
accepted by :class:`.Optimizer`. Weight decay is not supported.
Parameters
----------
eps: float, optional
Initial value of the history accumulator. Avoids division by 0.
"""

def __init__(self, eps=1e-5, **kwargs):
super(GroupAdaGrad, self).__init__(**kwargs)
self.float_stable_eps = eps

def create_state(self, index, weight):
assert len(weight.shape) == 2
history = zeros(
(weight.shape[0], 1), weight.context, stype=weight.stype)
return history

def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
assert wd == 0, 'Weight decay is not supported for GroupAdaGrad'

is_sparse = grad.stype == 'row_sparse'
if is_sparse:
kwargs = {
'epsilon': self.float_stable_eps,
'rescale_grad': self.rescale_grad
}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
contrib.group_adagrad_update(
weight,
grad,
state,
out=weight,
lr=lr,
**kwargs)
else:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
state[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(state + self.float_stable_eps)
weight[:] -= div
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@
import pickle
import warnings
import numpy
from .base import py_str
from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from .ndarray import sparse
from .random import normal
from ..base import py_str
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from ..ndarray import sparse
from ..random import normal

__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD',
'NAG', 'NDArray', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD',
'Signum', 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
]


class Optimizer(object):
Expand Down
41 changes: 41 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,3 +1957,44 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc
% (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
str(buckets), str(probs)))
return cs_ret_l

def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
"""Compare ndarray tuple."""
if t1 is not None and t2 is not None:
if isinstance(t1, tuple):
for s1, s2 in zip(t1, t2):
compare_ndarray_tuple(s1, s2, rtol, atol)
else:
assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol)


def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
rtol=1e-4, atol=1e-5, compare_states=True):
"""Compare opt1 and opt2."""
if w_stype == 'default':
w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
w1 = w2.copyto(default_context())
elif w_stype == 'row_sparse' or w_stype == 'csr':
w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
w1 = w2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")
if g_stype == 'default':
g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
g1 = g2.copyto(default_context())
elif g_stype == 'row_sparse' or g_stype == 'csr':
g2 = rand_ndarray(shape, g_stype, dtype=dtype)
g1 = g2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")

state1 = opt1.create_state_multi_precision(0, w1)
state2 = opt2.create_state_multi_precision(0, w2)
if compare_states:
compare_ndarray_tuple(state1, state2)

opt1.update_multi_precision(0, w1, g1, state1)
opt2.update_multi_precision(0, w2, g2, state2)
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
Loading