diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 6bbee8a6fbc3..74c355dc1288 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -542,3 +542,29 @@ def isnan(data): """ return data != data + +def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, + epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): + if not isinstance(rescale_grad, ndarray.NDArray): + rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context) + else: + rescale_grad = rescale_grad.as_in_context(weight.context) + return ndarray._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) + +def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9, + beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, + name=None, **kwargs): + if not isinstance(rescale_grad, ndarray.NDArray): + rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context) + else: + rescale_grad = rescale_grad.as_in_context(weight.context) + return ndarray._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var, + weight32=weight32, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index 3b19a772411d..05d7f17a8fc1 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -167,3 +167,14 @@ def _make_ndarray_function(handle, name, func_name): return ndarray_function _init_op_module('mxnet', 'ndarray', _make_ndarray_function) + +# Update operator documentation with added float support +# Note that we can only do this after the op module is initialized +# Otherwise the backend operators cannot be found +# pylint: disable=wrong-import-position +from .contrib import adamw_update, mp_adamw_update +from ._internal import _adamw_update, _mp_adamw_update +adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : NDArray", + "rescale_grad : NDArray or float") +mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : NDArray", + "rescale_grad : NDArray or float") diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index a83227a0261c..d1048df7cd6f 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -727,3 +727,25 @@ def _union_inputs(*graphs): outputs = [result[i] for i in range(then_num_outputs)] outputs, _ = _regroup(outputs, then_fmt) return outputs + +def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, + epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): + if not isinstance(rescale_grad, Symbol): + rescale_grad = symbol.full(shape=(1,), val=rescale_grad) + return symbol._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) + +def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9, + beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, + name=None, **kwargs): + if not isinstance(rescale_grad, Symbol): + rescale_grad = symbol.full(shape=(1,), val=rescale_grad) + return symbol._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var, + weight32=weight32, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index c147914ddb70..15c8e5e1fa68 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -208,3 +208,14 @@ def _make_symbol_function(handle, name, func_name): return symbol_function _init_op_module('mxnet', 'symbol', _make_symbol_function) + +# Update operator documentation with added float support +# Note that we can only do this after the op module is initialized +# Otherwise the backend operators cannot be found +# pylint: disable=wrong-import-position +from .contrib import adamw_update, mp_adamw_update +from ._internal import _adamw_update, _mp_adamw_update +adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : Symbol", + "rescale_grad : Symbol or float") +mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : Symbol", + "rescale_grad : Symbol or float") diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index 2fbc39743c93..874cce8d8772 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -50,7 +50,7 @@ inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs, }); } -NNVM_REGISTER_OP(_contrib_mp_adamw_update) +NNVM_REGISTER_OP(_mp_adamw_update) .describe(R"code(Update function for multi-precision AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the @@ -91,10 +91,11 @@ the update is skipped. .add_argument("var", "NDArray-or-Symbol", "Moving variance") .add_argument("weight32", "NDArray-or-Symbol", "Weight32") .add_argument("rescale_grad", "NDArray-or-Symbol", - "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.") + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") .add_arguments(AdamWParam::__FIELDS__()); -NNVM_REGISTER_OP(_contrib_adamw_update) +NNVM_REGISTER_OP(_adamw_update) .describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. @@ -132,7 +133,8 @@ the update is skipped. .add_argument("mean", "NDArray-or-Symbol", "Moving mean") .add_argument("var", "NDArray-or-Symbol", "Moving variance") .add_argument("rescale_grad", "NDArray-or-Symbol", - "Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.") + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") .add_arguments(AdamWParam::__FIELDS__()); } // namespace op diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index e21b83b8aba6..1521749904b9 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -50,10 +50,10 @@ inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs, }); } -NNVM_REGISTER_OP(_contrib_adamw_update) +NNVM_REGISTER_OP(_adamw_update) .set_attr("FCompute", MPUpdateGPU); -NNVM_REGISTER_OP(_contrib_mp_adamw_update) +NNVM_REGISTER_OP(_mp_adamw_update) .set_attr("FCompute", MPUpdateGPU); } // namespace op diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py index dad7bed3a923..675cc94c64f1 100644 --- a/tests/python/unittest/test_contrib_optimizer.py +++ b/tests/python/unittest/test_contrib_optimizer.py @@ -107,6 +107,12 @@ def test_adamw(): kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + # update is skipped for rescale = nan scalar + mx.nd.contrib.adamw_update(weight, grad, m, v, + np.nan, out=weight, **kwargs) + # weight remains unchanged + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + # update is skipped for rescale = 0 mx.nd.contrib.adamw_update(weight, grad, m, v, rescale_grad * 0, out=weight, **kwargs) @@ -134,6 +140,12 @@ def test_adamw(): mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy()) + # multi-precision update is skipped for rescale = nan scalar + mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, + np.nan, out=weight_fp16, **kwargs) + mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy()) + mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy()) + # multi-precision update is skipped for rescale = inf mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, rescale_grad * np.inf, out=weight_fp16, **kwargs)