Skip to content

Commit

Permalink
[op] add back support for scalar type rescale_grad argument for adamw…
Browse files Browse the repository at this point in the history
…_update/mp_adamw_update (apache#14221)

* support scalar

* remove two copies of documentation for adamw

* fix lint
  • Loading branch information
eric-haibin-lin authored and vdantu committed Mar 31, 2019
1 parent 145a30c commit 99c434e
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 6 deletions.
26 changes: 26 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,29 @@ def isnan(data):
<NDArray 2 @cpu(0)>
"""
return data != data

def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
if not isinstance(rescale_grad, ndarray.NDArray):
rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
else:
rescale_grad = rescale_grad.as_in_context(weight.context)
return ndarray._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)

def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
name=None, **kwargs):
if not isinstance(rescale_grad, ndarray.NDArray):
rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
else:
rescale_grad = rescale_grad.as_in_context(weight.context)
return ndarray._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
weight32=weight32,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)
11 changes: 11 additions & 0 deletions python/mxnet/ndarray/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,14 @@ def _make_ndarray_function(handle, name, func_name):
return ndarray_function

_init_op_module('mxnet', 'ndarray', _make_ndarray_function)

# Update operator documentation with added float support
# Note that we can only do this after the op module is initialized
# Otherwise the backend operators cannot be found
# pylint: disable=wrong-import-position
from .contrib import adamw_update, mp_adamw_update
from ._internal import _adamw_update, _mp_adamw_update
adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : NDArray",
"rescale_grad : NDArray or float")
mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : NDArray",
"rescale_grad : NDArray or float")
22 changes: 22 additions & 0 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,25 @@ def _union_inputs(*graphs):
outputs = [result[i] for i in range(then_num_outputs)]
outputs, _ = _regroup(outputs, then_fmt)
return outputs

def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
if not isinstance(rescale_grad, Symbol):
rescale_grad = symbol.full(shape=(1,), val=rescale_grad)
return symbol._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)

def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
name=None, **kwargs):
if not isinstance(rescale_grad, Symbol):
rescale_grad = symbol.full(shape=(1,), val=rescale_grad)
return symbol._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
weight32=weight32,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)
11 changes: 11 additions & 0 deletions python/mxnet/symbol/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,14 @@ def _make_symbol_function(handle, name, func_name):
return symbol_function

_init_op_module('mxnet', 'symbol', _make_symbol_function)

# Update operator documentation with added float support
# Note that we can only do this after the op module is initialized
# Otherwise the backend operators cannot be found
# pylint: disable=wrong-import-position
from .contrib import adamw_update, mp_adamw_update
from ._internal import _adamw_update, _mp_adamw_update
adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : Symbol",
"rescale_grad : Symbol or float")
mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : Symbol",
"rescale_grad : Symbol or float")
10 changes: 6 additions & 4 deletions src/operator/contrib/adamw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ inline void MPUpdateCPU(const nnvm::NodeAttrs& attrs,
});
}

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
NNVM_REGISTER_OP(_mp_adamw_update)
.describe(R"code(Update function for multi-precision AdamW optimizer.
AdamW is seen as a modification of Adam by decoupling the weight decay from the
Expand Down Expand Up @@ -91,10 +91,11 @@ the update is skipped.
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
"Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
"the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

NNVM_REGISTER_OP(_contrib_adamw_update)
NNVM_REGISTER_OP(_adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.
Expand Down Expand Up @@ -132,7 +133,8 @@ the update is skipped.
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_argument("rescale_grad", "NDArray-or-Symbol",
"Rescale gradient to rescale_grad * grad. If NaN, the update is skipped.")
"Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, "
"the update is skipped.")
.add_arguments(AdamWParam::__FIELDS__());

} // namespace op
Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ inline void MPUpdateGPU(const nnvm::NodeAttrs& attrs,
});
}

NNVM_REGISTER_OP(_contrib_adamw_update)
NNVM_REGISTER_OP(_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<AdamWUpdate>);

NNVM_REGISTER_OP(_contrib_mp_adamw_update)
NNVM_REGISTER_OP(_mp_adamw_update)
.set_attr<FCompute>("FCompute<gpu>", MPUpdateGPU<MPAdamWUpdate>);

} // namespace op
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_contrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def test_adamw():
kwargs = {'eta': eta, 'lr': lr, 'wd': wd, 'epsilon': epsilon,
'beta1': beta1, 'beta2': beta2}

# update is skipped for rescale = nan scalar
mx.nd.contrib.adamw_update(weight, grad, m, v,
np.nan, out=weight, **kwargs)
# weight remains unchanged
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())

# update is skipped for rescale = 0
mx.nd.contrib.adamw_update(weight, grad, m, v,
rescale_grad * 0, out=weight, **kwargs)
Expand Down Expand Up @@ -134,6 +140,12 @@ def test_adamw():
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())

# multi-precision update is skipped for rescale = nan scalar
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
np.nan, out=weight_fp16, **kwargs)
mx.test_utils.assert_almost_equal(weight_ref.asnumpy(), weight.asnumpy())
mx.test_utils.assert_almost_equal(weight_fp16_ref.asnumpy(), weight_fp16.asnumpy())

# multi-precision update is skipped for rescale = inf
mx.nd.contrib.mp_adamw_update(weight_fp16, grad_fp16, m, v, weight,
rescale_grad * np.inf, out=weight_fp16, **kwargs)
Expand Down

0 comments on commit 99c434e

Please sign in to comment.