diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 412d3134476b..1c54158a2ba4 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -36,6 +36,7 @@ from . import sampler as _sampler from ... import nd, context +from ...recordio import MXRecordIO if sys.platform == 'darwin' or sys.platform == 'win32': def rebuild_ndarray(*args): @@ -158,10 +159,24 @@ def _as_in_context(data, ctx): return [_as_in_context(d, ctx) for d in data] return data +def _recursive_fork_recordio(obj, depth, max_depth=1000): + """Recursively find instance of MXRecordIO and reset file handler. + This is required for MXRecordIO which holds a C pointer to a opened file after fork. + """ + if depth >= max_depth: + return + if isinstance(obj, MXRecordIO): + obj.close() + obj.open() # re-obtain file hanlder in new process + elif (hasattr(obj, '__dict__')): + for _, v in obj.__dict__.items(): + _recursive_fork_recordio(v, depth + 1, max_depth) + def worker_loop(dataset, key_queue, data_queue, batchify_fn): """Worker loop for multiprocessing DataLoader.""" - if hasattr(dataset, '_fork') and callable(dataset._fork): - dataset._fork() + # re-fork a new recordio handler in new process if applicable + _recursive_fork_recordio(dataset, 0, 1000) + while True: idx, samples = key_queue.get() if idx is None: @@ -181,6 +196,7 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False): batch = _as_in_context(batch, context.cpu()) data_buffer[idx] = batch + class _MultiWorkerIter(object): """Interal multi-worker iterator for DataLoader.""" def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 13e2b57a8c59..c93a4b1cd6b9 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -94,11 +94,6 @@ def base_fn(x, *args): return fn(x) return self.transform(base_fn, lazy) - def _fork(self): - """Protective operations required when launching multiprocess workers.""" - # for non file descriptor related datasets, just skip - pass - class SimpleDataset(Dataset): """Simple Dataset wrapper for lists and arrays. @@ -180,9 +175,6 @@ class RecordFileDataset(Dataset): def __init__(self, filename): self.idx_file = os.path.splitext(filename)[0] + '.idx' self.filename = filename - self._fork() - - def _fork(self): self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r') def __getitem__(self, idx): diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 53ce600629c8..cc80aacb6447 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -65,7 +65,8 @@ def prepare_record(): @with_seed() def test_recordimage_dataset(): recfile = prepare_record() - dataset = gluon.data.vision.ImageRecordDataset(recfile) + fn = lambda x, y : (x, y) + dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn) loader = gluon.data.DataLoader(dataset, 1) for i, (x, y) in enumerate(loader): @@ -84,6 +85,15 @@ def test_recordimage_dataset_with_data_loader_multiworker(): assert x.shape[0] == 1 and x.shape[3] == 3 assert y.asscalar() == i + # with transform + fn = lambda x, y : (x, y) + dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn) + loader = gluon.data.DataLoader(dataset, 1, num_workers=5) + + for i, (x, y) in enumerate(loader): + assert x.shape[0] == 1 and x.shape[3] == 3 + assert y.asscalar() == i + @with_seed() def test_sampler(): seq_sampler = gluon.data.SequentialSampler(10)