Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Aug 6, 2018
1 parent d4142fa commit 4546948
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
11 changes: 7 additions & 4 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,13 @@ def test_gluon_ctc_consistency():
def test_global_norm_clip_multi_device():
x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
for check_isfinite in [True, False]:
for check_scale in [True, False]:
norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite,
check_scale=check_scale)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3, 3)) / 5)
assert_almost_equal(x2.asnumpy(), np.ones((4, 4)) / 5)


def _check_batchnorm_result(input, num_devices=1, cuda=False):
Expand Down
9 changes: 6 additions & 3 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,10 +735,11 @@ def test_sequential_warning():
@with_seed()
def test_global_norm_clip():
stypes = ['default', 'row_sparse']
def check_global_norm_clip(stype):
def check_global_norm_clip(stype, check_isfinite, check_scale):
x1 = mx.nd.ones((3,3)).tostype(stype)
x2 = mx.nd.ones((4,4)).tostype(stype)
norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite,
check_scale=check_scale)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
Expand All @@ -750,7 +751,9 @@ def check_global_norm_clip(stype):
assert len(w) == 1

for stype in stypes:
check_global_norm_clip(stype)
for check_isfinite in [True, False]:
for check_scale in [True, False]:
check_global_norm_clip(stype, check_isfinite, check_scale)

@with_seed()
def test_embedding():
Expand Down

0 comments on commit 4546948

Please sign in to comment.