diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 334b7d4c0fdb..b03dcdcfba44 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -16,6 +16,7 @@ # under the License. import numpy as np +import itertools import mxnet as mx import mxnet.lr_scheduler as lr_scheduler from mxnet import gluon @@ -501,7 +502,6 @@ def test_ftml(): # ADAM - class PyAdam(mx.optimizer.Optimizer): """python reference implemenation of adam""" def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, @@ -613,6 +613,80 @@ def test_adam(): dtype, w_stype='default', g_stype='row_sparse', rtol=1e-4, atol=2e-5) + +# AdaMax +class PyAdamax(mx.optimizer.Optimizer): + """The python reference of AdaMax optimizer. + + This class implements the AdaMax optimizer, one variant of Adam based on the infinity norm, + available at http://arxiv.org/abs/1412.6980 Section 7. + + The optimizer updates the weight by:: + grad = clip(grad * rescale_grad + wd * weight, clip_gradient) + m = beta1 * m_t + (1 - beta1) * grad + u = maximum(beta2 * u, abs(grad)) + weight -= lr / (1 - beta1**t) * m / u + + This optimizer accepts the following parameters in addition to those accepted + by :class:`.Optimizer`. + + Parameters + ---------- + beta1 : float, optional + Exponential decay rate for the first moment estimates. + beta2 : float, optional + Exponential decay rate for the second moment estimates. + """ + def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, **kwargs): + super(PyAdamax, self).__init__(learning_rate=learning_rate, **kwargs) + self.beta1 = beta1 + self.beta2 = beta2 + + def create_state(self, index, weight): + return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean + mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance + + def update(self, index, weight, grad, state): + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + + t = self._index_update_count[index] + lr /= (1. - self.beta1**t) + + # preprocess grad + grad = grad * self.rescale_grad + wd * weight + if self.clip_gradient is not None: + grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient) + + # update m_t and u_t + m_t, u_t = state + m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad + u_t[:] = mx.nd.maximum(self.beta2 * u_t, mx.nd.abs(grad)) + + # update weight + weight[:] -= lr * m_t / u_t + + +@with_seed() +def test_adamax(): + opt1 = PyAdamax + opt2 = mx.optimizer.Adamax + shape = (3, 4, 5) + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}] + for dtype in [np.float16, np.float32, np.float64]: + for params in itertools.product(cg_options, rg_options, wd_options, mp_options): + kwarg = {k: v for param in params for k, v in param.items()} + if (dtype == np.float16 and + ('multi_precision' not in kwarg or + not kwarg['multi_precision'])): + continue + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + + # Signum class PySignum(mx.optimizer.Optimizer): """The python reference of Signum optimizer.