Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix recordfile dataset with multi worker
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Jun 22, 2018
1 parent 5550c0a commit 209414a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
1 change: 1 addition & 0 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def default_mp_batchify_fn(data):

def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
dataset._fork()
while True:
idx, samples = key_queue.get()
if idx is None:
Expand Down
13 changes: 11 additions & 2 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def base_fn(x, *args):
return fn(x)
return self.transform(base_fn, lazy)

def _fork(self):
"""Protective operations required when launching multiprocess workers."""
# for non file descriptor related datasets, just skip
pass


class SimpleDataset(Dataset):
"""Simple Dataset wrapper for lists and arrays.
Expand Down Expand Up @@ -173,8 +178,12 @@ class RecordFileDataset(Dataset):
Path to rec file.
"""
def __init__(self, filename):
idx_file = os.path.splitext(filename)[0] + '.idx'
self._record = recordio.MXIndexedRecordIO(idx_file, filename, 'r')
self.idx_file = os.path.splitext(filename)[0] + '.idx'
self.filename = filename
self._fork()

def _fork(self):
self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r')

def __getitem__(self, idx):
return self._record.read_idx(self._record.keys[idx])
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def test_recordimage_dataset():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

with_seed()
def test_recordimage_dataset_with_data_loader_multiworker():
# This test is pointless on Windows because Windows doesn't fork
if platform.system() != 'Windows':
recfile = prepare_record()
dataset = gluon.data.vision.ImageRecordDataset(recfile)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

@with_seed()
def test_sampler():
seq_sampler = gluon.data.SequentialSampler(10)
Expand Down

0 comments on commit 209414a

Please sign in to comment.