Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Adds extra randomness to imagerecordio iterator test
Browse files Browse the repository at this point in the history
  • Loading branch information
perdasilva committed Oct 1, 2018
1 parent de34abd commit 1550da3
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
Expand All @@ -397,7 +398,7 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

# check whether to get different images after change seed_aug
Expand All @@ -418,7 +419,7 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug+1)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(not np.array_equal(data,data2))

# check whether seed_aug changes the iterator behavior
Expand All @@ -430,7 +431,8 @@ def test_ImageRecordIter_seed_augmentation():
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
Expand All @@ -440,7 +442,7 @@ def test_ImageRecordIter_seed_augmentation():
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

if __name__ == "__main__":
Expand Down

0 comments on commit 1550da3

Please sign in to comment.