Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add unit test for gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Jan 29, 2019
1 parent 9672091 commit b31d19f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tests/python/gpu/test_gluon_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,14 @@ def test_normalize():
# Invalid Input - Channel neither 1 or 3
invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)
assertRaises(MXNetError, normalize_transformer, invalid_data_in)


@with_seed()
def test_resize():
# Test with normal case 3D input
data_in_3d = nd.random.uniform(0, 255, (300, 300, 3)).astype('uint8')
out_nd_3d = transforms.Resize((100, 100))(data_in_3d)
data_in_4d_nchw = nd.moveaxis(nd.expand_dims(data_in_3d, axis=0), 3, 1)
data_expected_3d = (nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3))[0]
assert_almost_equal(out_nd_3d.asnumpy(), data_expected_3d.asnumpy())
1 change: 1 addition & 0 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_normalize():
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)


@with_seed()
def test_resize():
def _test_resize_with_diff_type(dtype):
Expand Down

0 comments on commit b31d19f

Please sign in to comment.