diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py index c7afc762bd80..3760c57523f0 100644 --- a/tests/python/gpu/test_gluon_transforms.py +++ b/tests/python/gpu/test_gluon_transforms.py @@ -69,4 +69,14 @@ def test_normalize(): # Invalid Input - Channel neither 1 or 3 invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)) normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) - assertRaises(MXNetError, normalize_transformer, invalid_data_in) \ No newline at end of file + assertRaises(MXNetError, normalize_transformer, invalid_data_in) + + +@with_seed() +def test_resize(): + # Test with normal case 3D input + data_in_3d = nd.random.uniform(0, 255, (300, 300, 3)).astype('uint8') + out_nd_3d = transforms.Resize((100, 100))(data_in_3d) + data_in_4d_nchw = nd.moveaxis(nd.expand_dims(data_in_3d, axis=0), 3, 1) + data_expected_3d = (nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3))[0] + assert_almost_equal(out_nd_3d.asnumpy(), data_expected_3d.asnumpy()) diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 154edb866730..0bfaa78b243c 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -67,6 +67,7 @@ def test_normalize(): normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) assertRaises(MXNetError, normalize_transformer, invalid_data_in) + @with_seed() def test_resize(): def _test_resize_with_diff_type(dtype):