From 12ffd997e54be7d5d313406c03a7bb1b4ae3b21e Mon Sep 17 00:00:00 2001 From: "Per G. da Silva" Date: Wed, 26 Sep 2018 12:49:56 +0200 Subject: [PATCH] Adds extra randomness to imagerecordio iterator test --- tests/python/unittest/test_io.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index ce664b6f8714..f051472d1be7 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -450,7 +450,8 @@ def test_ImageRecordIter_seed_augmentation(): max_shear_ratio=2, seed_aug=seed_aug) batch = dataiter.next() - data = batch.data[0].asnumpy().astype(np.uint8) + test_index = rnd.randint(0, len(batch.data)) + data = batch.data[test_index].asnumpy().astype(np.uint8) dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", @@ -469,7 +470,7 @@ def test_ImageRecordIter_seed_augmentation(): max_shear_ratio=2, seed_aug=seed_aug) batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) + data2 = batch.data[test_index].asnumpy().astype(np.uint8) assert(np.array_equal(data,data2)) # check whether to get different images after change seed_aug @@ -490,7 +491,7 @@ def test_ImageRecordIter_seed_augmentation(): max_shear_ratio=2, seed_aug=seed_aug+1) batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) + data2 = batch.data[test_index].asnumpy().astype(np.uint8) assert(not np.array_equal(data,data2)) # check whether seed_aug changes the iterator behavior @@ -502,7 +503,8 @@ def test_ImageRecordIter_seed_augmentation(): batch_size=3, seed_aug=seed_aug) batch = dataiter.next() - data = batch.data[0].asnumpy().astype(np.uint8) + test_index = rnd.randint(0, len(batch.data)) + data = batch.data[test_index].asnumpy().astype(np.uint8) dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", @@ -512,7 +514,7 @@ def test_ImageRecordIter_seed_augmentation(): batch_size=3, seed_aug=seed_aug) batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) + data2 = batch.data[test_index].asnumpy().astype(np.uint8) assert(np.array_equal(data,data2)) if __name__ == "__main__":