Skip to content

Commit

Permalink
[MXNET-1235] Add a test for AdaMax optimizer (apache#13467)
Browse files Browse the repository at this point in the history
* Add a test for AdaMax optimizer

* Modify nested for loop with itertools.product and left tolerance to default

* Trigger
  • Loading branch information
D-Roberts authored and zhaoyao73 committed Dec 9, 2018
1 parent 3285084 commit 3d3964e
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import numpy as np
import itertools
import mxnet as mx
import mxnet.lr_scheduler as lr_scheduler
from mxnet import gluon
Expand Down Expand Up @@ -501,7 +502,6 @@ def test_ftml():


# ADAM

class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
Expand Down Expand Up @@ -613,6 +613,80 @@ def test_adam():
dtype, w_stype='default', g_stype='row_sparse',
rtol=1e-4, atol=2e-5)


# AdaMax
class PyAdamax(mx.optimizer.Optimizer):
"""The python reference of AdaMax optimizer.
This class implements the AdaMax optimizer, one variant of Adam based on the infinity norm,
available at http://arxiv.org/abs/1412.6980 Section 7.
The optimizer updates the weight by::
grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
m = beta1 * m_t + (1 - beta1) * grad
u = maximum(beta2 * u, abs(grad))
weight -= lr / (1 - beta1**t) * m / u
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
beta1 : float, optional
Exponential decay rate for the first moment estimates.
beta2 : float, optional
Exponential decay rate for the second moment estimates.
"""
def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, **kwargs):
super(PyAdamax, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2

def create_state(self, index, weight):
return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance

def update(self, index, weight, grad, state):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

t = self._index_update_count[index]
lr /= (1. - self.beta1**t)

# preprocess grad
grad = grad * self.rescale_grad + wd * weight
if self.clip_gradient is not None:
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)

# update m_t and u_t
m_t, u_t = state
m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
u_t[:] = mx.nd.maximum(self.beta2 * u_t, mx.nd.abs(grad))

# update weight
weight[:] -= lr * m_t / u_t


@with_seed()
def test_adamax():
opt1 = PyAdamax
opt2 = mx.optimizer.Adamax
shape = (3, 4, 5)
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}]
for dtype in [np.float16, np.float32, np.float64]:
for params in itertools.product(cg_options, rg_options, wd_options, mp_options):
kwarg = {k: v for param in params for k, v in param.items()}
if (dtype == np.float16 and
('multi_precision' not in kwarg or
not kwarg['multi_precision'])):
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)


# Signum
class PySignum(mx.optimizer.Optimizer):
"""The python reference of Signum optimizer.
Expand Down

0 comments on commit 3d3964e

Please sign in to comment.