diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index ea51858c0b3b..ed70f8ccfc6e 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -723,13 +723,13 @@ def multi_mp_adabelief_update(weights, grads, mean, var, weights32, rescale_grad rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context) temp_list = _flatten_list(zip(weights, grads, mean, var, weights32)) + [rescale_grad] return ndarray._internal._multi_mp_adabelief_update(*temp_list, - out=out, - num_weights=size, - lrs=lrs, - wds=wds, - etas=etas, - name=name, - **kwargs) + out=out, + num_weights=size, + lrs=lrs, + wds=wds, + etas=etas, + name=name, + **kwargs) def multi_lans_update(weights, grads, mean, var, step_count, lrs, wds, out=None, num_tensors=0, **kwargs):