diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 1c54158a2ba4..50e2ad9f784d 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -175,7 +175,12 @@ def _recursive_fork_recordio(obj, depth, max_depth=1000): def worker_loop(dataset, key_queue, data_queue, batchify_fn): """Worker loop for multiprocessing DataLoader.""" # re-fork a new recordio handler in new process if applicable - _recursive_fork_recordio(dataset, 0, 1000) + # for a dataset with transform function, the depth of MXRecordIO is 1 + # for a lazy transformer, the depth is 2 + # for a user defined transformer, the depth is unknown, try a reasonable depth + limit = sys.getrecursionlimit() + max_recursion_depth = min(limit - 5, max(10, limit // 2)) + _recursive_fork_recordio(dataset, 0, max_recursion_depth) while True: idx, samples = key_queue.get() diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 2ebe657accbd..6fc4d8e7bf57 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -83,6 +83,32 @@ def open(self): def __del__(self): self.close() + def __getstate__(self): + """Override pickling behavior.""" + # pickling pointer is not allowed + is_open = self.is_open + self.close() + d = dict(self.__dict__) + d['is_open'] = is_open + uri = self.uri.value + try: + uri = uri.decode('utf-8') + except AttributeError: + pass + del d['handle'] + d['uri'] = uri + return d + + def __setstate__(self, d): + """Restore from pickled.""" + self.__dict__ = d + is_open = d['is_open'] + self.is_open = False + self.handle = RecordIOHandle() + self.uri = c_str(self.uri) + if is_open: + self.open() + def close(self): """Closes the record file.""" if not self.is_open: @@ -217,6 +243,12 @@ def close(self): super(MXIndexedRecordIO, self).close() self.fidx.close() + def __getstate__(self): + """Override pickling behavior.""" + d = super(MXIndexedRecordIO, self).__getstate__() + d['fidx'] = None + return d + def seek(self, idx): """Sets the current read pointer position. diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index cc80aacb6447..c731f8d782d1 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -73,26 +73,39 @@ def test_recordimage_dataset(): assert x.shape[0] == 1 and x.shape[3] == 3 assert y.asscalar() == i +def _dataset_transform_fn(x, y): + """Named transform function since lambda function cannot be pickled.""" + return x, y + @with_seed() def test_recordimage_dataset_with_data_loader_multiworker(): - # This test is pointless on Windows because Windows doesn't fork - if platform.system() != 'Windows': - recfile = prepare_record() - dataset = gluon.data.vision.ImageRecordDataset(recfile) - loader = gluon.data.DataLoader(dataset, 1, num_workers=5) - - for i, (x, y) in enumerate(loader): - assert x.shape[0] == 1 and x.shape[3] == 3 - assert y.asscalar() == i - - # with transform - fn = lambda x, y : (x, y) - dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn) - loader = gluon.data.DataLoader(dataset, 1, num_workers=5) - - for i, (x, y) in enumerate(loader): - assert x.shape[0] == 1 and x.shape[3] == 3 - assert y.asscalar() == i + recfile = prepare_record() + dataset = gluon.data.vision.ImageRecordDataset(recfile) + loader = gluon.data.DataLoader(dataset, 1, num_workers=5) + + for i, (x, y) in enumerate(loader): + assert x.shape[0] == 1 and x.shape[3] == 3 + assert y.asscalar() == i + + # with transform + dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn) + loader = gluon.data.DataLoader(dataset, 1, num_workers=5) + + for i, (x, y) in enumerate(loader): + assert x.shape[0] == 1 and x.shape[3] == 3 + assert y.asscalar() == i + + # try limit recursion depth + import sys + old_limit = sys.getrecursionlimit() + sys.setrecursionlimit(500) # this should be smaller than any default value used in python + dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn) + loader = gluon.data.DataLoader(dataset, 1, num_workers=5) + + for i, (x, y) in enumerate(loader): + assert x.shape[0] == 1 and x.shape[3] == 3 + assert y.asscalar() == i + sys.setrecursionlimit(old_limit) @with_seed() def test_sampler():