diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index ce664b6f8714..f051472d1be7 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -450,7 +450,8 @@ def test_ImageRecordIter_seed_augmentation(): max_shear_ratio=2, seed_aug=seed_aug) batch = dataiter.next() - data = batch.data[0].asnumpy().astype(np.uint8) + test_index = rnd.randint(0, len(batch.data)) + data = batch.data[test_index].asnumpy().astype(np.uint8) dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", @@ -469,7 +470,7 @@ def test_ImageRecordIter_seed_augmentation(): max_shear_ratio=2, seed_aug=seed_aug) batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) + data2 = batch.data[test_index].asnumpy().astype(np.uint8) assert(np.array_equal(data,data2)) # check whether to get different images after change seed_aug @@ -490,7 +491,7 @@ def test_ImageRecordIter_seed_augmentation(): max_shear_ratio=2, seed_aug=seed_aug+1) batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) + data2 = batch.data[test_index].asnumpy().astype(np.uint8) assert(not np.array_equal(data,data2)) # check whether seed_aug changes the iterator behavior @@ -502,7 +503,8 @@ def test_ImageRecordIter_seed_augmentation(): batch_size=3, seed_aug=seed_aug) batch = dataiter.next() - data = batch.data[0].asnumpy().astype(np.uint8) + test_index = rnd.randint(0, len(batch.data)) + data = batch.data[test_index].asnumpy().astype(np.uint8) dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", @@ -512,7 +514,7 @@ def test_ImageRecordIter_seed_augmentation(): batch_size=3, seed_aug=seed_aug) batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) + data2 = batch.data[test_index].asnumpy().astype(np.uint8) assert(np.array_equal(data,data2)) if __name__ == "__main__":