Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Aug 27, 2018
1 parent 023e058 commit 94783e8
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 56 deletions.
30 changes: 16 additions & 14 deletions python/mxnet/contrib/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

# pylint: disable=too-many-lines
"""Contrib optimizers."""
from ..ndarray import (NDArray, clip, full, mean, norm,
proximal_group_adagrad_update, sqrt, square, zeros)
from ..ndarray import (NDArray, clip, contrib, full, mean, norm, sqrt, square,
zeros)
from ..optimizer import Optimizer

# convenience wrapper for Optimizer.Register
Expand All @@ -40,7 +40,7 @@ class ProximalGroupAdaGrad(Optimizer):
grad = clip(grad * rescale_grad, clip_gradient)
history += mean(square(grad), axis=1, keepdims=True)
div = grad / sqrt(history + float_stable_eps)
weight += (div + weight * wd) * -lr
weight -= div * lr
If `l2_regularization_strength > 0` a proximal operator is used to optimize
with group lasso objective. Weights are updated lazily if the gradient is
Expand All @@ -58,7 +58,7 @@ class ProximalGroupAdaGrad(Optimizer):
trainer.step(batch_size=1)
For details of the update algorithm see
:class:`~mxnet.ndarray.proximal_group_adagrad_update`.
:class:`~mxnet.ndarray.contrib.proximal_group_adagrad_update`.
This optimizer accepts the following parameters in addition to those
accepted by :class:`.Optimizer`. Weight decay is not supported.
Expand All @@ -76,42 +76,44 @@ def __init__(self, l2_regularization_strength=0.0, eps=1e-5, **kwargs):
super(ProximalGroupAdaGrad, self).__init__(**kwargs)
self.l2_regularization_strength = l2_regularization_strength
self.float_stable_eps = eps
wd = self._get_wd(index)
assert wd == 0, 'Weight decay is not supported for ProximalGroupAdaGrad'

def create_state(self, index, weight):
assert len(weight.shape) == 2
history = zeros(
(weight.shape[0], 1), weight.context, stype=weight.stype)
last_update_buffer = None
last_update = None
if self.l2_regularization_strength > 0:
last_update_buffer = full(
last_update = full(
shape=(weight.shape[0], ),
val=self.num_update,
ctx=weight.context)
else:
last_update_buffer = zeros(1, ctx=weight.context)
return (history, last_update_buffer)
last_update = zeros(1, ctx=weight.context)
return (history, last_update)

def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
assert wd == 0
assert wd == 0, 'Weight decay is not supported for ProximalGroupAdaGrad'

is_sparse = grad.stype == 'row_sparse'
history = state[0]
last_update_buffer = state[1]
last_update = state[1]
if self.l2_regularization_strength > 0 and is_sparse:
kwargs = dict()
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
proximal_group_adagrad_update(
contrib.proximal_group_adagrad_update(
weight,
grad,
history,
out=weight,
last_update_buffer=last_update_buffer,
last_update=last_update,
rescale_grad=self.rescale_grad,
epsilon=self.float_stable_eps,
lr=lr,
Expand All @@ -124,13 +126,13 @@ def update(self, index, weight, grad, state):
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
history[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(history + self.float_stable_eps)
num_skipped = (self.num_update - last_update_buffer).expand_dims(1)
num_skipped = (self.num_update - last_update).expand_dims(1)
scaled_l2 = lr / sqrt(history + self.float_stable_eps) \
* self.l2_regularization_strength * num_skipped
nrm = norm(weight - div, ord=2, axis=1, keepdims=True)
weight[:] = (weight - div) * (1 - scaled_l2 / nrm)
weight[:] *= nrm > scaled_l2
last_update_buffer[:] = self.num_update
last_update[:] = self.num_update
else:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
Expand Down
8 changes: 4 additions & 4 deletions src/operator/contrib/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ template <typename xpu> struct ProximalGroupAdagradDnsRspKernel {
// Compute number of weight updates skipped due to lazy_update
DType num_skipped = current_update - last_update_data[grad_idx[i]];
last_update_data[grad_idx[i]] = current_update;
// Warn in case of erroneous last_update_buffer
// Warn in case of erroneous last_update
if (num_skipped < 0) {
num_skipped = 0;
std::printf("Got invalid last_update in proximal_adagrad_update. "
Expand Down Expand Up @@ -200,13 +200,13 @@ template <typename xpu> struct ProximalGroupAdagradDnsRspKernel {
};

/*
* \brief Adagrad update implementation for dense weight and row_sparse grad.
* \brief Proximal Group Adagrad update implementation for dense weight and row_sparse grad.
*/
template <typename xpu>
inline void ProximalGroupAdagradUpdateDnsRspDnsImpl(
const ProximalGroupAdagradParam &param, const OpContext &ctx,
const TBlob &weight, const NDArray &grad, const TBlob &state,
const TBlob &last_update_buffer, const OpReqType &req, TBlob *out) {
const TBlob &last_update, const OpReqType &req, TBlob *out) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
Expand All @@ -229,7 +229,7 @@ inline void ProximalGroupAdagradUpdateDnsRspDnsImpl(
const IType *grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
const DType *grad_val = grad.data().dptr<DType>();
DType *state_data = state.dptr<DType>();
DType *last_update_data = last_update_buffer.dptr<DType>();
DType *last_update_data = last_update.dptr<DType>();
const nnvm::dim_t num_grad = grad.aux_shape(rowsparse::kIdx)[0];
const auto row_length = weight.shape_.ProdShape(1, weight.ndim());

Expand Down
5 changes: 2 additions & 3 deletions src/operator/contrib/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ inline bool ProximalGroupAdagradShape(const nnvm::NodeAttrs &attrs,
(in_attrs->at(0)[0] == in_attrs->at(2)[0]);
}

NNVM_REGISTER_OP(proximal_group_adagrad_update)
MXNET_ADD_SPARSE_OP_ALIAS(proximal_group_adagrad_update)
NNVM_REGISTER_OP(_contrib_proximal_group_adagrad_update)
.describe(R"code(Update function for Proximal Group AdaGrad optimizer.
Referenced from *Adaptive Subgradient Methods for Online Learning and Stochastic Optimization*,
Expand Down Expand Up @@ -88,7 +87,7 @@ Note that non-zero values for the weight decay option are not supported.
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("history", "NDArray-or-Symbol", "History")
.add_argument("last_update_buffer", "NDArray-or-Symbol", "Last update buffer")
.add_argument("last_update", "NDArray-or-Symbol", "Array storing last update counter for each row.")
.add_arguments(ProximalGroupAdagradParam::__FIELDS__());

} // namespace op
Expand Down
65 changes: 30 additions & 35 deletions tests/python/unittest/test_contrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,41 +83,36 @@ def test_proximal_group_adagrad():
'l2_regularization_strength': 0.05
}]
for dtype in [np.float32]:
for eps_option in eps_options:
for cg_option in cg_options:
for rg_option in rg_options:
for l2_option in l2_options:
kwarg = dict(wd=0.0)
kwarg.update(eps_option)
kwarg.update(cg_option)
kwarg.update(rg_option)
kwarg.update(l2_option)
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
compare_states=False)
if l2_option.get('l2_regularization_strength',
0.0) == 0.0:
# By design results for PyOp which always performs
# dense update will differ if
# l2_regularization_strength > 0
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
w_stype='row_sparse',
g_stype='row_sparse',
compare_states=False)
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
g_stype='row_sparse',
compare_states=False)
for options in itertools.product(eps_options, cg_options, rg_options,
l2_options):
kwarg = dict(wd=0.0)
for option in options:
kwarg.update(option)
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
compare_states=False)
if kwarg.get('l2_regularization_strength', 0.0) == 0.0:
# By design results for PyOp which always performs
# dense update will differ if
# l2_regularization_strength > 0
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
w_stype='row_sparse',
g_stype='row_sparse',
compare_states=False)
compare_optimizer(
opt1(**kwarg),
opt2(**kwarg),
shape,
dtype,
g_stype='row_sparse',
compare_states=False)


if __name__ == '__main__':
Expand Down

0 comments on commit 94783e8

Please sign in to comment.