Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[op] add back support for scalar type rescale_grad argument for adamw_update/mp_adamw_update #14221

Merged
merged 4 commits into from
Feb 27, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,153 @@ def isnan(data):
<NDArray 2 @cpu(0)>
"""
return data != data

def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
"""Update function for AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight
decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v
are 1st and 2nd order moment estimates (mean and variance).

.. math::

g_t = \nabla J(W_{t-1})\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})

It updates the weights using::

m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf,
or 0, the update is skipped.

Parameters
----------
weight : NDArray
Weight
grad : NDArray
Gradient
mean : NDArray
Moving mean
var : NDArray
Moving variance
rescale_grad : float or NDArray
Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped.
lr : float
Learning rate
eta : float
Learning rate schedule multiplier
beta1 : float, optional, default is 0.9
The decay rate for the 1st moment estimates.
beta2 : float, optional, default is 0.999
The decay rate for the 2nd moment estimates.
epsilon : float, optional, default is 1e-08
A small constant for numerical stability.
wd : float, optional, default is 0
Weight decay augments the objective function with a regularization term that penalizes
large weights. The penalty scales with the square of the magnitude of each weight.
clip_gradient : float, optional, default is -1
Clip gradient to the range of [-clip_gradient, clip_gradient]
If clip_gradient <= 0, gradient clipping is turned off.
grad = max(min(grad, clip_gradient), -clip_gradient).
out : NDArray, optional
The output NDArray to hold the result.

Returns
-------
output: NDArray

"""
if not isinstance(rescale_grad, ndarray.NDArray):
rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
else:
rescale_grad = rescale_grad.as_in_context(weight.context)
return ndarray._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)

def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
name=None, **kwargs):
"""Update function for multi-precision AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight
decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v
are 1st and 2nd order moment estimates (mean and variance).

.. math::

g_t = \nabla J(W_{t-1})\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})

It updates the weights using::

m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf,
or 0, the update is skipped.

Parameters
----------
weight : NDArray
Weight
grad : NDArray
Gradient
mean : NDArray
Moving mean
var : NDArray
Moving variance
weight32 : NDArray
Weight in fp32.
rescale_grad : float or NDArray
Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped.
lr : float
Learning rate
eta : float
Learning rate schedule multiplier
beta1 : float, optional, default is 0.9
The decay rate for the 1st moment estimates.
beta2 : float, optional, default is 0.999
The decay rate for the 2nd moment estimates.
epsilon : float, optional, default is 1e-08
A small constant for numerical stability.
wd : float, optional, default is 0
Weight decay augments the objective function with a regularization term that penalizes
large weights. The penalty scales with the square of the magnitude of each weight.
clip_gradient : float, optional, default is -1
Clip gradient to the range of [-clip_gradient, clip_gradient]
If clip_gradient <= 0, gradient clipping is turned off.
grad = max(min(grad, clip_gradient), -clip_gradient).
out : NDArray, optional
The output NDArray to hold the result.

Returns
-------
output: NDArray

"""
if not isinstance(rescale_grad, ndarray.NDArray):
rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
else:
rescale_grad = rescale_grad.as_in_context(weight.context)
return ndarray._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
weight32=weight32,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)
146 changes: 146 additions & 0 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,149 @@ def _union_inputs(*graphs):
outputs = [result[i] for i in range(then_num_outputs)]
outputs, _ = _regroup(outputs, then_fmt)
return outputs

def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
"""Update function for AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight
decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v
are 1st and 2nd order moment estimates (mean and variance).

.. math::

g_t = \nabla J(W_{t-1})\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})

It updates the weights using::

m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf,
or 0, the update is skipped.

Parameters
----------
weight : Symbol
Weight
grad : Symbol
Gradient
mean : Symbol
Moving mean
var : Symbol
Moving variance
rescale_grad : float or Symbol
Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped.
lr : float
Learning rate
eta : float
Learning rate schedule multiplier
beta1 : float, optional, default is 0.9
The decay rate for the 1st moment estimates.
beta2 : float, optional, default is 0.999
The decay rate for the 2nd moment estimates.
epsilon : float, optional, default is 1e-08
A small constant for numerical stability.
wd : float, optional, default is 0
Weight decay augments the objective function with a regularization term that penalizes
large weights. The penalty scales with the square of the magnitude of each weight.
clip_gradient : float, optional, default is -1
Clip gradient to the range of [-clip_gradient, clip_gradient]
If clip_gradient <= 0, gradient clipping is turned off.
grad = max(min(grad, clip_gradient), -clip_gradient).
out : Symbol, optional
The output NDArray to hold the result.

Returns
-------
output: Symbol

"""
if not isinstance(rescale_grad, Symbol):
rescale_grad = symbol.full(shape=(1,), val=rescale_grad)
return symbol._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)

def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
name=None, **kwargs):
"""Update function for multi-precision AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight
decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v
are 1st and 2nd order moment estimates (mean and variance).

.. math::

g_t = \nabla J(W_{t-1})\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})

It updates the weights using::

m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd)

Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf,
or 0, the update is skipped.

Parameters
----------
weight : Symbol
Weight
grad : Symbol
Gradient
mean : Symbol
Moving mean
var : Symbol
Moving variance
weight32 : Symbol
Weight in fp32.
rescale_grad : float or Symbol
Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped.
lr : float
Learning rate
eta : float
Learning rate schedule multiplier
beta1 : float, optional, default is 0.9
The decay rate for the 1st moment estimates.
beta2 : float, optional, default is 0.999
The decay rate for the 2nd moment estimates.
epsilon : float, optional, default is 1e-08
A small constant for numerical stability.
wd : float, optional, default is 0
Weight decay augments the objective function with a regularization term that penalizes
large weights. The penalty scales with the square of the magnitude of each weight.
clip_gradient : float, optional, default is -1
Clip gradient to the range of [-clip_gradient, clip_gradient]
If clip_gradient <= 0, gradient clipping is turned off.
grad = max(min(grad, clip_gradient), -clip_gradient).
out : Symbol, optional
The output Symbol to hold the result.

Returns
-------
output: Symbol

"""
if not isinstance(rescale_grad, Symbol):
rescale_grad = symbol.full(shape=(1,), val=rescale_grad)
return symbol._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
weight32=weight32,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)
10 changes: 6 additions & 4 deletions src/operator/contrib/adamw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs,
});
}

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
NNVM_REGISTER_OP(_mp_adamw_update)
.describe(R"code(Update function for multi-precision AdamW optimizer.

AdamW is seen as a modification of Adam by decoupling the weight decay from the
Expand Down Expand Up @@ -91,10 +91,11 @@ the update is skipped.
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
"Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
"the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

NNVM_REGISTER_OP(_contrib_adamw_update)
NNVM_REGISTER_OP(_adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.

Expand Down Expand Up @@ -132,7 +133,8 @@ the update is skipped.
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
"Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
"the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

} // namespace op
Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs,
});
}

NNVM_REGISTER_OP(_contrib_adamw_update)
NNVM_REGISTER_OP(_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>);

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
NNVM_REGISTER_OP(_mp_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>);

} // namespace op
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_contrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def test_adamw():
kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon,
'beta1': beta1, 'beta2': beta2}

# update is skipped for rescale = nan scalar
mx.nd.contrib.adamw_update(weight, grad, m, v,
np.nan, out=weight, **kwargs)
# weight remains unchanged
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())

# update is skipped for rescale = 0
mx.nd.contrib.adamw_update(weight, grad, m, v,
rescale_grad * 0, out=weight, **kwargs)
Expand Down Expand Up @@ -134,6 +140,12 @@ def test_adamw():
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())

# multi-precision update is skipped for rescale = nan scalar
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
np.nan, out=weight_fp16, **kwargs)
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())

# multi-precision update is skipped for rescale = inf
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
rescale_grad * np.inf, out=weight_fp16, **kwargs)
Expand Down