Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Proximal Group Adagrad optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Aug 27, 2018
1 parent 48d2155 commit de44d8d
Show file tree
Hide file tree
Showing 8 changed files with 753 additions and 37 deletions.
1 change: 1 addition & 0 deletions python/mxnet/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from . import text
from . import onnx
from . import optimizer
from . import io
from . import quantization
from . import quantization as quant
Expand Down
141 changes: 141 additions & 0 deletions python/mxnet/contrib/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=too-many-lines
"""Contrib optimizers."""
from ..ndarray import (NDArray, cast, clip, full, mean, norm,
proximal_group_adagrad_update, sparse, sqrt, square,
zeros)
from ..optimizer import Optimizer

# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name


@register
class ProximalGroupAdaGrad(Optimizer):
"""Proximal Adagrad optimizer with row-wise learning rates.
This class implements the AdaGrad optimizer described in *Adaptive
Subgradient Methods for Online Learning and Stochastic Optimization*, and
available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but
uses only a single learning rate for every row of the parameter array.
This optimizer updates each weight by::
grad = clip(grad * rescale_grad, clip_gradient)
history += mean(square(grad), axis=1, keepdims=True)
div = grad / sqrt(history + float_stable_eps)
weight += (div + weight * wd) * -lr
If `l2_regularization_strength > 0` a proximal operator is used to optimize
with group lasso objective. Weights are updated lazily if the gradient is
sparse. In particular, before using a set of weights for a forward pass,
you may want to ensure that the lazily accumulated group lasso
regularization is applied. This can be achieved by creating a sparse
gradient array that contains explicit 0 data for the indices to be updated:
fake_grad = mx.nd.sparse.row_sparse_array(
(mx.nd.zeros((len(indices), dim)), indices))
weight.grad()[:] = fake_grad
weight.data()._fresh_grad = True
trainer._optimizer._index_update_count[0] -= 1
trainer._optimizer.num_update -= 1
trainer.step(batch_size=1)
For details of the update algorithm see
:class:`~mxnet.ndarray.proximal_group_adagrad_update`.
This optimizer accepts the following parameters in addition to those
accepted by :class:`.Optimizer`. Weight decay is not supported.
Parameters
----------
l2_regularization_strength : float
Strength of group lasso L2 regularization.
eps: float, optional
Initial value of the history accumulator. Avoids division by 0.
"""

def __init__(self, l2_regularization_strength=0.0, eps=1e-5, **kwargs):
super(ProximalGroupAdaGrad, self).__init__(**kwargs)
self.l2_regularization_strength = l2_regularization_strength
self.float_stable_eps = eps

def create_state(self, index, weight):
assert len(weight.shape) == 2
history = zeros(
(weight.shape[0], 1), weight.context, stype=weight.stype)
last_update_buffer = None
if self.l2_regularization_strength > 0:
last_update_buffer = full(
shape=(weight.shape[0], ),
val=self.num_update,
ctx=weight.context)
else:
last_update_buffer = zeros(1, ctx=weight.context)
return (history, last_update_buffer)

def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
assert wd == 0

is_sparse = grad.stype == 'row_sparse'
history = state[0]
last_update_buffer = state[1]
if self.l2_regularization_strength > 0 and is_sparse:
kwargs = dict()
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
proximal_group_adagrad_update(
weight,
grad,
history,
out=weight,
last_update_buffer=last_update_buffer,
rescale_grad=self.rescale_grad,
epsilon=self.float_stable_eps,
lr=lr,
current_update=self.num_update,
l2_regularization_strength=self.l2_regularization_strength,
**kwargs)
elif self.l2_regularization_strength > 0:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
history[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(history + self.float_stable_eps)
num_skipped = (self.num_update - last_update_buffer).expand_dims(1)
scaled_l2 = lr / sqrt(history + self.float_stable_eps) \
* self.l2_regularization_strength * num_skipped
nrm = norm(weight - div, ord=2, axis=1, keepdims=True)
weight[:] = (weight - div) * (1 - scaled_l2 / nrm)
weight[:] *= nrm > scaled_l2
last_update_buffer[:] = self.num_update
else:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
history[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(history + self.float_stable_eps)
weight[:] -= div
39 changes: 39 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,3 +1954,42 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc
% (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
str(buckets), str(probs)))
return cs_ret_l

def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
if t1 is not None and t2 is not None:
if isinstance(t1, tuple):
for s1, s2 in zip(t1, t2):
compare_ndarray_tuple(s1, s2, rtol, atol)
else:
assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol)


def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
rtol=1e-4, atol=1e-5, compare_states=True):
if w_stype == 'default':
w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
w1 = w2.copyto(default_context())
elif w_stype == 'row_sparse' or w_stype == 'csr':
w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
w1 = w2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")
if g_stype == 'default':
g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
g1 = g2.copyto(default_context())
elif g_stype == 'row_sparse' or g_stype == 'csr':
g2 = rand_ndarray(shape, g_stype, dtype=dtype)
g1 = g2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")

state1 = opt1.create_state_multi_precision(0, w1)
state2 = opt2.create_state_multi_precision(0, w2)
if compare_states:
compare_ndarray_tuple(state1, state2)

opt1.update_multi_precision(0, w1, g1, state1)
opt2.update_multi_precision(0, w2, g2, state2)
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
Loading

0 comments on commit de44d8d

Please sign in to comment.