Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix race condition in unittest by pytest temp_dir fixtures (#18323)
Browse files Browse the repository at this point in the history
* serial tests

* pytest fixture temp_dir

* address comments
  • Loading branch information
zhreshold committed May 21, 2020
1 parent f4d0290 commit 67b5d31
Showing 1 changed file with 54 additions and 56 deletions.
110 changes: 54 additions & 56 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,25 @@ def test_array_dataset():
for i, x in enumerate(loader):
assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])


def prepare_record():
if not os.path.isdir("data/test_images"):
os.makedirs('data/test_images')
if not os.path.isdir("data/test_images/test_images"):
gluon.utils.download("http://data.mxnet.io/data/test_images.tar.gz", "data/test_images.tar.gz")
tarfile.open('data/test_images.tar.gz').extractall('data/test_images/')
if not os.path.exists('data/test.rec') or not os.path.exists('data/test.idx'):
imgs = os.listdir('data/test_images/test_images')
record = mx.recordio.MXIndexedRecordIO('data/test.idx', 'data/test.rec', 'w')
for i, img in enumerate(imgs):
str_img = open('data/test_images/test_images/'+img, 'rb').read()
@pytest.fixture(scope="session")
def prepare_record(tmpdir_factory):
test_images = tmpdir_factory.mktemp("test_images")
test_images_tar = test_images.join("test_images.tar.gz")
gluon.utils.download("http://data.mxnet.io/data/test_images.tar.gz", str(test_images_tar))
tarfile.open(test_images_tar).extractall(str(test_images))
imgs = os.listdir(str(test_images.join("test_images")))
record = mx.recordio.MXIndexedRecordIO(str(test_images.join("test.idx")), str(test_images.join("test.rec")), 'w')
for i, img in enumerate(imgs):
with open(str(test_images.join("test_images").join(img)), 'rb') as f:
str_img = f.read()
s = mx.recordio.pack((0, i, i, 0), str_img)
record.write_idx(i, s)
return 'data/test.rec'
return str(test_images.join('test.rec'))


@with_seed()
def test_recordimage_dataset():
recfile = prepare_record()
def test_recordimage_dataset(prepare_record):
recfile = prepare_record
fn = lambda x, y : (x, y)
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
loader = gluon.data.DataLoader(dataset, 1)
Expand All @@ -77,8 +76,8 @@ def test_recordimage_dataset():
assert y.asscalar() == i

@with_seed()
def test_recordimage_dataset_handle():
recfile = prepare_record()
def test_recordimage_dataset_handle(prepare_record):
recfile = prepare_record
class TmpTransform(mx.gluon.HybridBlock):
def hybrid_forward(self, F, x):
return x
Expand All @@ -99,8 +98,8 @@ def _dataset_transform_first_fn(x):
return x

@with_seed()
def test_recordimage_dataset_with_data_loader_multiworker():
recfile = prepare_record()
def test_recordimage_dataset_with_data_loader_multiworker(prepare_record):
recfile = prepare_record
dataset = gluon.data.vision.ImageRecordDataset(recfile)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5, try_nopython=False)

Expand Down Expand Up @@ -138,52 +137,51 @@ def test_sampler():
assert sorted(sum(list(rand_batch_keep), [])) == list(range(10))

@with_seed()
def test_datasets():
assert len(gluon.data.vision.MNIST(root='data/mnist')) == 60000
assert len(gluon.data.vision.MNIST(root='data/mnist', train=False)) == 10000
assert len(gluon.data.vision.FashionMNIST(root='data/fashion-mnist')) == 60000
assert len(gluon.data.vision.FashionMNIST(root='data/fashion-mnist', train=False)) == 10000
assert len(gluon.data.vision.CIFAR10(root='data/cifar10')) == 50000
assert len(gluon.data.vision.CIFAR10(root='data/cifar10', train=False)) == 10000
assert len(gluon.data.vision.CIFAR100(root='data/cifar100')) == 50000
assert len(gluon.data.vision.CIFAR100(root='data/cifar100', fine_label=True)) == 50000
assert len(gluon.data.vision.CIFAR100(root='data/cifar100', train=False)) == 10000
def test_datasets(tmpdir):
p = tmpdir.mkdir("test_datasets")
assert len(gluon.data.vision.MNIST(root=str(p.join('mnist')))) == 60000
assert len(gluon.data.vision.MNIST(root=str(p.join('mnist')), train=False)) == 10000
assert len(gluon.data.vision.FashionMNIST(root=str(p.join('fashion-mnist')))) == 60000
assert len(gluon.data.vision.FashionMNIST(root=str(p.join('fashion-mnist')), train=False)) == 10000
assert len(gluon.data.vision.CIFAR10(root=str(p.join('cifar10')))) == 50000
assert len(gluon.data.vision.CIFAR10(root=str(p.join('cifar10')), train=False)) == 10000
assert len(gluon.data.vision.CIFAR100(root=str(p.join('cifar100')))) == 50000
assert len(gluon.data.vision.CIFAR100(root=str(p.join('cifar100')), fine_label=True)) == 50000
assert len(gluon.data.vision.CIFAR100(root=str(p.join('cifar100')), train=False)) == 10000

@with_seed()
@pytest.mark.serial
def test_datasets_handles():
assert len(gluon.data.vision.MNIST(root='data/mnist').__mx_handle__()) == 60000
assert len(gluon.data.vision.MNIST(root='data/mnist', train=False).__mx_handle__()) == 10000
assert len(gluon.data.vision.FashionMNIST(root='data/fashion-mnist').__mx_handle__()) == 60000
assert len(gluon.data.vision.FashionMNIST(root='data/fashion-mnist', train=False).__mx_handle__()) == 10000
assert len(gluon.data.vision.CIFAR10(root='data/cifar10').__mx_handle__()) == 50000
assert len(gluon.data.vision.CIFAR10(root='data/cifar10', train=False).__mx_handle__()) == 10000
assert len(gluon.data.vision.CIFAR100(root='data/cifar100').__mx_handle__()) == 50000
assert len(gluon.data.vision.CIFAR100(root='data/cifar100', fine_label=True).__mx_handle__()) == 50000
assert len(gluon.data.vision.CIFAR100(root='data/cifar100', train=False).__mx_handle__()) == 10000
def test_datasets_handles(tmpdir):
p = tmpdir.mkdir("test_datasets_handles")
assert len(gluon.data.vision.MNIST(root=str(p.join('mnist'))).__mx_handle__()) == 60000
assert len(gluon.data.vision.MNIST(root=str(p.join('mnist')), train=False).__mx_handle__()) == 10000
assert len(gluon.data.vision.FashionMNIST(root=str(p.join('fashion-mnist'))).__mx_handle__()) == 60000
assert len(gluon.data.vision.FashionMNIST(root=str(p.join('fashion-mnist')), train=False).__mx_handle__()) == 10000
assert len(gluon.data.vision.CIFAR10(root=str(p.join('cifar10'))).__mx_handle__()) == 50000
assert len(gluon.data.vision.CIFAR10(root=str(p.join('cifar10')), train=False).__mx_handle__()) == 10000
assert len(gluon.data.vision.CIFAR100(root=str(p.join('cifar100'))).__mx_handle__()) == 50000
assert len(gluon.data.vision.CIFAR100(root=str(p.join('cifar100')), fine_label=True).__mx_handle__()) == 50000
assert len(gluon.data.vision.CIFAR100(root=str(p.join('cifar100')), train=False).__mx_handle__()) == 10000

@with_seed()
def test_image_folder_dataset():
prepare_record()
dataset = gluon.data.vision.ImageFolderDataset('data/test_images')
def test_image_folder_dataset(prepare_record):
dataset = gluon.data.vision.ImageFolderDataset(os.path.dirname(prepare_record))
assert dataset.synsets == ['test_images']
assert len(dataset.items) == 16

@with_seed()
def test_image_folder_dataset_handle():
prepare_record()
dataset = gluon.data.vision.ImageFolderDataset('data/test_images')
def test_image_folder_dataset_handle(prepare_record):
dataset = gluon.data.vision.ImageFolderDataset(os.path.dirname(prepare_record))
hd = dataset.__mx_handle__()
assert len(hd) == 16
assert (hd[1][0] == dataset[1][0]).asnumpy().all()
assert hd[5][1] == dataset[5][1]

@with_seed()
def test_image_list_dataset():
prepare_record()
imlist = os.listdir('data/test_images/test_images')
def test_image_list_dataset(prepare_record):
root = os.path.join(os.path.dirname(prepare_record), 'test_images')
imlist = os.listdir(root)
imglist = [(0, path) for i, path in enumerate(imlist)]
dataset = gluon.data.vision.ImageListDataset(root='data/test_images/test_images', imglist=imglist)
dataset = gluon.data.vision.ImageListDataset(root=root, imglist=imglist)
assert len(dataset) == 16, len(dataset)
img, label = dataset[0]
assert len(img.shape) == 3
Expand All @@ -196,18 +194,18 @@ def test_image_list_dataset():
fp.write(line + '\n')
fp.close()

dataset = gluon.data.vision.ImageListDataset(root='data/test_images/test_images', imglist=fp.name)
dataset = gluon.data.vision.ImageListDataset(root=root, imglist=fp.name)
assert len(dataset) == 16, len(dataset)
img, label = dataset[0]
assert len(img.shape) == 3
assert label == 0

@with_seed()
def test_image_list_dataset_handle():
prepare_record()
imlist = os.listdir('data/test_images/test_images')
def test_image_list_dataset_handle(prepare_record):
root = os.path.join(os.path.dirname(prepare_record), 'test_images')
imlist = os.listdir(root)
imglist = [(0, path) for i, path in enumerate(imlist)]
dataset = gluon.data.vision.ImageListDataset(root='data/test_images/test_images', imglist=imglist).__mx_handle__()
dataset = gluon.data.vision.ImageListDataset(root=root, imglist=imglist).__mx_handle__()
assert len(dataset) == 16, len(dataset)
img, label = dataset[0]
assert len(img.shape) == 3
Expand All @@ -220,7 +218,7 @@ def test_image_list_dataset_handle():
fp.write(line + '\n')
fp.close()

dataset = gluon.data.vision.ImageListDataset(root='data/test_images/test_images', imglist=fp.name).__mx_handle__()
dataset = gluon.data.vision.ImageListDataset(root=root, imglist=fp.name).__mx_handle__()
assert len(dataset) == 16
img, label = dataset[0]
assert len(img.shape) == 3
Expand Down

0 comments on commit 67b5d31

Please sign in to comment.