From b785251496e4518dafa54d543d0d6acd64d79cda Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 21 Feb 2019 06:01:47 +0000 Subject: [PATCH 1/3] support scalar --- python/mxnet/ndarray/contrib.py | 150 ++++++++++++++++++ python/mxnet/symbol/contrib.py | 146 +++++++++++++++++ src/operator/contrib/adamw.cc | 10 +- src/operator/contrib/adamw.cu | 4 +- .../python/unittest/test_contrib_optimizer.py | 12 ++ 5 files changed, 316 insertions(+), 6 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 6bbee8a6fbc3..b15415e879a5 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -542,3 +542,153 @@ def isnan(data): """ return data != data + +def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, + epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): + """Update function for AdamW optimizer. + + AdamW is seen as a modification of Adam by decoupling the weight + decay from the optimization steps taken w.r.t. the loss function. + + Adam update consists of the following steps, where g represents gradient and m, v + are 1st and 2nd order moment estimates (mean and variance). + + .. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) + + It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) + + Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, + or 0, the update is skipped. + + Parameters + ---------- + weight : NDArray + Weight + grad : NDArray + Gradient + mean : NDArray + Moving mean + var : NDArray + Moving variance + rescale_grad : float or NDArray + Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped. + lr : float + Learning rate + eta : float + Learning rate schedule multiplier + beta1 : float, optional, default is 0.9 + The decay rate for the 1st moment estimates. + beta2 : float, optional, default is 0.999 + The decay rate for the 2nd moment estimates. + epsilon : float, optional, default is 1e-08 + A small constant for numerical stability. + wd : float, optional, default is 0 + Weight decay augments the objective function with a regularization term that penalizes + large weights. The penalty scales with the square of the magnitude of each weight. + clip_gradient : float, optional, default is -1 + Clip gradient to the range of [-clip_gradient, clip_gradient] + If clip_gradient <= 0, gradient clipping is turned off. + grad = max(min(grad, clip_gradient), -clip_gradient). + out : NDArray, optional + The output NDArray to hold the result. + + Returns + ------- + output: NDArray + + """ + if not isinstance(rescale_grad, ndarray.NDArray): + rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context) + else: + rescale_grad = rescale_grad.as_in_context(weight.context) + return ndarray._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) + +def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9, + beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, + name=None, **kwargs): + """Update function for multi-precision AdamW optimizer. + + AdamW is seen as a modification of Adam by decoupling the weight + decay from the optimization steps taken w.r.t. the loss function. + + Adam update consists of the following steps, where g represents gradient and m, v + are 1st and 2nd order moment estimates (mean and variance). + + .. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) + + It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) + + Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, + or 0, the update is skipped. + + Parameters + ---------- + weight : NDArray + Weight + grad : NDArray + Gradient + mean : NDArray + Moving mean + var : NDArray + Moving variance + weight32 : NDArray + Weight in fp32. + rescale_grad : float or NDArray + Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped. + lr : float + Learning rate + eta : float + Learning rate schedule multiplier + beta1 : float, optional, default is 0.9 + The decay rate for the 1st moment estimates. + beta2 : float, optional, default is 0.999 + The decay rate for the 2nd moment estimates. + epsilon : float, optional, default is 1e-08 + A small constant for numerical stability. + wd : float, optional, default is 0 + Weight decay augments the objective function with a regularization term that penalizes + large weights. The penalty scales with the square of the magnitude of each weight. + clip_gradient : float, optional, default is -1 + Clip gradient to the range of [-clip_gradient, clip_gradient] + If clip_gradient <= 0, gradient clipping is turned off. + grad = max(min(grad, clip_gradient), -clip_gradient). + out : NDArray, optional + The output NDArray to hold the result. + + Returns + ------- + output: NDArray + + """ + if not isinstance(rescale_grad, ndarray.NDArray): + rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context) + else: + rescale_grad = rescale_grad.as_in_context(weight.context) + return ndarray._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var, + weight32=weight32, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index a83227a0261c..5d0f63ba763c 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -727,3 +727,149 @@ def _union_inputs(*graphs): outputs = [result[i] for i in range(then_num_outputs)] outputs, _ = _regroup(outputs, then_fmt) return outputs + +def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, + epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): + """Update function for AdamW optimizer. + + AdamW is seen as a modification of Adam by decoupling the weight + decay from the optimization steps taken w.r.t. the loss function. + + Adam update consists of the following steps, where g represents gradient and m, v + are 1st and 2nd order moment estimates (mean and variance). + + .. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) + + It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) + + Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, + or 0, the update is skipped. + + Parameters + ---------- + weight : Symbol + Weight + grad : Symbol + Gradient + mean : Symbol + Moving mean + var : Symbol + Moving variance + rescale_grad : float or Symbol + Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped. + lr : float + Learning rate + eta : float + Learning rate schedule multiplier + beta1 : float, optional, default is 0.9 + The decay rate for the 1st moment estimates. + beta2 : float, optional, default is 0.999 + The decay rate for the 2nd moment estimates. + epsilon : float, optional, default is 1e-08 + A small constant for numerical stability. + wd : float, optional, default is 0 + Weight decay augments the objective function with a regularization term that penalizes + large weights. The penalty scales with the square of the magnitude of each weight. + clip_gradient : float, optional, default is -1 + Clip gradient to the range of [-clip_gradient, clip_gradient] + If clip_gradient <= 0, gradient clipping is turned off. + grad = max(min(grad, clip_gradient), -clip_gradient). + out : Symbol, optional + The output NDArray to hold the result. + + Returns + ------- + output: Symbol + + """ + if not isinstance(rescale_grad, Symbol): + rescale_grad = symbol.full(shape=(1,), val=rescale_grad) + return symbol._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) + +def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9, + beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, + name=None, **kwargs): + """Update function for multi-precision AdamW optimizer. + + AdamW is seen as a modification of Adam by decoupling the weight + decay from the optimization steps taken w.r.t. the loss function. + + Adam update consists of the following steps, where g represents gradient and m, v + are 1st and 2nd order moment estimates (mean and variance). + + .. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) + + It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) + + Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, + or 0, the update is skipped. + + Parameters + ---------- + weight : Symbol + Weight + grad : Symbol + Gradient + mean : Symbol + Moving mean + var : Symbol + Moving variance + weight32 : Symbol + Weight in fp32. + rescale_grad : float or Symbol + Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped. + lr : float + Learning rate + eta : float + Learning rate schedule multiplier + beta1 : float, optional, default is 0.9 + The decay rate for the 1st moment estimates. + beta2 : float, optional, default is 0.999 + The decay rate for the 2nd moment estimates. + epsilon : float, optional, default is 1e-08 + A small constant for numerical stability. + wd : float, optional, default is 0 + Weight decay augments the objective function with a regularization term that penalizes + large weights. The penalty scales with the square of the magnitude of each weight. + clip_gradient : float, optional, default is -1 + Clip gradient to the range of [-clip_gradient, clip_gradient] + If clip_gradient <= 0, gradient clipping is turned off. + grad = max(min(grad, clip_gradient), -clip_gradient). + out : Symbol, optional + The output Symbol to hold the result. + + Returns + ------- + output: Symbol + + """ + if not isinstance(rescale_grad, Symbol): + rescale_grad = symbol.full(shape=(1,), val=rescale_grad) + return symbol._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var, + weight32=weight32, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index 2fbc39743c93..874cce8d8772 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -50,7 +50,7 @@ inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs, }); } -NNVM_REGISTER_OP(_contrib_mp_adamw_update) +NNVM_REGISTER_OP(_mp_adamw_update) .describe(R"code(Update function for multi-precision AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the @@ -91,10 +91,11 @@ the update is skipped. .add_argument("var", "NDArray-or-Symbol", "Moving variance") .add_argument("weight32", "NDArray-or-Symbol", "Weight32") .add_argument("rescale_grad", "NDArray-or-Symbol", - "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.") + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") .add_arguments(AdamWParam::__FIELDS__()); -NNVM_REGISTER_OP(_contrib_adamw_update) +NNVM_REGISTER_OP(_adamw_update) .describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. @@ -132,7 +133,8 @@ the update is skipped. .add_argument("mean", "NDArray-or-Symbol", "Moving mean") .add_argument("var", "NDArray-or-Symbol", "Moving variance") .add_argument("rescale_grad", "NDArray-or-Symbol", - "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.") + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") .add_arguments(AdamWParam::__FIELDS__()); } // namespace op diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index e21b83b8aba6..1521749904b9 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -50,10 +50,10 @@ inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs, }); } -NNVM_REGISTER_OP(_contrib_adamw_update) +NNVM_REGISTER_OP(_adamw_update) .set_attr("FCompute", MPUpdateGPU); -NNVM_REGISTER_OP(_contrib_mp_adamw_update) +NNVM_REGISTER_OP(_mp_adamw_update) .set_attr("FCompute", MPUpdateGPU); } // namespace op diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py index dad7bed3a923..675cc94c64f1 100644 --- a/tests/python/unittest/test_contrib_optimizer.py +++ b/tests/python/unittest/test_contrib_optimizer.py @@ -107,6 +107,12 @@ def test_adamw(): kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + # update is skipped for rescale = nan scalar + mx.nd.contrib.adamw_update(weight, grad, m, v, + np.nan, out=weight, **kwargs) + # weight remains unchanged + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + # update is skipped for rescale = 0 mx.nd.contrib.adamw_update(weight, grad, m, v, rescale_grad * 0, out=weight, **kwargs) @@ -134,6 +140,12 @@ def test_adamw(): mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy()) + # multi-precision update is skipped for rescale = nan scalar + mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, + np.nan, out=weight_fp16, **kwargs) + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy()) + # multi-precision update is skipped for rescale = inf mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, rescale_grad * np.inf, out=weight_fp16, **kwargs) From d17fc1fbaa2b60869b0000f9ccc605d6c4b792fb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 22 Feb 2019 23:35:04 +0000 Subject: [PATCH 2/3] remove two copies of documentation for adamw --- python/mxnet/ndarray/contrib.py | 124 ------------------------------- python/mxnet/ndarray/register.py | 9 +++ python/mxnet/symbol/contrib.py | 124 ------------------------------- python/mxnet/symbol/register.py | 10 +++ 4 files changed, 19 insertions(+), 248 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b15415e879a5..74c355dc1288 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -545,67 +545,6 @@ def isnan(data): def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): - """Update function for AdamW optimizer. - - AdamW is seen as a modification of Adam by decoupling the weight - decay from the optimization steps taken w.r.t. the loss function. - - Adam update consists of the following steps, where g represents gradient and m, v - are 1st and 2nd order moment estimates (mean and variance). - - .. math:: - - g_t = \nabla J(W_{t-1})\\ - m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ - v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ - W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) - - It updates the weights using:: - - m = beta1*m + (1-beta1)*grad - v = beta2*v + (1-beta2)*(grad**2) - w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) - - Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, - or 0, the update is skipped. - - Parameters - ---------- - weight : NDArray - Weight - grad : NDArray - Gradient - mean : NDArray - Moving mean - var : NDArray - Moving variance - rescale_grad : float or NDArray - Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped. - lr : float - Learning rate - eta : float - Learning rate schedule multiplier - beta1 : float, optional, default is 0.9 - The decay rate for the 1st moment estimates. - beta2 : float, optional, default is 0.999 - The decay rate for the 2nd moment estimates. - epsilon : float, optional, default is 1e-08 - A small constant for numerical stability. - wd : float, optional, default is 0 - Weight decay augments the objective function with a regularization term that penalizes - large weights. The penalty scales with the square of the magnitude of each weight. - clip_gradient : float, optional, default is -1 - Clip gradient to the range of [-clip_gradient, clip_gradient] - If clip_gradient <= 0, gradient clipping is turned off. - grad = max(min(grad, clip_gradient), -clip_gradient). - out : NDArray, optional - The output NDArray to hold the result. - - Returns - ------- - output: NDArray - - """ if not isinstance(rescale_grad, ndarray.NDArray): rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context) else: @@ -619,69 +558,6 @@ def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): - """Update function for multi-precision AdamW optimizer. - - AdamW is seen as a modification of Adam by decoupling the weight - decay from the optimization steps taken w.r.t. the loss function. - - Adam update consists of the following steps, where g represents gradient and m, v - are 1st and 2nd order moment estimates (mean and variance). - - .. math:: - - g_t = \nabla J(W_{t-1})\\ - m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ - v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ - W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) - - It updates the weights using:: - - m = beta1*m + (1-beta1)*grad - v = beta2*v + (1-beta2)*(grad**2) - w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) - - Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, - or 0, the update is skipped. - - Parameters - ---------- - weight : NDArray - Weight - grad : NDArray - Gradient - mean : NDArray - Moving mean - var : NDArray - Moving variance - weight32 : NDArray - Weight in fp32. - rescale_grad : float or NDArray - Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped. - lr : float - Learning rate - eta : float - Learning rate schedule multiplier - beta1 : float, optional, default is 0.9 - The decay rate for the 1st moment estimates. - beta2 : float, optional, default is 0.999 - The decay rate for the 2nd moment estimates. - epsilon : float, optional, default is 1e-08 - A small constant for numerical stability. - wd : float, optional, default is 0 - Weight decay augments the objective function with a regularization term that penalizes - large weights. The penalty scales with the square of the magnitude of each weight. - clip_gradient : float, optional, default is -1 - Clip gradient to the range of [-clip_gradient, clip_gradient] - If clip_gradient <= 0, gradient clipping is turned off. - grad = max(min(grad, clip_gradient), -clip_gradient). - out : NDArray, optional - The output NDArray to hold the result. - - Returns - ------- - output: NDArray - - """ if not isinstance(rescale_grad, ndarray.NDArray): rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context) else: diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index 3b19a772411d..ed5482b7c503 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -167,3 +167,12 @@ def _make_ndarray_function(handle, name, func_name): return ndarray_function _init_op_module('mxnet', 'ndarray', _make_ndarray_function) +# Update operator documentation with added float support +# Note that we can only do this after the op module is initialized +# Otherwise the backend operators cannot be found +from .contrib import adamw_update, mp_adamw_update +from ._internal import _adamw_update, _mp_adamw_update +adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : NDArray", + "rescale_grad : NDArray or float") +mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : NDArray", + "rescale_grad : NDArray or float") diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 5d0f63ba763c..d1048df7cd6f 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -730,67 +730,6 @@ def _union_inputs(*graphs): def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): - """Update function for AdamW optimizer. - - AdamW is seen as a modification of Adam by decoupling the weight - decay from the optimization steps taken w.r.t. the loss function. - - Adam update consists of the following steps, where g represents gradient and m, v - are 1st and 2nd order moment estimates (mean and variance). - - .. math:: - - g_t = \nabla J(W_{t-1})\\ - m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ - v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ - W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) - - It updates the weights using:: - - m = beta1*m + (1-beta1)*grad - v = beta2*v + (1-beta2)*(grad**2) - w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) - - Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, - or 0, the update is skipped. - - Parameters - ---------- - weight : Symbol - Weight - grad : Symbol - Gradient - mean : Symbol - Moving mean - var : Symbol - Moving variance - rescale_grad : float or Symbol - Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped. - lr : float - Learning rate - eta : float - Learning rate schedule multiplier - beta1 : float, optional, default is 0.9 - The decay rate for the 1st moment estimates. - beta2 : float, optional, default is 0.999 - The decay rate for the 2nd moment estimates. - epsilon : float, optional, default is 1e-08 - A small constant for numerical stability. - wd : float, optional, default is 0 - Weight decay augments the objective function with a regularization term that penalizes - large weights. The penalty scales with the square of the magnitude of each weight. - clip_gradient : float, optional, default is -1 - Clip gradient to the range of [-clip_gradient, clip_gradient] - If clip_gradient <= 0, gradient clipping is turned off. - grad = max(min(grad, clip_gradient), -clip_gradient). - out : Symbol, optional - The output NDArray to hold the result. - - Returns - ------- - output: Symbol - - """ if not isinstance(rescale_grad, Symbol): rescale_grad = symbol.full(shape=(1,), val=rescale_grad) return symbol._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var, @@ -802,69 +741,6 @@ def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): - """Update function for multi-precision AdamW optimizer. - - AdamW is seen as a modification of Adam by decoupling the weight - decay from the optimization steps taken w.r.t. the loss function. - - Adam update consists of the following steps, where g represents gradient and m, v - are 1st and 2nd order moment estimates (mean and variance). - - .. math:: - - g_t = \nabla J(W_{t-1})\\ - m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ - v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ - W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) - - It updates the weights using:: - - m = beta1*m + (1-beta1)*grad - v = beta2*v + (1-beta2)*(grad**2) - w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) - - Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, - or 0, the update is skipped. - - Parameters - ---------- - weight : Symbol - Weight - grad : Symbol - Gradient - mean : Symbol - Moving mean - var : Symbol - Moving variance - weight32 : Symbol - Weight in fp32. - rescale_grad : float or Symbol - Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, the update is skipped. - lr : float - Learning rate - eta : float - Learning rate schedule multiplier - beta1 : float, optional, default is 0.9 - The decay rate for the 1st moment estimates. - beta2 : float, optional, default is 0.999 - The decay rate for the 2nd moment estimates. - epsilon : float, optional, default is 1e-08 - A small constant for numerical stability. - wd : float, optional, default is 0 - Weight decay augments the objective function with a regularization term that penalizes - large weights. The penalty scales with the square of the magnitude of each weight. - clip_gradient : float, optional, default is -1 - Clip gradient to the range of [-clip_gradient, clip_gradient] - If clip_gradient <= 0, gradient clipping is turned off. - grad = max(min(grad, clip_gradient), -clip_gradient). - out : Symbol, optional - The output Symbol to hold the result. - - Returns - ------- - output: Symbol - - """ if not isinstance(rescale_grad, Symbol): rescale_grad = symbol.full(shape=(1,), val=rescale_grad) return symbol._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var, diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index c147914ddb70..65d8daee7a0a 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -208,3 +208,13 @@ def _make_symbol_function(handle, name, func_name): return symbol_function _init_op_module('mxnet', 'symbol', _make_symbol_function) + +# Update operator documentation with added float support +# Note that we can only do this after the op module is initialized +# Otherwise the backend operators cannot be found +from .contrib import adamw_update, mp_adamw_update +from ._internal import _adamw_update, _mp_adamw_update +adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : Symbol", + "rescale_grad : Symbol or float") +mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : Symbol", + "rescale_grad : Symbol or float") From 7907641c2ffe845597c93d290d2b1cc0a36fc5f6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 23 Feb 2019 00:18:59 +0000 Subject: [PATCH 3/3] fix lint --- python/mxnet/ndarray/register.py | 2 ++ python/mxnet/symbol/register.py | 1 + 2 files changed, 3 insertions(+) diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index ed5482b7c503..05d7f17a8fc1 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -167,9 +167,11 @@ def _make_ndarray_function(handle, name, func_name): return ndarray_function _init_op_module('mxnet', 'ndarray', _make_ndarray_function) + # Update operator documentation with added float support # Note that we can only do this after the op module is initialized # Otherwise the backend operators cannot be found +# pylint: disable=wrong-import-position from .contrib import adamw_update, mp_adamw_update from ._internal import _adamw_update, _mp_adamw_update adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : NDArray", diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index 65d8daee7a0a..15c8e5e1fa68 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -212,6 +212,7 @@ def _make_symbol_function(handle, name, func_name): # Update operator documentation with added float support # Note that we can only do this after the op module is initialized # Otherwise the backend operators cannot be found +# pylint: disable=wrong-import-position from .contrib import adamw_update, mp_adamw_update from ._internal import _adamw_update, _mp_adamw_update adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : Symbol",