Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1235] Add a test for AdaMax optimizer #13467

Merged
merged 3 commits into from
Dec 4, 2018
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import numpy as np
import itertools
import mxnet as mx
import mxnet.lr_scheduler as lr_scheduler
from mxnet import gluon
Expand Down Expand Up @@ -501,7 +502,6 @@ def test_ftml():


# ADAM

class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
Expand Down Expand Up @@ -613,6 +613,80 @@ def test_adam():
dtype, w_stype='default', g_stype='row_sparse',
rtol=1e-4, atol=2e-5)


# AdaMax
class PyAdamax(mx.optimizer.Optimizer):
"""The python reference of AdaMax optimizer.

This class implements the AdaMax optimizer, a variant of Adam based on the infinity norm,
available at http://arxiv.org/abs/1412.6980 Section 7.

The optimizer updates the weight by::
grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
m = beta1 * m_t + (1 - beta1) * grad
u = maximum(beta2 * u, abs(grad))
weight -= lr / (1 - beta1**t) * m / u

This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.

Parameters
----------
beta1 : float, optional
Exponential decay rate for the first moment estimates.
beta2 : float, optional
Exponential decay rate for the second moment estimates.
"""
def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, **kwargs):
super(PyAdamax, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2

def create_state(self, index, weight):
return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance

def update(self, index, weight, grad, state):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

t = self._index_update_count[index]
lr /= (1. - self.beta1**t)

# preprocess grad
grad = grad * self.rescale_grad + wd * weight
if self.clip_gradient is not None:
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)

# update m_t and u_t
m_t, u_t = state
m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
u_t[:] = mx.nd.maximum(self.beta2 * u_t, mx.nd.abs(grad))

# update weight
weight[:] -= lr * m_t / u_t


@with_seed()
def test_adamax():
opt1 = PyAdamax
opt2 = mx.optimizer.Adamax
shape = (3, 4, 5)
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}]
for dtype in [np.float16, np.float32, np.float64]:
D-Roberts marked this conversation as resolved.
Show resolved Hide resolved
for params in itertools.product(cg_options, rg_options, wd_options, mp_options):
kwarg = {k: v for param in params for k, v in param.items()}
if (dtype == np.float16 and
('multi_precision' not in kwarg or
not kwarg['multi_precision'])):
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)


# Signum
class PySignum(mx.optimizer.Optimizer):
"""The python reference of Signum optimizer.
Expand Down