Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Aggregated adamw update (#16398)
Browse files Browse the repository at this point in the history
* Trigger CI

* MxNet operator for aggregated Adam update

* Fixing problem with getRescaleGrad(...) call in Python2
and some minor changes requested by Przemek

* Fix a problem appearing in Python2

* Minor cleanup

* Changing function name

* Trigger CI

* Eliminating "asnumpy()" conversion

* Trigger CI
  • Loading branch information
drivanov authored and apeforest committed Nov 6, 2019
1 parent 0415a2f commit 8c22fac
Show file tree
Hide file tree
Showing 6 changed files with 666 additions and 195 deletions.
56 changes: 46 additions & 10 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import absolute_import
import math
import numpy as np
import mxnet as mx
from ..context import current_context
from ..random import uniform
from ..base import _as_list
Expand All @@ -32,6 +33,9 @@

__all__ = ["rand_zipfian", "foreach", "while_loop", "cond", "isinf", "isfinite", "isnan"]

def _flatten_list(nested_list):
return [item for sublist in nested_list for item in sublist]

# pylint: disable=line-too-long
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
"""Draw random samples from an approximately log-uniform or Zipfian distribution.
Expand Down Expand Up @@ -514,7 +518,7 @@ def isfinite(data):
[0. 0. 0. 1.]
<NDArray 4 @cpu(0)>
"""
is_data_not_nan = data == data # pylint: disable=comparison-with-itself
is_data_not_nan = data == data # pylint: disable=comparison-with-itself
is_data_not_infinite = data.abs() != np.inf
return ndarray.logical_and(is_data_not_infinite, is_data_not_nan)

Expand Down Expand Up @@ -542,14 +546,17 @@ def isnan(data):
[1. 0.]
<NDArray 2 @cpu(0)>
"""
return data != data # pylint: disable=comparison-with-itself
return data != data # pylint: disable=comparison-with-itself

def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
def _get_rescale_grad(rescale_grad, ctx=mx.cpu()):
if not isinstance(rescale_grad, ndarray.NDArray):
rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
return ndarray.full(shape=(1,), val=rescale_grad, ctx=ctx)
else:
rescale_grad = rescale_grad.as_in_context(weight.context)
return rescale_grad.as_in_context(ctx)

def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999,
epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs):
rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context)
return ndarray._internal._adamw_update(weight=weight, grad=grad, mean=mean, var=var,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
Expand All @@ -559,13 +566,42 @@ def adamw_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta
def mp_adamw_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9,
beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None,
name=None, **kwargs):
if not isinstance(rescale_grad, ndarray.NDArray):
rescale_grad = ndarray.full(shape=(1,), val=rescale_grad, ctx=weight.context)
else:
rescale_grad = rescale_grad.as_in_context(weight.context)
rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context)
return ndarray._internal._mp_adamw_update(weight=weight, grad=grad, mean=mean, var=var,
weight32=weight32,
rescale_grad=rescale_grad, lr=lr, eta=eta,
beta1=beta1, beta2=beta2, epsilon=epsilon,
wd=wd, clip_gradient=clip_gradient, out=out,
name=name, **kwargs)

def multi_adamw_update(weights, grads, mean, var, rescale_grad, lrs, wds, etas,
out=None, name=None, size=0, **kwargs):
if not size:
size = len(weights)

rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context)
temp_list = _flatten_list(zip(weights, grads, mean, var)) + [rescale_grad]
return ndarray._internal._multi_adamw_update(*temp_list,
out=out,
num_weights=size,
lrs=lrs,
wds=wds,
etas=etas,
name=name,
**kwargs)

def multi_mp_adamw_update(weights, grads, mean, var, weights32, rescale_grad, lrs, wds, etas,
out=None, name=None, size=0, **kwargs):
if not size:
size = len(weights)

rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context)
temp_list = _flatten_list(zip(weights, grads, mean, var, weights32)) + [rescale_grad]
return ndarray._internal._multi_mp_adamw_update(*temp_list,
out=out,
num_weights=size,
lrs=lrs,
wds=wds,
etas=etas,
name=name,
**kwargs)
Loading

0 comments on commit 8c22fac

Please sign in to comment.