Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Make check_isfinite, check_scale optional in clip_global_norm (#12042)
Browse files Browse the repository at this point in the history
* Make check_isfinite, check_scale optional in clip_global_norm

If both are set to false, clip_global_norm does not force any synchronization
and throughput can be increased.

* Add tests

* Remove check_scale

* Document return type

* Fix test_gluon_gpu
  • Loading branch information
leezu authored and eric-haibin-lin committed Aug 27, 2018
1 parent 48d2155 commit 308ada1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 20 deletions.
38 changes: 29 additions & 9 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,23 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)]


def clip_global_norm(arrays, max_norm):
def clip_global_norm(arrays, max_norm, check_isfinite=True):
"""Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
Parameters
----------
arrays : list of NDArray
max_norm : float
check_isfinite : bool, default True
If True, check that the total_norm is finite (not nan or inf). This
requires a blocking .asscalar() call.
Returns
-------
NDArray or float
Total norm. Return type is NDArray of shape (1,) if check_isfinite is
False. Otherwise a float is returned.
"""
def _norm(array):
if array.stype == 'default':
Expand All @@ -126,15 +141,20 @@ def _norm(array):
assert len(arrays) > 0
ctx = arrays[0].context
total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays])
total_norm = ndarray.sqrt(total_norm).asscalar()
if not np.isfinite(total_norm):
warnings.warn(UserWarning('nan or inf is detected. Clipping results will be undefined.'),
stacklevel=2)
total_norm = ndarray.sqrt(total_norm)
if check_isfinite:
if not np.isfinite(total_norm.asscalar()):
warnings.warn(
UserWarning('nan or inf is detected. '
'Clipping results will be undefined.'), stacklevel=2)
scale = max_norm / (total_norm + 1e-8)
if scale < 1.0:
for arr in arrays:
arr *= scale
return total_norm
scale = ndarray.min(ndarray.concat(scale, ndarray.ones(1, ctx=ctx), dim=0))
for arr in arrays:
arr *= scale.as_in_context(arr.context)
if check_isfinite:
return total_norm.asscalar()
else:
return total_norm


def _indent(s_, numSpaces):
Expand Down
16 changes: 10 additions & 6 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,16 @@ def test_gluon_ctc_consistency():

@with_seed()
def test_global_norm_clip_multi_device():
x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
for check_isfinite in [True, False]:
x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite)
if check_isfinite:
assert norm == 5.0
else:
assert norm.asscalar() == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3, 3)) / 5)
assert_almost_equal(x2.asnumpy(), np.ones((4, 4)) / 5)


def _check_batchnorm_result(input, num_devices=1, cuda=False):
Expand Down
11 changes: 6 additions & 5 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,22 +735,23 @@ def test_sequential_warning():
@with_seed()
def test_global_norm_clip():
stypes = ['default', 'row_sparse']
def check_global_norm_clip(stype):
def check_global_norm_clip(stype, check_isfinite):
x1 = mx.nd.ones((3,3)).tostype(stype)
x2 = mx.nd.ones((4,4)).tostype(stype)
norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)

x3 = mx.nd.array([1.0, 2.0, float('nan')]).tostype(stype)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
gluon.utils.clip_global_norm([x1, x3], 2.0)
assert len(w) == 1
gluon.utils.clip_global_norm([x1, x3], 2.0, check_isfinite=check_isfinite)
assert len(w) == check_isfinite

for stype in stypes:
check_global_norm_clip(stype)
for check_isfinite in [True, False]:
check_global_norm_clip(stype, check_isfinite)

@with_seed()
def test_embedding():
Expand Down

0 comments on commit 308ada1

Please sign in to comment.