Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update image record iterator tests to check the whole iterator not on…
Browse files Browse the repository at this point in the history
…ly first image
  • Loading branch information
perdasilva committed Jan 9, 2019
1 parent a856fde commit 1861e85
Showing 1 changed file with 60 additions and 21 deletions.
81 changes: 60 additions & 21 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
import sys
from common import assertRaises
import unittest
try:
from itertools import izip_longest as zip_longest
except:
from itertools import zip_longest


def test_MNISTIter():
Expand Down Expand Up @@ -427,13 +431,56 @@ def check_CSVIter_synthetic(dtype='float32'):
for dtype in ['int32', 'int64', 'float32']:
check_CSVIter_synthetic(dtype=dtype)

# @unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11359")
def test_ImageRecordIter_seed_augmentation():
get_cifar10()
seed_aug = 3

def assert_dataiter_items_equals(dataiter1, dataiter2):
"""
Asserts that two data iterators have the same numbner of batches,
that the batches have the same number of items, and that the items
are the equal.
"""
for batch1, batch2 in zip_longest(dataiter1, dataiter2):

# ensure iterators contain the same number of batches
# zip_longest will return None if on of the iterators have run out of batches
assert batch1 and batch2, 'The iterators do not contain the same number of batches'

# ensure batches are of same length
assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length'

# ensure batch data is the same
for i in range(0, len(batch1.data)):
data1 = batch1.data[i].asnumpy().astype(np.uint8)
data2 = batch2.data[i].asnumpy().astype(np.uint8)
assert(np.array_equal(data1, data2))

def assert_dataiter_items_not_equals(dataiter1, dataiter2):
"""
Asserts that two data iterators have the same numbner of batches,
that the batches have the same number of items, and that the items
are the _not_ equal.
"""
for batch1, batch2 in zip_longest(dataiter1, dataiter2):

# ensure iterators are of same length
# zip_longest will return None if on of the iterators have run out of batches
assert batch1 and batch2, 'The iterators do not contain the same number of batches'

# ensure batches are of same length
assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length'

# ensure batch data is the same
for i in range(0, len(batch1.data)):
data1 = batch1.data[i].asnumpy().astype(np.uint8)
data2 = batch2.data[i].asnumpy().astype(np.uint8)
if not np.array_equal(data1, data2):
return
assert False, 'Expected data iterators to be different, but they are the same'

# check whether to get constant images after fixing seed_aug
dataiter = mx.io.ImageRecordIter(
dataiter1 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -449,11 +496,8 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -469,12 +513,12 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

assert_dataiter_items_equals(dataiter1, dataiter2)

# check whether to get different images after change seed_aug
dataiter = mx.io.ImageRecordIter(
dataiter1.reset()
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -490,32 +534,27 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug+1)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(not np.array_equal(data,data2))

assert_dataiter_items_not_equals(dataiter1, dataiter2)

# check whether seed_aug changes the iterator behavior
dataiter = mx.io.ImageRecordIter(
dataiter1 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

assert_dataiter_items_equals(dataiter1, dataiter2)

if __name__ == "__main__":
test_NDArrayIter()
Expand Down

0 comments on commit 1861e85

Please sign in to comment.