Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Remove proximal implementation and rename to GroupAdagrad
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Oct 6, 2018
1 parent 47da99f commit be20af2
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 252 deletions.
2 changes: 1 addition & 1 deletion docs/api/python/optimization/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ In the rest of this document, we list routines provided by the `optimizer.contri
.. autosummary::
:nosignatures:
ProximalGroupAdaGrad
GroupAdaGrad
```

## API Reference
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Optimizer API of MXNet."""

from . import optimizer, contrib
# pylint: disable=wildcard-import
from .optimizer import *
# pylint: enable=wildcard-import

Expand Down
69 changes: 18 additions & 51 deletions python/mxnet/optimizer/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@

# pylint: disable=too-many-lines
"""Contrib optimizers."""
from ..ndarray import (NDArray, clip, contrib, full, mean, norm, sparse, sqrt,
square, zeros)
from ..ndarray import (NDArray, clip, contrib, mean, sqrt, square, zeros)
from .optimizer import Optimizer

# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name

__all__ = ['ProximalGroupAdaGrad']
__all__ = ['GroupAdaGrad']


@register
class ProximalGroupAdaGrad(Optimizer):
"""Proximal Adagrad optimizer with row-wise learning rates.
class GroupAdaGrad(Optimizer):
"""Adagrad optimizer with row-wise learning rates.
This class implements the AdaGrad optimizer described in *Adaptive
Subgradient Methods for Online Learning and Stochastic Optimization*, and
Expand All @@ -44,12 +43,11 @@ class ProximalGroupAdaGrad(Optimizer):
div = grad / sqrt(history + float_stable_eps)
weight -= div * lr
If `l2_regularization_strength > 0` a proximal operator is used to optimize
with group lasso objective. Weights are updated lazily if the gradient is
sparse. In particular, before using a set of weights for a forward pass,
you may want to ensure that the lazily accumulated group lasso
regularization is applied. This can be achieved by creating a sparse
gradient array that contains explicit 0 data for the indices to be updated:
Weights are updated lazily if the gradient is sparse. In particular, before
using a set of weights for a forward pass, you may want to ensure that the
lazily accumulated group lasso regularization is applied. This can be
achieved by creating a sparse gradient array that contains explicit 0 data
for the indices to be updated:
fake_grad = mx.nd.sparse.row_sparse_array(
(mx.nd.zeros((len(indices), dim)), indices))
Expand All @@ -60,86 +58,55 @@ class ProximalGroupAdaGrad(Optimizer):
trainer.step(batch_size=1)
For details of the update algorithm see
:class:`~mxnet.ndarray.contrib.proximal_group_adagrad_update`.
:class:`~mxnet.ndarray.contrib.group_adagrad_update`.
This optimizer accepts the following parameters in addition to those
accepted by :class:`.Optimizer`. Weight decay is not supported.
Parameters
----------
l2_regularization_strength : float
Strength of group lasso L2 regularization.
eps: float, optional
Initial value of the history accumulator. Avoids division by 0.
"""

def __init__(self, l2_regularization_strength=0.0, eps=1e-5, **kwargs):
super(ProximalGroupAdaGrad, self).__init__(**kwargs)
self.l2_regularization_strength = l2_regularization_strength
def __init__(self, eps=1e-5, **kwargs):
super(GroupAdaGrad, self).__init__(**kwargs)
self.float_stable_eps = eps

def create_state(self, index, weight):
assert len(weight.shape) == 2
history = zeros(
(weight.shape[0], 1), weight.context, stype=weight.stype)
last_update = None
if self.l2_regularization_strength > 0:
last_update = full(
shape=(weight.shape[0], ),
val=self.num_update,
ctx=weight.context)
else:
last_update = zeros(1, ctx=weight.context)
return (history, last_update)
return history

def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
assert wd == 0, 'Weight decay is not supported for ProximalGroupAdaGrad'
assert wd == 0, 'Weight decay is not supported for GroupAdaGrad'

is_sparse = grad.stype == 'row_sparse'
history = state[0]
last_update = state[1]
if is_sparse:
kwargs = {
'epsilon': self.float_stable_eps,
'rescale_grad': self.rescale_grad
}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if self.l2_regularization_strength:
kwargs['l2_regularization_strength'] = \
self.l2_regularization_strength
contrib.proximal_group_adagrad_update(
contrib.group_adagrad_update(
weight,
grad,
history,
state,
out=weight,
last_update=last_update,
lr=lr,
current_update=self.num_update,
**kwargs)
elif self.l2_regularization_strength > 0:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
history[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(history + self.float_stable_eps)
num_skipped = (self.num_update - last_update).expand_dims(1)
scaled_l2 = lr / sqrt(history + self.float_stable_eps) \
* self.l2_regularization_strength * num_skipped
nrm = norm(weight - div, ord=2, axis=1, keepdims=True)
weight[:] = (weight - div) * (1 - scaled_l2 / nrm)
weight[:] *= nrm > scaled_l2
last_update[:] = self.num_update
else:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
history[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(history + self.float_stable_eps)
state[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(state + self.float_stable_eps)
weight[:] -= div
Loading

0 comments on commit be20af2

Please sign in to comment.