Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Adds extra randomness to imagerecordio iterator test
Browse files Browse the repository at this point in the history
  • Loading branch information
perdasilva committed Nov 5, 2018
1 parent 98a9ee7 commit 12ffd99
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
Expand All @@ -469,7 +470,7 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

# check whether to get different images after change seed_aug
Expand All @@ -490,7 +491,7 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug+1)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(not np.array_equal(data,data2))

# check whether seed_aug changes the iterator behavior
Expand All @@ -502,7 +503,8 @@ def test_ImageRecordIter_seed_augmentation():
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
Expand All @@ -512,7 +514,7 @@ def test_ImageRecordIter_seed_augmentation():
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

if __name__ == "__main__":
Expand Down

0 comments on commit 12ffd99

Please sign in to comment.